Hi everyone
I'm getting gradients that consist entirely of NaNs in the latest version(s) of jax/jaxlib. The gradients are fine with jax==0.1.59 and jaxlib==0.1.40.
I've tried following some of the tips in this related issue: [Tips for debugging NaNs in gradient?
I've made two Colab notebooks to show the issue. The error happens at the bottom of the second during grad compilation/evaluation. The notebooks are exactly the same except for the jax/jaxlib versions.
cloth_grad_working.ipynb (jax==0.1.59, jaxlib==0.1.40)
cloth_grad_nan.ipynb (jax=0.1.66, jaxlib==0.1.47)
I can also confirm that the NaN gradients occur with the current Colab default versions, jax==0.1.64 and jaxlib==0.1.45.
This differentiable cloth simulator looks really cool!
If you're able to simplify the broken code into a minimal example that only shows the bug but no unnecessary details, that would be very helpful for debugging.
Thanks for looking into this.
Here's a notebook with the same problem, but a lot less clutter:
simpler_repro.ipynb
It's not a minimal repro yet, but I hope that it helps. The error I get now is also slightly more informative:
FloatingPointError: invalid value (nan) encountered in mul
I'll try to create an even more minimal repro this weekend.
@shoyer I've narrowed it down a bit. Here's a minimal reproduction:
import jax
import jax.numpy as np
from jax import grad
from jax.config import config
config.update("jax_debug_nans", True)
config.update('jax_enable_x64', True)
def hooke(pa, pb):
d = pb - pa
l = np.linalg.norm(d)
l = np.max([l, 0.001])
return l
def loss(p):
return hooke(p, p)
grad(loss)(np.ones(3))
Colab notebook: minimal_repro.ipynb
The code above raises the following error in jax==0.1.64 but not for jax==0.1.59:
FloatingPointError: invalid value (nan) encountered in mul
The problem seems to be related to taking the norm of a zero vector, so maybe the NaN grad is intended? I'm not sure, I would have expected the gradient to be zero in this case.
Thanks so much for the report, and clear repros! A differentiable cloth simulator is _the coolest_ idea.
Indeed this is known and in a way intentional, though not because we are big jerks who like inflicting nans on users. There's a fundamental array-level-autodiff problem here to do with how nans break linearity assumptions. This comment on #1052 has a more thorough description, and #2447 describes the canonical ways to avoid nans in these situations.
The "array-level" part of "array-level-autodiff" is salient here. Interestingly, using your last minimal repro (which I realize might not be a perfect representative of your real use case) we can avoid the issue by replacing this line
l = np.max([l, 0.001])
with this:
l = max(l, 0.001)
Then the gradient is zero as expected!
Otherwise, keeping with array operations rathe than Python builtins and control flow, the fix in user code (following the pattern of #2447) would be something like this:
import jax
import jax.numpy as np
from jax import grad
from jax.config import config
config.update("jax_debug_nans", True)
config.update('jax_enable_x64', True)
def hooke(pa, pb):
assert pa.ndim == pb.ndim == 1 # assuming 1D for simplicity
d = pb - pa
is_zero = np.allclose(d, 0.)
d = np.where(is_zero, np.ones_like(d), d) # replace d with ones if is_zero
l = np.linalg.norm(d)
l = np.where(is_zero, 0., l) # replace norm with zero if is_zero
l = np.max([l, 0.001])
return l
def loss(p):
return hooke(p, p)
print(grad(loss)(np.ones(3)))
Kinda awkward!
Perhaps a better fix on our end (rather than in user code) would be just to define a custom JVP rule for np.linalg.norm to encode a convention for differentiation at zero, like in this tutorial example. (EDIT: fixed link)
@Victorlouisdg FYI I had to fix the link to the tutorial just now, because I had linked to the version in my Google drive instead of the one on GitHub.
By the way, I think #2447 is the change that made these nans appear for you.
Interesting! Thanks for the clear explanation. Reading through the threads/tutorial I definitely understand how JAX works a bit better now :)
I was quite surprised that taking the norm of a vector was the source of the nans. It _feels_ like an innocent operation. But it makes sense because norm() uses sqrt() which isn't differentiable at zero.
I think I'll try to write a norm with a custom_jvp/vjp that's 1 at 0. And if I can't figure that out, I'll just use the workaround, so this issue is resolved for me.
Glad we were able to help! Don't hesitate to open more issues as they come up :)