Jax: Gradient of `np.exp` sometimes causes invalid values

Created on 27 May 2019  路  3Comments  路  Source: google/jax

Similar issue for a sigmoid as for #764:

import jax
import jax.numpy as np
from jax import grad
print(jax.version.__version__)
x = np.linspace(-200, 0, 4)
sig = lambda x: np.sum(1/(1+np.exp(-x)))
f = grad(sig)
print(x)
f(x)
0.1.35
[-200.         -133.33333333  -66.66666667    0.        ]
DeviceArray([       nan,        nan,        nan, 0.24999999],
            dtype=float32)
bug

Most helpful comment

We updated XLA to a new version that turns off approximate semantics for reciprocals by default. This fixes the problem in this issue, although it requires that you either rebuild Jaxlib from source or wait for us to make a new release. Hope that helps!

All 3 comments

Interesting! It seems not a problem of exp or its grad but a problem of inverse of inf. On my system,
1 / np.array([np.inf, 1., 1., 1.]) returns nan while 1 / np.array([np.inf, 1., 1.]) returns 0.

Thanks for the report! I filed an (internal) bug with the XLA team. I'm guessing this is a bug with vectorization not honoring infs correctly in fast math mode

As a workaround, you might try either disabling fast math by setting the environment variable:

XLA_FLAGS=--xla_cpu_enable_fast_math=false

or by defining sigmoid using a rescaled tanh:

sigmoid = lambda x: 0.5 + 0.5*np.tanh(0.5*x)

which should work fine already because of the fix to #764.

We updated XLA to a new version that turns off approximate semantics for reciprocals by default. This fixes the problem in this issue, although it requires that you either rebuild Jaxlib from source or wait for us to make a new release. Hope that helps!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

clemisch picture clemisch  路  3Comments

kunc picture kunc  路  3Comments

clemisch picture clemisch  路  3Comments

murphyk picture murphyk  路  3Comments

yfji picture yfji  路  3Comments