Jax: Tracing a function that indexes into Numpy array gives a poor error message

Created on 18 May 2020  路  4Comments  路  Source: google/jax

The following code fails on the last line

f = lambda i: jnp.zeros((3, 3))[i, :]
g = lambda i: np.zeros((3, 3))[i, :]

a = np.array([1, 2])

f(a)           # Okay
jax.jit(f)(a)  # Okay

g(a)           # Okay
jax.jit(g)(a)  # Fail

with the standard error message

Tracer can't be used with raw numpy functions. You might have
  import numpy as np
instead of
  import jax.numpy as np

The cause of the error is attempting to trace the __getitem__ method of a raw numpy tensor. Normally "Tracer can't be used ..." errors are easy to spot because the offending call starts with np., but this error is a bit more subtle and takes more time to track down. Also, binary operations that mix numpy and JAX arrays work fine, so it this is an exceptional case.

Is there any way to improve this error message / detect this case? At the extreme end, could jax do without implementing the __array__ method for implicit conversions (and replace with an explicit conversion method), to reduce the mental overhead associated with these conversions?

better_errors

Most helpful comment

We could just call out indexing in the error message too (in all cases, not just when we can detect it). I think we've seen this multiple times.

All 4 comments

I think the only way to specialize this error would be to do some sort of call stack tracing in __array__. The most useful thing may be to print the context of the call as part of the error; for example, something like this:

import inspect

class MyClass:
  def __array__(self):
    frame = inspect.currentframe()
    call_frame = inspect.getouterframes(frame, 3)
    prefix = lambda i: '--> ' if i == call_frame[1].index else '    '
    lines = [prefix(i) + line for i, line in enumerate(call_frame[1].code_context)]
    raise ValueError('Tracer error:\n\n' + ''.join(lines))

x = np.zeros(10)
m = MyClass()
x[m]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-36-70e4e8c61a15> in <module>()
     11 x = np.zeros(10)
     12 m = MyClass()
---> 13 x[m]

<ipython-input-36-70e4e8c61a15> in __array__(self)
      7     prefix = lambda i: '--> ' if i == call_frame[1].index else '    '
      8     lines = [prefix(i) + line for i, line in enumerate(call_frame[1].code_context)]
----> 9     raise ValueError('Tracer error:\n\n' + ''.join(lines))
     10 
     11 x = np.zeros(10)

ValueError: Tracer error:

    x = np.zeros(10)
    m = MyClass()
--> x[m]

I don't think the array slicing itself leaves a frame in the call stack, unfortunately. A more typical example would be something like jnp.exp(1 + 0.5 * x[m]) or longer. In that case, Jake's class will look like it is flagging jnp.exp but really it is flagging the x[m].

The message would be something like --> jnp.exp(1 + 0.5 * x[m]) that, together with the suggestion in the error message that you have the wrong import, will suggest the exp pretty strongly.

The error message could probably be improved either way

  1. It should probably mention methods, too, e.g., Tracer can't be used with raw numpy functions or methods on numpy arrays
  2. We not longer recommend import jax.numpy as np

We could just call out indexing in the error message too (in all cases, not just when we can detect it). I think we've seen this multiple times.

Was this page helpful?
0 / 5 - 0 ratings