from jax.scipy.stats import norm
import jax.numpy as np
print(norm.cdf(np.array([np.inf])))
print(norm.cdf(np.array([np.inf, np.inf])))
print(norm.cdf(np.array([np.inf, np.inf, np.inf])))
print(norm.cdf(np.array([np.inf, np.inf, np.inf, np.inf])))
print(norm.cdf(np.array([np.inf, np.inf, np.inf, np.inf, np.inf])))
breaks for arrays of 4 elements or more on CPU:
[1.]
[1. 1.]
[1. 1. 1.]
[nan nan nan nan]
[nan nan nan nan nan]
For comparison,
from scipy.stats import norm
import numpy as np
print(norm.cdf(np.array([np.inf])))
print(norm.cdf(np.array([np.inf, np.inf])))
print(norm.cdf(np.array([np.inf, np.inf, np.inf])))
print(norm.cdf(np.array([np.inf, np.inf, np.inf, np.inf])))
print(norm.cdf(np.array([np.inf, np.inf, np.inf, np.inf, np.inf])))
gives the expected result:
[1.]
[1. 1.]
[1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1. 1.]
Ah, my old friend: fast math mode.
If you set the environment variable XLA_FLAGS=--xla_cpu_enable_fast_math=false the nans go away. I'll need to dig into it a bit further to figure out why, but the size of 4 strongly suggests an issue with vectorization.
We updated XLA to a new version that turns off approximate semantics for reciprocals by default. This fixes the problem in this issue, although it requires that you either rebuild Jaxlib from source or wait for us to make a new release. Hope that helps!
Most helpful comment
We updated XLA to a new version that turns off approximate semantics for reciprocals by default. This fixes the problem in this issue, although it requires that you either rebuild Jaxlib from source or wait for us to make a new release. Hope that helps!