Currently the api of jax's random is quite different from that of np.random. For example,
jax.random.normal(key, shape, dtype=onp.float32)
numpy.random.normal(loc=0.0, scale=1.0, size=None)
Shall we make them the same in the future? I think for jax we can assign a default value to the random number seed?
Thanks for asking this!
In general we like to introduce as few APIs as possible, and we do that by sticking close to NumPy. But the np.random API is different because it uses global state and side-effects (in the example you wrote) and/or requires users to thread around RandomState instances to ensure reproducibility. Its default PRNG algorithms are also not amenable to vectorization, meaning they'd be much slower when executed on a GPU or TPU, and using different algorithms on different backends would harm reproducibility.
For those reasons, and the fact that parallel random numbers can be as easy as 1, 2, 3, we wrote our own PRNG system, which requires a slightly different API than numpy.random. To provide a more detailed explanation, we just uploaded a design doc that explains our reasoning. Feedback is always welcome!
I suspect that our PRNG design requires a slightly different API than numpy.random, but there are likely several ways we can improve it, both by making it more familiar to NumPy users and by adding checks for some easily-made mistakes (like #192) and more documentation. If you have specific ideas for how to bring our PRNG API closer to NumPy's (while still fitting our underlying design), or improve it in other ways, we'd love to hear them!
If this answers your question, let's close this issue, but discuss specific API ideas as they come up in new threads or pull requests.
Sure, now I see the problem. Thanks for that!
Most helpful comment
Thanks for asking this!
In general we like to introduce as few APIs as possible, and we do that by sticking close to NumPy. But the
np.randomAPI is different because it uses global state and side-effects (in the example you wrote) and/or requires users to thread aroundRandomStateinstances to ensure reproducibility. Its default PRNG algorithms are also not amenable to vectorization, meaning they'd be much slower when executed on a GPU or TPU, and using different algorithms on different backends would harm reproducibility.For those reasons, and the fact that parallel random numbers can be as easy as 1, 2, 3, we wrote our own PRNG system, which requires a slightly different API than
numpy.random. To provide a more detailed explanation, we just uploaded a design doc that explains our reasoning. Feedback is always welcome!I suspect that our PRNG design requires a slightly different API than
numpy.random, but there are likely several ways we can improve it, both by making it more familiar to NumPy users and by adding checks for some easily-made mistakes (like #192) and more documentation. If you have specific ideas for how to bring our PRNG API closer to NumPy's (while still fitting our underlying design), or improve it in other ways, we'd love to hear them!If this answers your question, let's close this issue, but discuss specific API ideas as they come up in new threads or pull requests.