Jax: jitting grad gives wrong result in CPU

Created on 27 Jun 2019  路  4Comments  路  Source: google/jax

The following script returns nan when using jit(grad(f)) in CPU (it works in GPU).

from jax import numpy as np, grad, jit
from jax.config import config; config.update('jax_platform_name', 'cpu')
from jax.scipy.special import logsumexp

def f(params):
    x = params['loc'] + params['scale'] * np.array([-1.5, 0.])
    term1 = 0.5 * ((np.linalg.norm(x, axis=-1) - 2) / 0.4) ** 2
    term2 = -0.5 * ((x[..., :1] + np.array([-2., 2.])) / 0.6) ** 2
    return term1 - logsumexp(term2, axis=-1)

params = {'loc': np.zeros(2), 'scale': np.ones(2)}
grad(f)(params), jit(grad(f))(params)

which returns

({'loc': DeviceArray([4.51388836, 0.       ], dtype=float32),
  'scale': DeviceArray([-6.77083254,  0.       ], dtype=float32)},
 {'loc': DeviceArray([nan, -0.], dtype=float32),
  'scale': DeviceArray([nan,  0.], dtype=float32)})
bug

All 4 comments

This seems related to fast-math mode. If you set the environment variable XLA_FLAGS=--xla_cpu_enable_fast_math=false both give identical outputs.

Interestingly it also doesn't reproduce with jaxlib built from head. I bisected the change to https://github.com/tensorflow/tensorflow/commit/af4eb9c864563a98cd12c2f731b06b722f17141d and I'm following up with the XLA folks to see whether we should declare this fixed or whether this is masking an underlying problem.

We haven't seen this issue in a long time and it doesn't reproduce at head. Closing.

Thanks, @hawkinsp ! This should have been resolved for a while.

Was this page helpful?
0 / 5 - 0 ratings