It would be nice if we could get jax.random functions analogous to np.random.permutation and np.random.shuffle.
There is a random.shuffle. A random.permutation could be a simple wrapper around that IIUC.
I'd be happy to take a look at this.
Most helpful comment
I'd be happy to take a look at this.