Jax: Tips for debugging NaNs in gradient?

Created on 4 Mar 2019  路  21Comments  路  Source: google/jax

Hi there,

I am running an optimisation using gradients from Jax, and everything goes well for a number of steps until the gradients returned are all nan. I am having a bit of a hard time tracking down where the problem is; the forward calculations all seem to be fine.

Is there some way I can work out which operation is causing the nans from grad? This would be really useful.

Thanks!

bug

Most helpful comment

Blocking a user is the worst feeling! That's a magic word to get us to help you out ASAP :)

I added some basic nan debugging machinery in #482. As with other config options there are a few ways to turn it on:

  1. you can set the JAX_DEBUG_NANS environment variable to something truthy,
  2. you can add from jax.config import config and config.update("jax_debug_nans", True) near the top of your main file,
  3. you can add from jax.config import config and config.parse_flags_with_absl() to your main file, then set the option using a command-line flag like --jax_debug_nans=True.

Switching that option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an @jit. For code under an @jit, the output of every @jit function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode (effectively removing one level of @jit at a time).

There could be tricky situations that arise, like nans that only occur under a @jit but don't get produced in de-optimized mode. In that case you'll see a warning message print out but your code will continue to execute, so we can dig in deeper.

If the nans are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you'll be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. If it's not immediately obvious, you can poke around a bit to find the primitive that's producing the nan by doing things in an interactive debugger like p eqn.primitive in that stack frame.

How does that sound? This is a good opportunity to add exactly the tooling we want; JAX is tiny and easy to instrument, so there's no reason not to get this right.

All 21 comments

Great question. I don't think we have a solution in place for this right now, but I think we can make one.

There are at least two things to solve here:

  1. set up the equivalent of np.seterr(invalid="raise")
  2. catch nans on the backward pass, and associate them helpfully with user code

Thanks Matt, great that you think it's worthwhile to enhance this!

Although it's cool you have a clear roadmap, I am actually really blocked by this at the moment and was wondering if there are any things I could do in the meantime? I'd be happy to dig into some of the backend if required. I've already changed to float64 which has helped but not resolved things.

Thanks for letting us know. Any chance you can share a small repro? I just want to make sure we provide the right pointers or tools.

Blocking a user is the worst feeling! That's a magic word to get us to help you out ASAP :)

I added some basic nan debugging machinery in #482. As with other config options there are a few ways to turn it on:

  1. you can set the JAX_DEBUG_NANS environment variable to something truthy,
  2. you can add from jax.config import config and config.update("jax_debug_nans", True) near the top of your main file,
  3. you can add from jax.config import config and config.parse_flags_with_absl() to your main file, then set the option using a command-line flag like --jax_debug_nans=True.

Switching that option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an @jit. For code under an @jit, the output of every @jit function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode (effectively removing one level of @jit at a time).

There could be tricky situations that arise, like nans that only occur under a @jit but don't get produced in de-optimized mode. In that case you'll see a warning message print out but your code will continue to execute, so we can dig in deeper.

If the nans are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you'll be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. If it's not immediately obvious, you can poke around a bit to find the primitive that's producing the nan by doing things in an interactive debugger like p eqn.primitive in that stack frame.

How does that sound? This is a good opportunity to add exactly the tooling we want; JAX is tiny and easy to instrument, so there's no reason not to get this right.

Oh wow, thanks so much! Can't wait to try this out. So sorry, I didn't mean to make you feel bad! It's very likely this is some stupid mistake on my end, but I am very grateful that you added tooling so quickly!

Hah don't worry, I was joking about feeling bad. But it does light a fire under us whenever someone is blocked :)

Check out this comment for a toy example of how to use this. It's going to be a bit hairier to debug nans in the backward pass, but hopefully not too bad.

When you get back to it, let us know how it goes, and any additional issues you run into. This kind of feedback is incredibly helpful, and it's going to pay off a lot in the future when it helps us build an awesome debugging experience.

Hi Matt,

Thanks for this! Indeed it now crashes rather than returning nan, which is great. From the stack trace below, it looks like the mul operation raises the issue:

/Users/ingramm/Projects/software/jax/jax/lib/xla_bridge.py:128: UserWarning: No GPU found, falling back to CPU.
  warnings.warn('No GPU found, falling back to CPU.')
/Users/ingramm/Projects/software/jax/jax/numpy/linalg.py:51: UserWarning: numpy.linalg support is experimental and may cause silent failures or wrong outputs
  warnings.warn(_EXPERIMENTAL_WARNING)
Log posterior is Traced<ConcreteArray(-7522.515563681398)>with<JVPTrace(level=1/0)>.
Log determinant is Traced<ConcreteArray(247.07115427818476)>with<JVPTrace(level=1/0)>
Traceback (most recent call last):
  File "nan_gradient.py", line 179, in <module>
    data['l'], data['b'], data['n_c']))
  File "/Users/ingramm/Projects/software/jax/jax/api.py", line 206, in grad_f
    ans, g = value_and_grad_f(*args, **kwargs)
  File "/Users/ingramm/Projects/software/jax/jax/api.py", line 243, in value_and_grad_f
    g = vjp_py(onp.ones((), onp.result_type(ans)))
  File "/Users/ingramm/Projects/software/jax/jax/api_util.py", line 56, in apply_jaxtree_fun
    ans = fun(*args)
  File "/Users/ingramm/Projects/software/jax/jax/api.py", line 570, in out_vjp_packed
    return out_vjp(cotangent_in)
  File "/Users/ingramm/Projects/software/jax/jax/interpreters/ad.py", line 81, in vjp_
    _, arg_cts = backward_pass(jaxpr, consts, (), dummy_args, dummy_primal_and_ct)
  File "/Users/ingramm/Projects/software/jax/jax/interpreters/ad.py", line 139, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(ct_in, *invals, **eqn.params)
  File "/Users/ingramm/Projects/software/jax/jax/interpreters/ad.py", line 338, in bilinear_transpose
    out = zero if cotangent is zero else lhs_rule(cotangent, y, **kwargs)
  File "/Users/ingramm/Projects/software/jax/jax/lax.py", line 242, in mul
    return mul_p.bind(x, y)
  File "/Users/ingramm/Projects/software/jax/jax/core.py", line 75, in bind
    return self.impl(*args, **kwargs)
  File "/Users/ingramm/Projects/software/jax/jax/interpreters/xla.py", line 54, in apply_primitive
    return compiled_fun(*args)
  File "/Users/ingramm/Projects/software/jax/jax/interpreters/xla.py", line 85, in execute_compiled_primitive
    return result_handler(compiled.Execute(input_bufs, not core.skip_checks))
  File "/Users/ingramm/Projects/software/jax/jax/interpreters/xla.py", line 105, in handle_result
    raise FloatingPointError("invalid value")
