For JAX_ENABLE_X64=1 on CPU,
from scipy.stats import norm
print(norm.cdf(np.array(1, 'float64')))
from jax.scipy.stats import norm
print(norm.cdf(np.array(1, 'float64')))
gives
0.8413447460685429
0.8413447251486219
The accuracy seems to be that of float32, which is causing problems for me. Shouldn't the results be much closer, i. e. within 1e-12 tolerance?
The underlying issue is that 'erf' and 'erfc' from 'lax' are incorrect for JAX_ENABLE_X64=1:
from jax.lax import erf, erfc
print(erf(np.array(1, 'float64')), erfc(np.array(1, 'float64')))
from scipy.special import erf, erfc
print(erf(np.array(1, 'float64')), erfc(np.array(1, 'float64')))
gives
0.8427007686151753 0.15729926461790456
0.8427007929497148 0.15729920705028516
The scipy values can easily be verified.
I think XLA's implementation of erf uses a polynomial that isn't really appropriate for doubles.
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/client/lib/math.cc#L192
If we had a better expansion, we can either improve that XLA implementation (or write our own expansion in JAX in Python; there's nothing special about the C++ implementation.)
@JuliusKunze The bug was fixed on XLA's side.
@hawkinsp I think this issue can be closed.
Just tested it with the stable version of jax/jaxlib, with unchanged results. Is the fix still to be deployed?
Code for XLA compiler is located in TF repo. The change is already submitted (https://github.com/tensorflow/tensorflow/commit/5aefea138fcb992221e57f5400a39b578f2b2b87).
@JuliusKunze https://github.com/google/jax/pull/1182 this PR bumps the version of XLA. Try it.
Ah cool! As far as I understand, I would need to build jaxlib locally? Since it's not urgent, I will try once the PR is merged and deployed!
Yes, that's right. You either need to build jaxlib yourself or wait for us to make a binary wheel release. I merged PR #1182, so closing this bug.