Jax: np.isnan doesn't work on CPU in fast-math mode

Created on 25 Jan 2019  Â·  12Comments  Â·  Source: google/jax

as a result np.nan_to_num, np.nanmean, etc all don't work

import jax.numpy as np
a = np.zeros(1) / np.zeros(1)
print(a.__array__())
print(np.isnan(a).__array__())

[nan]
[False]

This bug only happens with the CPU-only build of JAX, when I see this warning: "warnings.warn('No GPU found, falling back to CPU.')"

question

Most helpful comment

We just pushed out jaxlib 0.1.13, which should fix this problem.

Parts of fast math are still enabled by default for performance, but the semantics of NaNs and Infs should now be honored. Please file a new issue if you see any further problems!

All 12 comments

This happens because XLA's CPU backend defaults to enabling fast math mode, which does not preserve nan/inf semantics. The GPU backend does not. Note the comment here:
https://github.com/google/jax/blob/master/jax/numpy/lax_numpy.py#L699

# Caution: If fast math mode is enabled, the semantics of inf and nan are not
# preserved by XLA/LLVM, and the behavior of inf/nan values is unpredictable.
# To disable fast math mode on CPU, set the environment variable
# XLA_FLAGS=--xla_cpu_enable_fast_math=false

The XLA_FLAGS environment variable above makes your example pass.

I guess the important question is: should we disable fast math mode by default? Are exact NaN semantics important to you?

I think consistency between CPU and GPU is more important than performance in this case. There can still be a _performance tips_ section that explains how to activate _fast math_.

I just got surprised by this, too. Maybe another option is to print a warning at startup, adding to the "No GPU found"?

A brief update on this bug: we tried disabling fastmath in XLA/CPU by default, but found it regressed performance for some neural network benchmarks significantly because it prevents vectorization in some important cases.

https://reviews.llvm.org/D57728 apparently fixes the performance problem, but it isn't in yet. I'm hoping we can simply disable fast math by default when that change makes it into LLVM.

A warning makes sense until we do so, I guess.

I also got surprised by this (I am using a CPU). Here is a simple example:

import numpy as onp # original numpy
import jax.numpy as np
print(np.isnan(np.nan)) #F
print(onp.isnan(np.nan)) #T
print(np.isnan(onp.nan)) #F
print(onp.isnan(onp.nan)) #T

Maybe worth mentioning the issue on the jax homepage (the comment is currently buried deep in the gotchas colab)

I also tried to set the environment flag but to no avail (is my syntax correct?)

import os
os.environ["XLA_FLAGS"]="--xla_cpu_enable_fast_math=false"

print(np.isnan(np.nan)) #F
print(onp.isnan(np.nan)) #T
print(np.isnan(onp.nan)) #F
print(onp.isnan(onp.nan)) #T

Did that os.environ come before importing anything from jax? That might be necessary.

Great idea re: mentioning it in the readme. I'll add it now.

yes, I did the os.environ thing first. I am running inside Spyder IDE.
Full script:

import os
os.environ["XLA_FLAGS"]="--xla_cpu_enable_fast_math=false"

import numpy as onp # original numpy
import jax.numpy as np

print(np.isnan(np.nan)) #F
print(onp.isnan(np.nan)) #T
print(np.isnan(onp.nan)) #F
print(onp.isnan(onp.nan)) #T

Thanks. Hrm I was unable to repro in my local environment (which I tried before my previous guess about os.environ going first):


In [1]: import os

In [2]: os.environ["XLA_FLAGS"] = "--xla_cpu_enable_fast_math=false"

In [3]: import jax.numpy as np

In [4]: print(np.isnan(np.nan))
True

Not sure how to chase that down further. In any case, we'll fix CPU nan issues ASAP.

weird. I am using python 3.7 and jax 0.1.23 (latest pip binary)

On Fri, Apr 5, 2019 at 1:03 PM Matthew Johnson notifications@github.com
wrote:

Thanks. Hrm I was unable to repro in my local environment (which I tried
before my previous guess about os.environ going first):

In [1]: import os

In [2]: os.environ["XLA_FLAGS"] = "--xla_cpu_enable_fast_math=false"

In [3]: import jax.numpy as np

In [4]: print(np.isnan(np.nan))
True

Not sure how to chase that down further. In any case, we'll fix CPU nan
issues ASAP.

—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/276#issuecomment-480405007, or mute
the thread
https://github.com/notifications/unsubscribe-auth/AEavEODkAa3GDDesn1Rq9P30sp7z75a8ks5vd6wmgaJpZM4aSH3b
.

We just pushed out jaxlib 0.1.13, which should fix this problem.

Parts of fast math are still enabled by default for performance, but the semantics of NaNs and Infs should now be honored. Please file a new issue if you see any further problems!

Was this page helpful?
0 / 5 - 0 ratings