Jax: JAX_ENABLE_X64 norm.cdf inconsistent with scipy

Created on 8 Jul 2019  路  8Comments  路  Source: google/jax

For JAX_ENABLE_X64=1 on CPU,

from scipy.stats import norm
print(norm.cdf(np.array(1, 'float64')))

from jax.scipy.stats import norm
print(norm.cdf(np.array(1, 'float64')))

gives

0.8413447460685429
0.8413447251486219

The accuracy seems to be that of float32, which is causing problems for me. Shouldn't the results be much closer, i. e. within 1e-12 tolerance?

bug

All 8 comments

The underlying issue is that 'erf' and 'erfc' from 'lax' are incorrect for JAX_ENABLE_X64=1:

from jax.lax import erf, erfc
print(erf(np.array(1, 'float64')), erfc(np.array(1, 'float64')))

from scipy.special import erf, erfc
print(erf(np.array(1, 'float64')), erfc(np.array(1, 'float64')))

gives

0.8427007686151753 0.15729926461790456
0.8427007929497148 0.15729920705028516

The scipy values can easily be verified.

I think XLA's implementation of erf uses a polynomial that isn't really appropriate for doubles.

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/client/lib/math.cc#L192

If we had a better expansion, we can either improve that XLA implementation (or write our own expansion in JAX in Python; there's nothing special about the C++ implementation.)

@JuliusKunze The bug was fixed on XLA's side.
@hawkinsp I think this issue can be closed.

Just tested it with the stable version of jax/jaxlib, with unchanged results. Is the fix still to be deployed?

Code for XLA compiler is located in TF repo. The change is already submitted (https://github.com/tensorflow/tensorflow/commit/5aefea138fcb992221e57f5400a39b578f2b2b87).

@JuliusKunze https://github.com/google/jax/pull/1182 this PR bumps the version of XLA. Try it.

Ah cool! As far as I understand, I would need to build jaxlib locally? Since it's not urgent, I will try once the PR is merged and deployed!

Yes, that's right. You either need to build jaxlib yourself or wait for us to make a binary wheel release. I merged PR #1182, so closing this bug.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

clemisch picture clemisch  路  3Comments

rdaems picture rdaems  路  3Comments

sussillo picture sussillo  路  3Comments

harshit-2115 picture harshit-2115  路  3Comments

yfji picture yfji  路  3Comments