Jax: Cholesky behavior mismatch between jax and scipy

Created on 27 May 2019  路  4Comments  路  Source: google/jax

Jax's (/ XLA?) implementation of cholesky returns bad values when applied to a non PSD matrix. I would expect the jax impl to throw an error similar to the scipy error.

Code:

import jax.numpy as np

x = np.arange(9).reshape((3,3))

from jax.scipy import linalg
print(linalg.cholesky(x))

print("From Scipy")

from scipy import linalg
print(linalg.cholesky(x))

Output:

/home/luke/.local/lib/python3.5/site-packages/jax/scipy/linalg.py:36: UserWarning: scipy.linalg support is experimental and may cause silent failures or wrong outputs
  warnings.warn(_EXPERIMENTAL_WARNING)
[[0. 1. 2.]
 [0. 4. 5.]
 [0. 0. 8.]]
From Scipy
Traceback (most recent call last):
  File "tmp.py", line 11, in <module>
    print(linalg.cholesky(x))
  File "/home/luke/.local/lib/python3.5/site-packages/scipy/linalg/decomp_cholesky.py", line 91, in cholesky
    check_finite=check_finite)
  File "/home/luke/.local/lib/python3.5/site-packages/scipy/linalg/decomp_cholesky.py", line 40, in _cholesky
    "definite" % info)
numpy.linalg.LinAlgError: 1-th leading minor of the array is not positive definite
bug

All 4 comments

The difficulty here is that JAX currently has no way to report an error from the middle of a JIT-ted computation. So it's not easy to exactly mimic the numpy behavior.

What I could easily do is change JAX to return a matrix of NaNs in the event that the input matrix is not PSD. Would that be good enough?

I vote for returning nan or keeping the current bahavior (returning an invalid result), rather than throwing errors. I have some usage cases where I want the script keeps running when getting nan. For example, it will decrease learning rate when initial learning rate leads to nan loss.

+1 to nan.

As per my use case, I was doing something dumb and the wrong values where tricky for me to debug. A nan output would have pointed me in the right direction much faster.

Thanks!

785 should have mostly fixed this issue on CPU. You may still experience the same problem on other backends.

Note that currently Cholesky returns an upper-triangular NaN matrix if the input is not PSD (the other triangle will be 0).

Was this page helpful?
0 / 5 - 0 ratings

Related issues

madvn picture madvn  路  3Comments

sschoenholz picture sschoenholz  路  3Comments

clemisch picture clemisch  路  3Comments

sursu picture sursu  路  3Comments

kunc picture kunc  路  3Comments