A piece of a larger algorithm is failing to compile with JAX==0.2.0 installed from pip, however works with JAX<=0.1.77.
It has troubles at one line:
print(depth, type(depth))# 4 <class 'int'>
num_splittings = 2 ** (depth - 1) - 1
keys = random.split(key, num_splittings)
where depth is a concrete python int as shown.
It thinks num_splittings a Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=2/1)>, however depth is a constant, and is not an argument to any jitted function.
Any insights would be helpful. What changed in JAX-0.2.0?
Here is the attached trace back: error.log
EDIT:
I note that at one point depth is computed as:
num_clusters = sampler_state.mu.shape[0] # 2^(depth-1)
depth = jnp.log2(num_clusters) + 1
depth = depth.astype(jnp.int_)
Why isn't depth constant folded?
JAX==0.2.0 :) We're not even at 1.0 yet.
Why isn't
depthconstant folded?
The purpose of #3370 was to stop doing any constant folding in Python. Staging out more to XLA means we can generate much more memory-efficient code in a lot of circumstances. But since staged-out computations are delayed and not available at Python tracing time, that means there are some programs that could be successfully traced before (due to Python trace-time constant folding) that can't be successfully traced now.
As a result, anything you want evaluated at trace-time should use the original numpy (as in import numpy as np and depth = np.log2(num_clusters) + 1)).
Can you try making that change in your code? (There might be other spots to use np instead of jnp too; the error message suggests several but I'm not sure if it's being overzealous.)
By the way, we bumped the version to 0.2.0 exactly because this is a big change.
I have an explainer doc that I need to make public...
Nice, I will have to update to omnistaging workflow, but I see the value. It'll take precaution, because somethings must stay concrete while others need not. I'll post back once I update and see the error go away
If you need to get unstuck ASAP, you can do
from jax.config import config
config.disable_omnistaging()
(Or set the JAX_OMNISTAGING=0 env var, or use the --jax_omnistaging=0 flag if you parse flags.)
But unless you're under time pressure I suggest trying to make your code work with omnistaging.
We're eager to help if you run into any issues, so don't be shy about them!
See #4410 for a doc explaining common issues. I think it might have helped here, but if not let's improve it!
Since merging #4410, I'm going to close this issue. But please open new ones with issues that you need help with!
Btw, this was solved by updating to omnistaging compliance.