Currently, while_loop/fori_loop seems not cache the body_fn or take the benefit from the cached body_fn. It would be nice if these loops take the benefit of the cached function.
The following repro illustrates the problem
import time
import jax.numpy as np
from jax import jit, lax
def f(x):
return np.mean(x - np.ones(10000))
x = np.array(0.)
t = time.time()
f(x)
print("time to evaluate f :", time.time() - t)
jf = jit(f)
jf(x)
t = time.time()
jf(x)
print("time to evaluate jitted f:", time.time() - t)
t = time.time()
lax.while_loop(lambda x: x < 0, jf, np.array(0.))
print("time for while_loop f :", time.time() - t)
t = time.time()
lax.while_loop(lambda x: x < 0, jf, np.array(0.))
print("time for while_loop f :", time.time() - t)
which yields
time to evaluate f : 0.0552217960357666
time to evaluate jitted f: 0.0003039836883544922
time for while_loop f : 0.026403188705444336
time for while_loop f : 0.020539522171020508
cc @mattjj
Summarizing my understanding of the problem at the moment:
When executing lax.fori_loop or lax.while_loop in op-by-op mode, even if the user's body function is itself jitted, the constructed XLA While and its function arguments are recompiled again for every call. What we want is instead for While to behave like simpler lax-wrapped HLOs in op-by-op mode and have its compilation cached across calls.
Here's what I'm seeing on this script on master:
In [1]: run issue587
('time to evaluate f :', 0.031208038330078125)
('time to evaluate jitted f:', 0.00018405914306640625)
('time for while_loop f :', 0.024088144302368164)
('time for while_loop f :', 0.018099069595336914)
The second example doesn't see a speedup is that it's getting recompiled each time. The reason (also explained in the testCaching2 disabled test) is that each time lax.while_loop(lambda x: x < 0, jf, np.array(0.)) is evaluated, lambda x: x < 0 evaluates to a fresh function object, and the only caching mechanism we're using in control flow compares the identity of function objects.
As @j-towns first observed in #1221, we could get something of a cache hit for the latter case, though it wouldn't be quite as good: we would still have to trace the Python function lambda x: x < 0 to a jaxpr, but once in that form if we implement nontrivial (i.e. not identity-based) equality testing on jaxprs we could notice that we already have a compiled version of the corresponding XLA computation and not redo that work. That is, we'd be re-forming a jaxpr every time, but avoiding the XLA compile time.
I understand the situation now. I think that using lambda is not necessary here, and the current behaviour is what I expected. Thanks a lot, Matt!
I think that we can close this for now. :)