Jax's (/ XLA?) implementation of cholesky returns bad values when applied to a non PSD matrix. I would expect the jax impl to throw an error similar to the scipy error.
Code:
import jax.numpy as np
x = np.arange(9).reshape((3,3))
from jax.scipy import linalg
print(linalg.cholesky(x))
print("From Scipy")
from scipy import linalg
print(linalg.cholesky(x))
Output:
/home/luke/.local/lib/python3.5/site-packages/jax/scipy/linalg.py:36: UserWarning: scipy.linalg support is experimental and may cause silent failures or wrong outputs
warnings.warn(_EXPERIMENTAL_WARNING)
[[0. 1. 2.]
[0. 4. 5.]
[0. 0. 8.]]
From Scipy
Traceback (most recent call last):
File "tmp.py", line 11, in <module>
print(linalg.cholesky(x))
File "/home/luke/.local/lib/python3.5/site-packages/scipy/linalg/decomp_cholesky.py", line 91, in cholesky
check_finite=check_finite)
File "/home/luke/.local/lib/python3.5/site-packages/scipy/linalg/decomp_cholesky.py", line 40, in _cholesky
"definite" % info)
numpy.linalg.LinAlgError: 1-th leading minor of the array is not positive definite
The difficulty here is that JAX currently has no way to report an error from the middle of a JIT-ted computation. So it's not easy to exactly mimic the numpy behavior.
What I could easily do is change JAX to return a matrix of NaNs in the event that the input matrix is not PSD. Would that be good enough?
I vote for returning nan or keeping the current bahavior (returning an invalid result), rather than throwing errors. I have some usage cases where I want the script keeps running when getting nan. For example, it will decrease learning rate when initial learning rate leads to nan loss.
+1 to nan.
As per my use case, I was doing something dumb and the wrong values where tricky for me to debug. A nan output would have pointed me in the right direction much faster.
Thanks!
Note that currently Cholesky returns an upper-triangular NaN matrix if the input is not PSD (the other triangle will be 0).