I'm trying to extract Jacobian of a function f consisting of numerous numpy operations (such as dot, matmul, einsum, etc.) via jacfwd. I've also used jit to compile the jacfwd of my function using J = jit(jacfwd(f)).
As f is a relatively small function the calling of J = jit(jacfwd(f)) is lightning fast, but to my surprise the first time when I call J, the run time was 10s, and subsequent runs were at the order of 1 ms or less.
Is this normal behavior? Really hoping I could lower the run-time of first run significantly as this is meant for speed-critical applications. If this is unavoidable, is there anyway I can save this "warmed-up" function somewhere so that it incurs the computational cost only once?
Disclaimer: I'm still a bit new to this and also not a contributor so someone might come and correct what I'm saying here 馃檪
What I think happens is that calling jax.jit of a function isn't what actually does the compilation. Instead, this just prepares the function for JIT compilation. See the documentation for JIT here: https://jax.readthedocs.io/en/latest/jax.html#jax.jit
For the actual compilation to happen, the function has to be traced -- this happens on your first run of J. That's why you're seeing the first call take 10 seconds. Unless startup time really matters, you could JIT your first iteration with fake data if you need all of your real iterations to run fast.
As for saving the compiled function, I don't think that's supported yet. But you might be interested in the discussions here:
To tackle the slow first run (i.e. compilation step), I'm thinking to run the first-run of jitted function when some of my processor cores are idling. I'm trying to achieve this via Python's multiprocessing, using "spawn" method. It seems that it failed and cited "Can't pickle local object". Is there any recommended way to go about this?
We don't support pickling of JAX objects.
Can you say a bit more about why simply warming up the cache as suggested above wouldn't work?
@hawkinsp I have a customer-facing application that's run-time sensitive. Currently the warm-up is taking up about 15% of the overall run-time on an AWS instance. Unfortunately the Python interpreter is started for every new case that the customer sends in, so all the cache only work if the computation occurs within the case that the warm-up was done. Otherwise one way to go about this would be to make the Python interpreter on serving mode, warm it up once and that solves the problem.
I think I have tried also cloudpickle (as suggested somewhere else), but as you mentioned it really didn't help much. It did allow me to pickle a jitted (and warmed up) function, but upon reloading the function and running it, it was doing a recompilation again.
I have also tried to warm up the function with one idle core way before it is needed as a step of preparation by using multiprocess, but I ran into problem passing back a DeviceArray which was declared as static argument (I understand now that this is asking for trouble).
Most helpful comment
Disclaimer: I'm still a bit new to this and also not a contributor so someone might come and correct what I'm saying here 馃檪
What I think happens is that calling
jax.jitof a function isn't what actually does the compilation. Instead, this just prepares the function for JIT compilation. See the documentation for JIT here: https://jax.readthedocs.io/en/latest/jax.html#jax.jitFor the actual compilation to happen, the function has to be traced -- this happens on your first run of
J. That's why you're seeing the first call take 10 seconds. Unless startup time really matters, you could JIT your first iteration with fake data if you need all of your real iterations to run fast.As for saving the compiled function, I don't think that's supported yet. But you might be interested in the discussions here:
476
679
1566
4300