I found an issue related to jit compilation with np.arange when it is indexed dynamically. In particular, the code below produces the produces the following error.
def fun(x):
r = np.arange(x.shape[0])[x]
return r
jit(fun)(np.array([0, 1, 2], dtype=np.int32))
Thanks for raising this, and for the simplified repro. It's currently expected behavior but it's pretty lame, and maybe we can improve it.
The issue is that np.arange(x.shape[0]) creates a plain-old ndarray, and then indexing an ndarray with one of our tracers (like x will be here) means that the ndarray is asked what to do with this weird type it's never seen before.
A workaround for this case, which doesn't have all the functionality of nparange, is to use lax.iota:
import jax.numpy as np
from jax import lax
def fun(x):
r = lax.iota(np.int32, x.shape[0])[x]
return r
jit(fun)(np.array([0, 1, 2], dtype=np.int32))
I think we can cover this use case better if we finish off the commented-out np.arange implementation in terms of lax.iota, which won't be able to handle all cases but could handle ones like these.
Ah thanks Matt, makes sense! It turns out one can also do
import jax.numpy as np
from jax.api import device_put
def fun(x):
r = device_put(np.arange(x.shape[0]))[x]
return r
jit(fun)(np.array([0, 1, 2], dtype=np.int32))
which might be useful, if less efficient, in cases where the more advanced functionality of np.arange is required.
You can probably also call np.take here. So long as we're not dispatching on a raw ndarray things are okay.
Relatedly, by falling back to onp we generate weird error messages for things like jit(np.arange)(3) (which is an error, but the message should be clear).
It seems this was fixed long ago!
I'm getting a similar error when doing the following:
import jax.numpy as jnp
from jax import jit
jit(jnp.arange)(1,2)
Doing the same thing without jit is fine (this is on '0.1.69').
The error message is more verbose, but I think it has the same gist?
Exception: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced.
This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using
jnptogether withimport jax.numpy as jnprather than usingnpviaimport numpy as np. If this error arises on a line that involves array indexing, likex[idx], it may be that the array being indexedxis a raw numpy.ndarray while the indicesidxare a JAX Tracer instance; in that case, you can instead writejax.device_put(x)[idx].