The following script returns nan when using jit(grad(f)) in CPU (it works in GPU).
from jax import numpy as np, grad, jit
from jax.config import config; config.update('jax_platform_name', 'cpu')
from jax.scipy.special import logsumexp
def f(params):
x = params['loc'] + params['scale'] * np.array([-1.5, 0.])
term1 = 0.5 * ((np.linalg.norm(x, axis=-1) - 2) / 0.4) ** 2
term2 = -0.5 * ((x[..., :1] + np.array([-2., 2.])) / 0.6) ** 2
return term1 - logsumexp(term2, axis=-1)
params = {'loc': np.zeros(2), 'scale': np.ones(2)}
grad(f)(params), jit(grad(f))(params)
which returns
({'loc': DeviceArray([4.51388836, 0. ], dtype=float32),
'scale': DeviceArray([-6.77083254, 0. ], dtype=float32)},
{'loc': DeviceArray([nan, -0.], dtype=float32),
'scale': DeviceArray([nan, 0.], dtype=float32)})
This seems related to fast-math mode. If you set the environment variable XLA_FLAGS=--xla_cpu_enable_fast_math=false both give identical outputs.
Interestingly it also doesn't reproduce with jaxlib built from head. I bisected the change to https://github.com/tensorflow/tensorflow/commit/af4eb9c864563a98cd12c2f731b06b722f17141d and I'm following up with the XLA folks to see whether we should declare this fixed or whether this is masking an underlying problem.
We haven't seen this issue in a long time and it doesn't reproduce at head. Closing.
Thanks, @hawkinsp ! This should have been resolved for a while.