Jax: Compile error on updating to JAX-0.2.0

Created on 25 Sep 2020  路  7Comments  路  Source: google/jax

A piece of a larger algorithm is failing to compile with JAX==0.2.0 installed from pip, however works with JAX<=0.1.77.
It has troubles at one line:

print(depth, type(depth))# 4 <class 'int'>
num_splittings = 2 ** (depth - 1) - 1
keys = random.split(key, num_splittings)

where depth is a concrete python int as shown.
It thinks num_splittings a Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=2/1)>, however depth is a constant, and is not an argument to any jitted function.
Any insights would be helpful. What changed in JAX-0.2.0?

Here is the attached trace back: error.log

EDIT:

I note that at one point depth is computed as:

num_clusters = sampler_state.mu.shape[0]  # 2^(depth-1)
depth = jnp.log2(num_clusters) + 1
depth = depth.astype(jnp.int_)

Why isn't depth constant folded?

documentation question

All 7 comments

JAX==0.2.0 :) We're not even at 1.0 yet.

Why isn't depth constant folded?

The purpose of #3370 was to stop doing any constant folding in Python. Staging out more to XLA means we can generate much more memory-efficient code in a lot of circumstances. But since staged-out computations are delayed and not available at Python tracing time, that means there are some programs that could be successfully traced before (due to Python trace-time constant folding) that can't be successfully traced now.

As a result, anything you want evaluated at trace-time should use the original numpy (as in import numpy as np and depth = np.log2(num_clusters) + 1)).

Can you try making that change in your code? (There might be other spots to use np instead of jnp too; the error message suggests several but I'm not sure if it's being overzealous.)

By the way, we bumped the version to 0.2.0 exactly because this is a big change.

I have an explainer doc that I need to make public...

Nice, I will have to update to omnistaging workflow, but I see the value. It'll take precaution, because somethings must stay concrete while others need not. I'll post back once I update and see the error go away

If you need to get unstuck ASAP, you can do

from jax.config import config
config.disable_omnistaging()

(Or set the JAX_OMNISTAGING=0 env var, or use the --jax_omnistaging=0 flag if you parse flags.)

But unless you're under time pressure I suggest trying to make your code work with omnistaging.

We're eager to help if you run into any issues, so don't be shy about them!

See #4410 for a doc explaining common issues. I think it might have helped here, but if not let's improve it!

Since merging #4410, I'm going to close this issue. But please open new ones with issues that you need help with!

Btw, this was solved by updating to omnistaging compliance.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

asross picture asross  路  3Comments

sussillo picture sussillo  路  3Comments

lhk picture lhk  路  3Comments

kunc picture kunc  路  3Comments

yfji picture yfji  路  3Comments