Jax: np.exp does not return inf on overflow in fast math mode

Created on 2 Mar 2019  路  8Comments  路  Source: google/jax

For example, np.exp(100) will return 100 while it should return inf.

bug

All 8 comments

That's a bit awkward. I'll take a look.

Sorry @hawkinsp ! It seems that the problem is resolved after I rebuild jax and use master version.

I can actually reproduce this, though. This is fast-math mode in the CPU backend, rearing its ugly head once again. It works fine with fast math mode disabled:

$ XLA_FLAGS=--xla_cpu_enable_fast_math=false ipython
Python 3.7.2 (default, Jan 16 2019, 11:36:28)
Type 'copyright', 'credits' or 'license' for more information
IPython 7.2.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: from jax import numpy as np

In [2]: np.exp(100.)
/Users/phawkins/p/jax/jax/lib/xla_bridge.py:167: UserWarning: No GPU found, falling back to CPU.
  warnings.warn('No GPU found, falling back to CPU.')
Out[2]: array(inf, dtype=float32)

So I think the right fix is to disable fast math mode by default, something also tracked in #276 .

Thanks for confirming it! I first thought that it is dtype mess in my previous build. I just get this issue again. :(

Note you can work around by setting the environment variable XLA_FLAGS=--xla_cpu_enable_fast_math=false.

Got it! Thanks @hawkinsp !

This was fixed as part of a recent Jaxlib update (0.1.13). Let us know if you see more problems.

Thanks for notifying me @hawkinsp !

Was this page helpful?
0 / 5 - 0 ratings