Thank you for you work, this is a very promising project. But recently I tried to use it for a gradient calculation and sadly concluded than I can not build my project on top of it because it is just extremely slow in the current state.
I created a very detailed self-contained Jupyter Notebook with all steps which are required to reproduce a problem, you can check it here. I tested preformance of a gradient calculation for the same function, but for different libraries: JAX and PyTorch. This function is a simplification of the real one and no external data is used, all data is generated during a runtime.
I was faced with two issues:
Thank you for your report. In your example,
for i in range(1, t):
f = ops.index_update(f, i, x[i]+f[i-1])
is unrolled by JAX when jitting, which explains the increasing jit time with t. Did you intend to specifically profile a sequence of several index_update? If you instead want to keep the loop structure, you'd have to use a lax.scan operation:
f = lax.scan(lambda f, i: (ops.index_update(f, i, x[i] + f[i-1]), None), f, np.arange(1, t))
I have not looked into the grad cost issue, but I am guessing that after the unrolling of the update operations, the compiler cannot quite understand the update pattern of the f array and then it ends up keeping multiple copies of f. I believe that you'd see different performance with the scan version.
Thanks for bringing this up!
To reiterate what @gnecula said: for the jit compilation time, the fact that JAX's tracing effectively unrolls the loop, paired with XLA's compile time scaling, causes the problem here. (Because XLA does a lot of optimization, its compilation time scales super-linearly with the input program size, especially on CPU where XLA is by far the least optimized.) Without jit, this program is making a lot of copies: I'd expect one fully copy of f for every call to ops.index_update, since while those updates become in-place updates under a jit, without jit they'll be real copies.
In terms of getting a good micro-benchmark measurement, you should add a .block_until_ready() so you're not just timing dispatch time (due to asynchronous dispatch). I'm not sure if PyTorch needs something similar. Moreover, by passing -r 1 -n 1 you might only be measuring compilation time.
You can compute the same thing much faster by writing the program differently, though I suspect that if the original program is a simplified model of a workload you actually care about, this particular rewrite won't be practically useful (see below for another rewrite):
def func_jax2(x):
return np.sum(np.cumsum(x[1:]))
Here's the timing I see with that (which, again, may just be academic if your real use case needs loops):
In [1]: import jax.numpy as np
In [2]: def func_jax2(x):
...: return np.sum(np.cumsum(x[1:]))
...:
In [3]: from jax import grad
In [4]: timeit grad(func_jax2)(np.ones(10)).block_until_ready()
/usr/local/google/home/mattjj/packages/jax/jax/lib/xla_bridge.py:120: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
1.88 ms ± 104 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [5]: timeit grad(func_jax2)(np.ones(20)).block_until_ready()
1.84 ms ± 54.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [6]: timeit grad(func_jax2)(np.ones(30)).block_until_ready()
1.84 ms ± 141 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [7]: from jax import jit
In [8]: @jit
...: def func_jax2(x):
...: return np.sum(np.cumsum(x[1:]))
...:
In [9]: timeit grad(func_jax2)(np.ones(30)).block_until_ready()
717 µs ± 33.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [10]: timeit grad(func_jax2)(np.ones(100)).block_until_ready()
709 µs ± 21.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
The work here is still O(n), where n is the length of the input array, but the program size is O(1).
It seems likely that in your real use case you might need a loop rather than just being able to use np.cumsum as in this toy model. Here's a way we can write this toy computation without using np.cumsum, but still avoiding the ops.index_update calls which are likely causing copies when outside a jit:
def func_jax3(x):
t = len(x)
f = 0.
tot = 0.
for i in range(1, t):
f = f + x[i]
tot = tot + f
return tot
Here, the timings for n=10, 20, 30 are 14.7ms, 32.2ms, 54.3ms on my machine. That's slow, but not 100x! We could probably get that down to similar numbers as PyTorch by working on JAX's Python overheads, but we haven't spent nearly as much time optimizing those overheads because for many workloads users just rely on the jit sledgehammer.
If we want to use jit but avoid the long compile times for loop-based code, we'll need to use lax.scan to keep the loop rolled as @gnecula mentioned. That's less convenient than using native Python loops, but sometimes jit demands concessions. Here's one way to write it:
from jax import lax
@jit
def func_jax4(x):
def body(c, x_i):
return c + x_i, c + x_i
_, f = lax.scan(body, 0., x)
return np.sum(f)
In [47]: timeit grad(func_jax4)(np.ones(10)).block_until_ready()
672 µs ± 34.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [48]: timeit grad(func_jax4)(np.ones(20)).block_until_ready()
682 µs ± 35.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [49]: timeit grad(func_jax4)(np.ones(30)).block_until_ready()
710 µs ± 26.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
To summarize my current understanding, JAX didn't work "out of the box" for you here because the program's use of ops.index_update meant falling off a performance cliff without jit due to extra copies, and with jit the program's unrolled loop meant compile times would get bad. By rewriting the code we're able to get great performance, but that meant either relying on our ability to write our program in terms of NumPy functions (func_jax2), or else using lax.scan together with jit (func_jax4) which is more awkward than writing native Python loops.
More generally, JAX isn't going to be faster on all microbenchmarks. JAX and XLA are optimized for certain workloads, and even certain ways of expressing programs. It's a lot like NumPy itself: if your program needs Python loops over scalar computations, or is otherwise hard to express, it's probably not in the performance sweet spot. That's why NumPy has grown "escape hatches" like numba and cython. (As an example of what's _in_ the sweet spot, aside from just neural networks, take a look at NumPyro; the paper compares against PyTorch, TF, and Stan.)
What do you think? What'd I miss?
I forgot to mention: jit-of-grad will be faster than grad-of-jit, since the more you can put under a jit, the better!
@gnecula, @mattjj, thank you for your detailed answers, I made suggested changes in the notebook and it helps a lot. Now execution time almost the same for 10000-elements array as for 10-elements array. It's just amazing!
How do you think, can it be possible for a compiler to guess that multiple copies of f array are not necessary if they are not used anywhere inside a function and can't be accessible outside of it?
One other recommendation: when you use jit, you should probably do it once outside the timing loop. Otherwise the compilation cache is getting cleared and you're paying the recompilation cost each time. (You can see when you're getting cache hits by setting the environment variable JAX_LOG_COMPILES=1, or otherwise setting the config option using from jax.config import config; config.update("jax_log_compiles", 1).)
That is, when you write %timeit jit(grad(...)) ... you're probably doing a lot of recompilation.
I'm thinking something like this (also with a tweaked version of func_jax just because it's interesting, it computes the same thing in the same way):
from jax import jit, grad
import jax.numpy as np
from jax import lax, ops
def func_jax(x):
def body(f_prev, x_i):
return f_prev + x_i, f_prev + x_i
_, f = lax.scan(body, 0., x)
return np.sum(f)
gradfun = jit(grad(func_jax))
Then I run timeit gradfun(np.ones(10000)) and get compilation cache hits (and faster timings).
How do you think, can it be possible for a compiler to guess that multiple copies of f array are not necessary if they are not used anywhere inside a function and can't be accessible outside of it?
XLA does exactly that sort of optimization. Actually, since the XLA programming model is functionally pure, XLA programs only deal with values (rather than dealing with buffers explicitly). Given a program, XLA attempts to minimize the cost of evaluating that program, and the cost includes factoring in the number and size of the buffers it needs. (It also factors in how expensive things are to compute, since in some cases it can be worth paying some extra memory to save compute, or trading them off in some way.) Part of that cost optimization means noticing that when a value isn't used elsewhere it can generate in-place updates to the underlying buffer.
I'm not totally sure how much optimization the XLA:CPU backend does right now; as I mentioned before, it's the least developed of the backends, with XLA:GPU being pretty sophisticated and XLA:TPU being absolutely incredible. It's a very smart compiler!
Thanks for raising this! Should we close the issue now that things are resolved?
Some of our JAX code jit compiles in many seconds to half a minute. Long jit times spoil the user experience of JAX for more complex code in my opinion.
Just a random idea: I wonder if it would be possible to cache jit results, so you can load the jit precompiled code, similar to precompilation of CUDA or OpenCL kernels?
Most helpful comment
Thanks for bringing this up!
To reiterate what @gnecula said: for the
jitcompilation time, the fact that JAX's tracing effectively unrolls the loop, paired with XLA's compile time scaling, causes the problem here. (Because XLA does a lot of optimization, its compilation time scales super-linearly with the input program size, especially on CPU where XLA is by far the least optimized.) Withoutjit, this program is making a lot of copies: I'd expect one fully copy offfor every call toops.index_update, since while those updates become in-place updates under ajit, withoutjitthey'll be real copies.In terms of getting a good micro-benchmark measurement, you should add a
.block_until_ready()so you're not just timing dispatch time (due to asynchronous dispatch). I'm not sure if PyTorch needs something similar. Moreover, by passing-r 1 -n 1you might only be measuring compilation time.You can compute the same thing much faster by writing the program differently, though I suspect that if the original program is a simplified model of a workload you actually care about, this particular rewrite won't be practically useful (see below for another rewrite):
Here's the timing I see with that (which, again, may just be academic if your real use case needs loops):
The work here is still O(n), where n is the length of the input array, but the program size is O(1).
It seems likely that in your real use case you might need a loop rather than just being able to use
np.cumsumas in this toy model. Here's a way we can write this toy computation without usingnp.cumsum, but still avoiding theops.index_updatecalls which are likely causing copies when outside ajit:Here, the timings for n=10, 20, 30 are 14.7ms, 32.2ms, 54.3ms on my machine. That's slow, but not 100x! We could probably get that down to similar numbers as PyTorch by working on JAX's Python overheads, but we haven't spent nearly as much time optimizing those overheads because for many workloads users just rely on the
jitsledgehammer.If we want to use
jitbut avoid the long compile times for loop-based code, we'll need to uselax.scanto keep the loop rolled as @gnecula mentioned. That's less convenient than using native Python loops, but sometimesjitdemands concessions. Here's one way to write it:To summarize my current understanding, JAX didn't work "out of the box" for you here because the program's use of
ops.index_updatemeant falling off a performance cliff withoutjitdue to extra copies, and withjitthe program's unrolled loop meant compile times would get bad. By rewriting the code we're able to get great performance, but that meant either relying on our ability to write our program in terms of NumPy functions (func_jax2), or else usinglax.scantogether withjit(func_jax4) which is more awkward than writing native Python loops.More generally, JAX isn't going to be faster on all microbenchmarks. JAX and XLA are optimized for certain workloads, and even certain ways of expressing programs. It's a lot like NumPy itself: if your program needs Python loops over scalar computations, or is otherwise hard to express, it's probably not in the performance sweet spot. That's why NumPy has grown "escape hatches" like numba and cython. (As an example of what's _in_ the sweet spot, aside from just neural networks, take a look at NumPyro; the paper compares against PyTorch, TF, and Stan.)
What do you think? What'd I miss?