With jax versions > 0.1.24, random samplers are slow. I think that it is due to recent changes on how jit works with static_args/kwargs. A script to reproduce,
import time
from jax import random
t = time.time()
random.normal(random.PRNGKey(0), shape=(1,))
print(time.time() - t)
t = time.time()
random.normal(random.PRNGKey(0), shape=(1,))
print(time.time() - t)
which returns 0.12923526763916016 and 0.12221717834472656.
However, if we wrap these samplers in some function, then it is fast. For example,
def f():
return random.normal(random.PRNGKey(0), shape=(1,))
t = time.time()
f()
print(time.time() - t)
t = time.time()
f()
print(time.time() - t)
will return 0.12787413597106934 and 0.0010831356048583984.
I think that there is a small bug elsewhere which forces the sampler recompile. If this is an expected behaviour, then which function we should use to wrap these samplers to make it not recompiled?
cc @neerajprad
Thanks so much for catching this! We will fix it asap.
(Longer term we aim to set up a continuous benchmarking solution to avoid regressions.)
This was an issue with all uses of static_argnums, though it only cropped up recently in random.py because of how we changed random.py following the kwargs-handling change.
The use of static_argnums was only resulting in compilation cache hits based on _object identity_ equivalence. That meant that random.py functions were only caching on object identity equivalence (i.e. x is y instead of equality x == y) of their shape parameters. In your first code example, which passed in (1,) the first time and a fresh literal (1,) the second time, those two objects didn't have the same identity (though they would have compared equal). In your second code example, the same literal was used twice and so the cache was being hit on object identity.
The fix in #692 is just to make static_argnums use equality checks when possible (i.e. when argument objects corresponding to static argnums have __hash__ and, by implication, __eq__), and fall back to object identity equivalence when not. This should still allow arbitrary objects to be passed as static args (i.e. even unhashable ones), but also let us get cache hits where appropriate, especially for shape args like in random.py.
Does that make sense? What do you think?
I can't thank you enough for spotting this and providing such a clear repro. Your sleuthing made this an easy fix, and without it we would have missed it for who knows how long!
Can you verify the fix on your end, and let me know what you see?
(I just pushed 0.1.28 out to pypi with this fix. Hopefully it actually worked!)
Whoa, it is so fast. Thank you so much for a very clear explanation! To my end, the samplers are fast again now. :)
Thanks for such a quick resolution and for cutting a new release, @mattjj! :)
JAX bugs fixed in 30 minutes or less, or your money* back!
*_JAX costs $0_
Most helpful comment
This was an issue with all uses of
static_argnums, though it only cropped up recently in random.py because of how we changed random.py following the kwargs-handling change.692 should fix it.
The use of
static_argnumswas only resulting in compilation cache hits based on _object identity_ equivalence. That meant that random.py functions were only caching on object identity equivalence (i.e.x is yinstead of equalityx == y) of their shape parameters. In your first code example, which passed in(1,)the first time and a fresh literal(1,)the second time, those two objects didn't have the same identity (though they would have compared equal). In your second code example, the same literal was used twice and so the cache was being hit on object identity.The fix in #692 is just to make
static_argnumsuse equality checks when possible (i.e. when argument objects corresponding to static argnums have__hash__and, by implication,__eq__), and fall back to object identity equivalence when not. This should still allow arbitrary objects to be passed as static args (i.e. even unhashable ones), but also let us get cache hits where appropriate, especially for shape args like in random.py.Does that make sense? What do you think?
I can't thank you enough for spotting this and providing such a clear repro. Your sleuthing made this an easy fix, and without it we would have missed it for who knows how long!