The following script took more than 1 minute to compile in my system (with the newest versions of jax & jaxlib). A very strange phenomenon is: compiling time will be reduced to 2s for some random changes (e.g. replacing 0.5 * level by level) in the body function.
import jax.numpy as np
from jax import jit, grad, lax
from jax.config import config; config.update("jax_platform_name", "cpu")
def f(s):
def scan_fn(carry, y_t):
level, s = carry
exp_val = (level + 2 * level ** 0.5) * s[0]
exp_val = np.clip(exp_val, a_min=0)
level = y_t / s[0] + 0.5 * level
s = y_t / level + 0.5 * s
return (level, s), exp_val
return lax.scan(scan_fn, (1., s), np.ones(80))[1].sum()
jit(grad(f))(np.ones(2))
This feels pretty strange to me. Just make_jaxpr(grad(f))(np.ones(2)) also displays the 2min+ computation time, which I believe is mostly spent in LLVM (but seems to be caused by an enormous HLO graph that's compiled somewhere during grad of scan; haven't yet looked into why?)
The jaxpr for grad(f) itself is pretty small and looks fine, though.
This looks like it's either an LLVM bug or an XLA bug. The computation isn't all that large, but lots of time is being spent in the LLVM SelectionDAG logic. I filed an (internal) bug for the XLA team.
Thank @jekbradbury, @hawkinsp for looking into this issue!
I agree that the computation isn't that large in absolute terms, but it's still surprising to me (the one that takes a lot of time in LLVM looks vaguely unrolled at the HLO level or something?)
We can probably hope for XLA/LLVM to get faster, but also we can improve the kind of graph we're staging out.
I noticed (because @hawkinsp pointed it out in related circumstances) that we're adding a lot of scalars literals to the loop-carry tuple, when instead those could be staged into the XLA computation as literals. @dougalm and I recently added some logic to instantiate scalar literals as XLA literals rather than hosting them into the loop carry (in d27bc0a, part of #704), but we conservatively only switched that on for the Python types int and float. In particular, DeviceArray constants still got hoisted.
I noticed that this computation was doing a good amount of scalar-hoisting, and so in #780 (specifically 9c931dd) I sketched out some logic that allows more types to be treated as literals in jaxprs (and hence in the staged-out XLA computations). That seems to make the compile time for the OP's code essentially instantaneous.
I want to look over that code with fresh eyes tomorrow, but I'm optimistic it (or something like it) will handle this scan compilation time issue and a lot of related ones.
@fehiepsi can you sync past #780 and verify that it solved this issue for you? If not, let's reopen and keep exploring!
Thanks so much, @mattjj! #780 has solved the compiling issue of the script in this topic. Currently, I am getting a shape error issue in the original script (after the update) so I can't test for it. I'll isolate the error then get back to you.
@mattjj Here is a repro code to trigger that error:
import jax.numpy as np
from jax import jit, grad, lax
from jax.config import config; config.update("jax_platform_name", "cpu")
def f(init_s):
seasonality = init_s.shape[0]
def scan_fn(carry, t):
level, s, moving_sum = carry
season = s[0] * level ** 0.6
exp_val = level + 2 * level ** 0.5 + season
exp_val = np.clip(exp_val, a_min=0)
moving_sum = moving_sum + y[t] - np.where(t >= seasonality, y[t - seasonality], 0.)
level_p = np.where(t >= seasonality, moving_sum / seasonality, y[t] - season)
level = 0.2 * level_p + (1 - 0.2) * level
level = np.clip(level, a_min=0)
new_s = (0.3 * (y[t] - level) / season + (1 - 0.3)) * s[0]
s = np.concatenate([s[1:], new_s[None]], axis=0)
return (level, s, moving_sum), exp_val
y = np.ones(80)
level_init = y[0]
s_init = np.concatenate([init_s[1:], init_s[:1]], axis=0)
moving_sum = level_init
(last_level, last_s, moving_sum), exp_vals = lax.scan(
scan_fn, (level_init, s_init, moving_sum), np.arange(1, y.shape[0]))
return exp_vals.sum()
f(np.ones(38))
jit(f)(np.ones(38))
grad(f)(np.ones(38))
jit(grad(f))(np.ones(38)) # TypeError: (ShapedArray(float32[]), ())
The good new is your fix runs pretty well for many other scripts. :D I think that it is just a typo somewhere in that PR.
This script triggers an error in forward pass
def f(init_s):
seasonality = init_s.shape[0]
def scan_fn(carry, t):
level, s, moving_sum = carry
season = s[0] * level ** 0.6
exp_val = level + 2 * level ** 0.5 + season
exp_val = np.clip(exp_val, a_min=0)
moving_sum = moving_sum + y[t] - np.where(t >= seasonality, y[t - seasonality], 0.)
#level_p = np.where(t >= seasonality, moving_sum / seasonality, y[t] - season)
#level = 0.2 * level_p + (1 - 0.2) * level
level = np.clip(level, a_min=0)
new_s = (0.3 * (y[t] - level) / season + (1 - 0.3)) * s[0]
s = np.concatenate([s[1:], new_s[None]], axis=0)
return (level, s, moving_sum), exp_val
y = np.ones(80)
level_init = y[0]
s_init = np.concatenate([init_s[1:], init_s[:1]], axis=0)
moving_sum = level_init
(last_level, last_s, moving_sum), exp_vals = lax.scan(
scan_fn, (level_init, s_init, moving_sum), np.arange(1, y.shape[0]))
return exp_vals.sum()
f(np.ones(38))
jit(f)(np.ones(38)) # Type Error
Thanks for the repros! I think this is exercising a bug that I noticed in #780 (but didn't cause in #780 as far as I know... though maybe you have evidence to the contrary!). I'm going to open a separate issue.
Most helpful comment
This looks like it's either an LLVM bug or an XLA bug. The computation isn't all that large, but lots of time is being spent in the LLVM SelectionDAG logic. I filed an (internal) bug for the XLA team.