The following code fails on the last line
f = lambda i: jnp.zeros((3, 3))[i, :]
g = lambda i: np.zeros((3, 3))[i, :]
a = np.array([1, 2])
f(a) # Okay
jax.jit(f)(a) # Okay
g(a) # Okay
jax.jit(g)(a) # Fail
with the standard error message
Tracer can't be used with raw numpy functions. You might have
import numpy as np
instead of
import jax.numpy as np
The cause of the error is attempting to trace the __getitem__ method of a raw numpy tensor. Normally "Tracer can't be used ..." errors are easy to spot because the offending call starts with np., but this error is a bit more subtle and takes more time to track down. Also, binary operations that mix numpy and JAX arrays work fine, so it this is an exceptional case.
Is there any way to improve this error message / detect this case? At the extreme end, could jax do without implementing the __array__ method for implicit conversions (and replace with an explicit conversion method), to reduce the mental overhead associated with these conversions?
I think the only way to specialize this error would be to do some sort of call stack tracing in __array__. The most useful thing may be to print the context of the call as part of the error; for example, something like this:
import inspect
class MyClass:
def __array__(self):
frame = inspect.currentframe()
call_frame = inspect.getouterframes(frame, 3)
prefix = lambda i: '--> ' if i == call_frame[1].index else ' '
lines = [prefix(i) + line for i, line in enumerate(call_frame[1].code_context)]
raise ValueError('Tracer error:\n\n' + ''.join(lines))
x = np.zeros(10)
m = MyClass()
x[m]
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-36-70e4e8c61a15> in <module>()
11 x = np.zeros(10)
12 m = MyClass()
---> 13 x[m]
<ipython-input-36-70e4e8c61a15> in __array__(self)
7 prefix = lambda i: '--> ' if i == call_frame[1].index else ' '
8 lines = [prefix(i) + line for i, line in enumerate(call_frame[1].code_context)]
----> 9 raise ValueError('Tracer error:\n\n' + ''.join(lines))
10
11 x = np.zeros(10)
ValueError: Tracer error:
x = np.zeros(10)
m = MyClass()
--> x[m]
I don't think the array slicing itself leaves a frame in the call stack, unfortunately. A more typical example would be something like jnp.exp(1 + 0.5 * x[m]) or longer. In that case, Jake's class will look like it is flagging jnp.exp but really it is flagging the x[m].
The message would be something like --> jnp.exp(1 + 0.5 * x[m]) that, together with the suggestion in the error message that you have the wrong import, will suggest the exp pretty strongly.
The error message could probably be improved either way
Tracer can't be used with raw numpy functions or methods on numpy arraysimport jax.numpy as npWe could just call out indexing in the error message too (in all cases, not just when we can detect it). I think we've seen this multiple times.
Most helpful comment
We could just call out indexing in the error message too (in all cases, not just when we can detect it). I think we've seen this multiple times.