I'll disclaim that I'm not sure if this is a bug that should be fixed in Jax, or if the Jax "Sharp Bits" documentation should discuss this.
I originally stumbled upon an issue where my jit function was behaving strangely, and behaving differently than the non-jit version. Upon going over the Sharp Bits, I think this pertains to the "functions with argument-value dependent shapes" section. At first I thought it was simply my mistake, but then it dawned on me that the function doesn't in fact have value dependent shapes.
Here's a reproducible example:
import jax
import jax.numpy as jnp
def fn(a, qs, idx):
return a[idx, :, qs].mean(axis=-1).squeeze()
jit_fn = jax.jit(fn)
a = jnp.ones((24, 8, 11))
qs = jnp.arange(11)
fn(a, qs, 4).shape # outputs (8,) as expected
jit_fn(a, qs, 4).shape # outputs (11,)
Note that the shape of fn does not actually depend on the values of qs. In fact, none of the "intermediate steps" of the function produce any array whose shape depends on the values of qs. I fixed the problem by rewriting the functions as follows:
def better_fn(a, idx):
return a[idx, :].mean(axis=-1).squeeze()
jit_better_fn = jax.jit(better_fn)
b = a[:, :, qs]
better_fn(b, idx).shape # outputs (8,)
jit_better_fn(b, idx).shape # outputs (8,)
This is more evidence to me that the Sharp Bits section of the documentation applies to this case, however I believe that the statement that "specializing on argument shapes is ok" from the documentation is ambiguous.
Thanks - this is definitely a bug, and I think it's unrelated to that sharp-bits section. It's something about the XLA translation rule for numpy's fancy indexing.
By the way, any time you do have value-dependent shapes (not here) JAX will raise a clear error message, not just silently do the wrong thing.
This just looks like a really surprising indexing bug...
Turns out the mean and squeeze are superfluous; the issue is a transpose in jitted code that combines single indexing, slicing, and fancy indexing:
import jax
import jax.numpy as jnp
def fn(a, i, q):
return a[i, :, q]
jit_fn = jax.jit(fn)
a = jnp.arange(6).reshape(1, 2, 3)
q = jnp.arange(3)
print(fn(a, 0, q))
# [[0 1 2]
# [3 4 5]]
print(jit_fn(a, 0, q))
# [[0 3]
# [1 4]
# [2 5]]
Thanks for the quick responses! It was a bit of a tricky one to debug, I thought for sure I was doing something wrong.
via offline chat with @mattjj, it looks like the abstract eval is returning the wrong shape:
jax.eval_shape(fn, a, 0, q)
# ShapeDtypeStruct(shape=(3, 2), dtype=int32)
(using my simplified repro).
Indeed, that's not causative, but it shows that XLA isn't to blame.
[EDIT: oops, Jake already observed this, my comment is redundant] I think this is a plain-old bug in our _rewriting_take which translates NumPy indexing to XLA operations. If we remove the mean(axis=-1).squeeze() part, we see that the result of the indexing expression should be of shape (8, 11), but we're producing something of shape (11, 8).
Update: issue appears to actually be on the non-jit codepath:
import jax
import jax.numpy as jnp
def fn(a, i):
return a[i, :, np.arange(3)]
print(fn(a, 0))
# [[0 1 2]
# [3 4 5]]
print(fn(a, jnp.array(0)))
# [[0 3]
# [1 4]
# [2 5]]
Note that the latter actually matches numpy's behavior:
import numpy as np
a = np.arange(6).reshape(1, 2, 3)
a[0, :, np.arange(3)]
# array([[0, 3],
# [1, 4],
# [2, 5]])
To be honest, I don't entirely understand why numpy gives this result. I would have expected it to be the same as this, but it is not:
a[0, :, :]
# array([[0, 1, 2],
# [3, 4, 5]])
Even further simplified:
import jax
import jax.numpy as jnp
a = jnp.arange(6).reshape(1, 2, 3)
print(a[0, :, jnp.arange(3)])
# [[0 1 2]
# [3 4 5]]
print(a[jnp.array(0), :, jnp.arange(3)])
# [[0 3]
# [1 4]
# [2 5]]
@mattjj points out that the issue is related to this comment: https://github.com/google/jax/blob/0660939ab016fe93aa0a897ba2f05343c6f3f380/jax/numpy/lax_numpy.py#L3979-L3981
We are treating devicearray scalars as if they are advanced indices; the fix is to treat all scalars the same.
The fix is in #4556. Note, however, that the result is that the return shape is (11,) both with and without jit, and that this agrees with numpy. On the branch on that PR, this is what I get:
import jax
import jax.numpy as jnp
import numpy as np
def fn(a, qs, idx):
return a[idx, :, qs].mean(axis=-1).squeeze()
jit_fn = jax.jit(fn)
a = np.ones((24, 8, 11))
qs = np.arange(11)
print(fn(a, qs, 4).shape)
# (11,)
a = jnp.ones((24, 8, 11))
qs = jnp.arange(11)
print(fn(a, qs, 4).shape)
# (11,)
print(jit_fn(a, qs, 4).shape)
# (11,)
Most helpful comment
The fix is in #4556. Note, however, that the result is that the return shape is
(11,)both with and without jit, and that this agrees with numpy. On the branch on that PR, this is what I get: