```python
from jax.scipy.stats import multivariate_normal as jmvn
from jax.scipy.linalg import inv, det
import jax.numpy as jnp
import numpy as onp
from scipy.stats import multivariate_normal as omvn
x = jnp.array([296.859985, 299.667206])
mean = jnp.array([208.274643, 210.701645])
cov = jnp.array([[2332.056396, 2307.901855],
[2307.901855, 2283.997559]])
print("with jax")
print(f"logpdf : {jmvn.logpdf(x, mean, cov)}, cov inverse : {inv(cov)}")
print("without jax")
print(omvn.logpdf(x,mean,cov), onp.linalg.inv(cov))
x = jnp.array([296.859985, 299.667206])
mean = jnp.array([208.274643, 210.701645])
cov = jnp.array([[2332.056396, 2306.901855],
[2306.901855, 2283.997559]])
print("with jax")
print(f"logpdf : {jmvn.logpdf(x, mean, cov)}, cov inverse : {inv(cov)}")
print("without jax")
print(omvn.logpdf(x,mean,cov), onp.linalg.inv(cov))
` ``
The issue here is 32-bit vs 64-bit computation. The first matrix is very close to singular, and if working in 32-bit the determinant is indistinguishable from zero. In 64-bit the determinant is small but nonzero.
JAX uses 32-bit arithmetic by default. If you enable 64-bit precision in JAX (see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Double-(64bit)-precision) then you will find that the JAX and numpy output match.
Thank you @jakevdp , enabling double precision solved this issue.
@jakevdp thanks for the help here, much appreciated! FYI @nitinkmittal is an intern I'm working with, we're totally digging JAX!
Most helpful comment
@jakevdp thanks for the help here, much appreciated! FYI @nitinkmittal is an intern I'm working with, we're totally digging JAX!