FloatingPointError: invalid value

I've put up a little reproducible example here:
error.zip
I hope I included all the relevant files, but please let me know if there are any problems with it!

The error happens when I try to calculate the gradient of the log determinant of the negative Hessian. This seems to run fine in autograd. Would be really grateful for any ideas.

Woohoo, the nans were caught just like we wanted! Now we've got this bug on the run.

Thanks for the repro. My guess is that this is coming from multiplying 0 * inf. Maybe there's a canonical value we'd choose here (like 0 if the covector being pulled back is 0, no matter what the Jacobian is). But that might be missing the larger issue, namely that things might be becoming ill-conditioned.

Is the matrix for which you're computing the log determinant becoming very ill-conditioned, or even indefinite? It could be that JAX's linalg has slightly different numerics than Autograd's, and maybe we could improve the stability of some of our Jacobian calculations.

I relabled the issue as a bug because now we're trying to figure out why JAX eventually produces NaNs here where Autograd might not. I suspect it's a question of numerical stability.

I can't look at your code right away, but I plan to get to it later.

Thanks Matt! I haven't read as much about numerical linear algebra as I would like to yet, but it looks like the largest eigenvalue of the matrix is about 5.2E5 and the smallest is 1.76E0, which I guess means we have a condition number of about 3E5 (?). Does that sound problematic? There don't seem to be any obvious issues computing the cholesky etc., I don't run into errors there.

That doesn't sound bad at all, no. Hmm...

(By the way, totally coincidentally I'm flying to Melbourne a week from today.)

Oh awesome, we should meet up if you have any time to spare!

Hi there, just wondering if there might be an update on this? No rush at all. I've tried some other things like changing the determinant by rearranging the equations, as well as using different Cholesky decompositions to calculate the determinant, but have not had any luck so far. No problem though if there are more pressing things / there's no time to look at this right now!

TensorFlow recently made a couple changes to perform all gradient multiplication (products of each Jacobian-transpose and seed) in ops where J-transpose could be infinity using a special multiplication op where 0 * inf is 0. I wonder if that might be the way to go here.

@jekbradbury That's what we do with lax._safe_mul, as in 58749c0. It could be that we need to use it in some more JVPs.

Hi @mattjj , I'm having the same issue, but it's a bit of a different use case. In my case, I think the NaNs are happening because there is an inf * 0 happening somewhere. I'd like to define that be the value of zero.

The context is that I am doing learning in an HMM. I wrote the forward pass to compute a log-normalizer, and I'm using grad to compute expectations. It's actually quite similar to this gist, but I have Gaussian observations. At certain points during EM, the value of the log-normalizer will compute just fine, but the gradients will have NaNs in them.

Here's my log-normalizer function:

def _log_normalizer(log_A, log_likelihoods):
  A, likes = map(np.exp, (log_A, log_likelihoods))
  N, K = likes.shape

  with loops.Scope() as s:
    s.alpha_p = np.ones(K)
    s.log_prob = 0.0
    for t in s.range(N):

      alpha_c = s.alpha_p * likes[t]

      Zt = np.sum(alpha_c)
      s.log_prob += np.log(Zt)
      alpha_c /= Zt 
      s.alpha_p = alpha_c @ A

    return s.log_prob

Do you have a recommendation here? Maybe if I masked out the elements of log_likelihoods that are -inf jax will know to ignore them in the gradient?

FWIW, I am also experiencing NaNs in a similar program that uses loops.Scope() which happens during the backward pass on the while primitive. I have been trying to debug the state, but it seems a bit hard to understand what the buffers do. Any tips on how to handle this?

Thanks!

@mattjj Is there a way to debug NaNs in complex primitives like while-loops? As far as I understand the NaN checks happen somewhere outside the loop so the body is opaque. Would it make sense to allow compiling the while-loop as a series of xla_call to the body function (with checks) for debugging purposes in the JAX XLA interpreter? E.g., replacing

while_loop(cond_fn, body_fn, coll)
with
xla_call(body_fn, coll[0]) ... xla_call(body_fn, coll[n - 1]) etc.?

I am willing to help looking into the implementation myself, but I need a bit of guidance of what can be done.

Thanks in advance!

Nevermind, I have written a high-level Jaxpr interpreter here (based on the documentation): https://github.com/aleatory-science/jaxinterp . I think that will help me debug :)

@mattjj I just wanted to a huge thank you for this feature.

This just got me unstuck after banging my head against the wall for week.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

NeilGirdhar picture NeilGirdhar  路  23Comments

ricardobarroslourenco picture ricardobarroslourenco  路  35Comments

shoyer picture shoyer  路  35Comments

shyoshyo picture shyoshyo  路  26Comments

murphyk picture murphyk  路  31Comments