For example, np.exp(100) will return 100 while it should return inf.
That's a bit awkward. I'll take a look.
Sorry @hawkinsp ! It seems that the problem is resolved after I rebuild jax and use master version.
I can actually reproduce this, though. This is fast-math mode in the CPU backend, rearing its ugly head once again. It works fine with fast math mode disabled:
$ XLA_FLAGS=--xla_cpu_enable_fast_math=false ipython
Python 3.7.2 (default, Jan 16 2019, 11:36:28)
Type 'copyright', 'credits' or 'license' for more information
IPython 7.2.0 -- An enhanced Interactive Python. Type '?' for help.
In [1]: from jax import numpy as np
In [2]: np.exp(100.)
/Users/phawkins/p/jax/jax/lib/xla_bridge.py:167: UserWarning: No GPU found, falling back to CPU.
warnings.warn('No GPU found, falling back to CPU.')
Out[2]: array(inf, dtype=float32)
So I think the right fix is to disable fast math mode by default, something also tracked in #276 .
Thanks for confirming it! I first thought that it is dtype mess in my previous build. I just get this issue again. :(
Note you can work around by setting the environment variable XLA_FLAGS=--xla_cpu_enable_fast_math=false.
Got it! Thanks @hawkinsp !
This was fixed as part of a recent Jaxlib update (0.1.13). Let us know if you see more problems.
Thanks for notifying me @hawkinsp !