Jax: compiling time grows non-linear w.r.t. the number of random.split ops

Created on 13 Aug 2019  路  3Comments  路  Source: google/jax

A repro code

from jax import jit, random
from jax.config import config; config.update("jax_platform_name", "cpu")
import time

def f(key, i):
    for _ in range(i):
        _, key = random.split(key)
    return key

for i in range(17):
    t = time.time()
    key = jit(f, static_argnums=(1,))(random.PRNGKey(0), i).copy()
    print("split {} times takes {}".format(i, time.time() - t))

which returns

split 0 times takes 0.008060216903686523
split 1 times takes 0.05638694763183594
split 2 times takes 0.13771629333496094
split 3 times takes 0.284271240234375
split 4 times takes 0.4720790386199951
split 5 times takes 0.8609561920166016
split 6 times takes 1.5761287212371826
split 7 times takes 2.615913152694702
split 8 times takes 4.20055079460144
split 9 times takes 6.581374168395996
split 10 times takes 9.83530855178833
split 11 times takes 15.45705771446228
split 12 times takes 20.33790898323059
split 13 times takes 27.418986320495605
split 14 times takes 35.04498267173767
split 15 times takes 44.33582162857056
split 16 times takes 53.186384439468384

@mattjj I think that this causes the regression of gamma sampler in our last discussion. The above issue is causing a big problem in compiling models with many random latent variables in NumPyro.

bug performance

All 3 comments

Thanks for opening this. We'll look into it.

Can you send me a master list of JAX bugs that are blocking or otherwise inconveniencing your work? I want to prioritize them, but I fear that I've lost track of the most important ones.

This isn't a great workaround, but just to make sure you're aware, if you make just one split call things are faster, i.e. all_keys = random.split(key, 16).

I looked at a profile and found that all the time is being spent in thew XLA compiler. Each call to random.split is inlining a call to the hash function underlying random.split, and that inlining means that we're building large XLA graphs (and XLA compile times can be super-linear in the input size). There may be things we can do on the JAX side, and also I pinged the XLA folks for ideas.

Thank Matt, that's a great idea! We'll think more if we can go with that solution. cc @neerajprad

Can you send me a master list of JAX bugs that are blocking or otherwise inconveniencing your work?

Currently, beside this issue, we only have some problems with fastmath mode, which makes it impossible (due to nan issue) for us to do stochastic/mcmc inference on some models involving e.g. logsumexp. I'll resort some repro code and will make a separate issue for it soon. Other than that, JAX <3 works beautifully for us.

Was this page helpful?
0 / 5 - 0 ratings