Similar issue for a sigmoid as for #764:
import jax
import jax.numpy as np
from jax import grad
print(jax.version.__version__)
x = np.linspace(-200, 0, 4)
sig = lambda x: np.sum(1/(1+np.exp(-x)))
f = grad(sig)
print(x)
f(x)
0.1.35
[-200. -133.33333333 -66.66666667 0. ]
DeviceArray([ nan, nan, nan, 0.24999999],
dtype=float32)
Interesting! It seems not a problem of exp or its grad but a problem of inverse of inf. On my system,
1 / np.array([np.inf, 1., 1., 1.]) returns nan while 1 / np.array([np.inf, 1., 1.]) returns 0.
Thanks for the report! I filed an (internal) bug with the XLA team. I'm guessing this is a bug with vectorization not honoring infs correctly in fast math mode
As a workaround, you might try either disabling fast math by setting the environment variable:
XLA_FLAGS=--xla_cpu_enable_fast_math=false
or by defining sigmoid using a rescaled tanh:
sigmoid = lambda x: 0.5 + 0.5*np.tanh(0.5*x)
which should work fine already because of the fix to #764.
We updated XLA to a new version that turns off approximate semantics for reciprocals by default. This fixes the problem in this issue, although it requires that you either rebuild Jaxlib from source or wait for us to make a new release. Hope that helps!
Most helpful comment
We updated XLA to a new version that turns off approximate semantics for reciprocals by default. This fixes the problem in this issue, although it requires that you either rebuild Jaxlib from source or wait for us to make a new release. Hope that helps!