It would be nice to have a random weighted choice in jax. We need this to sample bitstrings from a quantum wave function.
I ran into the same issue, I got around it by implementing the function described in this blog post. Ideally jax.random.choice would be implemented, but I suppose it's tricky to get it to match the API for np.random.choice. Perhaps jax.random could have a simpler sampling without replacement function?
Edit: Seems like #2066 discusses similar issues in how to match the numpy version.
This is a significant missing feature. I'm curious @Thenerdstation if you also need sampling without replacement or if you're happy with a single sample/sampling with replacement (both much easier).
Sampling with replacement is what we need for our use case
seconded
should this take a probability vector p? categorical takes a logits vector. (was this for performance reasons?)
guessing api would be something like
jax.random.choice(key, x, size=None, replace=True, p=None, axis=0)
potentially relevant: permutation was requested https://github.com/google/jax/issues/1526 with PR pending https://github.com/google/jax/pull/1568
I also have been in need of this several times.
Both the with and without replacement case.
I wrote an implementation of unweighted random.choice in #3463
Weighted will take a bit more thought. Weighted with replacement should be fairly straightforward using searchsorted of uniform values across cumulative weights; I'm not sure how to best implement weighted without replacement.
Thanks @jacobjinkelly – that's really cool!
Can we merge the two solutions? so that we have choice with and without replacement and both weighted and unweighted?
Yes, that’s my plan.
OK, #3463 now implements random.choice with or without replacement, and with or without weights.
Most helpful comment
@jakevdp Not sure if this is the best way, but based off this blog post I implemented weighted sampling without replacement (
swor) here