A repro code
from jax import jit, random
from jax.config import config; config.update("jax_platform_name", "cpu")
import time
def f(key, i):
for _ in range(i):
_, key = random.split(key)
return key
for i in range(17):
t = time.time()
key = jit(f, static_argnums=(1,))(random.PRNGKey(0), i).copy()
print("split {} times takes {}".format(i, time.time() - t))
which returns
split 0 times takes 0.008060216903686523
split 1 times takes 0.05638694763183594
split 2 times takes 0.13771629333496094
split 3 times takes 0.284271240234375
split 4 times takes 0.4720790386199951
split 5 times takes 0.8609561920166016
split 6 times takes 1.5761287212371826
split 7 times takes 2.615913152694702
split 8 times takes 4.20055079460144
split 9 times takes 6.581374168395996
split 10 times takes 9.83530855178833
split 11 times takes 15.45705771446228
split 12 times takes 20.33790898323059
split 13 times takes 27.418986320495605
split 14 times takes 35.04498267173767
split 15 times takes 44.33582162857056
split 16 times takes 53.186384439468384
@mattjj I think that this causes the regression of gamma sampler in our last discussion. The above issue is causing a big problem in compiling models with many random latent variables in NumPyro.
Thanks for opening this. We'll look into it.
Can you send me a master list of JAX bugs that are blocking or otherwise inconveniencing your work? I want to prioritize them, but I fear that I've lost track of the most important ones.
This isn't a great workaround, but just to make sure you're aware, if you make just one split call things are faster, i.e. all_keys = random.split(key, 16).
I looked at a profile and found that all the time is being spent in thew XLA compiler. Each call to random.split is inlining a call to the hash function underlying random.split, and that inlining means that we're building large XLA graphs (and XLA compile times can be super-linear in the input size). There may be things we can do on the JAX side, and also I pinged the XLA folks for ideas.
Thank Matt, that's a great idea! We'll think more if we can go with that solution. cc @neerajprad
Can you send me a master list of JAX bugs that are blocking or otherwise inconveniencing your work?
Currently, beside this issue, we only have some problems with fastmath mode, which makes it impossible (due to nan issue) for us to do stochastic/mcmc inference on some models involving e.g. logsumexp. I'll resort some repro code and will make a separate issue for it soon. Other than that, JAX <3 works beautifully for us.