I am losing some precision with float32 and 64 conversions, and also confused on how to get the level of precision I want. Here for example, it seems like I lose precision out of the gate.
import jax
import jax.numpy as np
A = np.array([[0, 1, 1],
[1, 0, 1],
[1, 1, 0]], dtype=np.float64)
print(A.dtype)
:results:
float32
:end:
If I start with an original float64 numpy array I can get a jax array of that type, but the lu array that comes from decomposition is downgraded to float32.
import numpy as onp
oA = onp.array([[0, 1, 1],
[1, 0, 1],
[1, 1, 0]], dtype=np.float64)
oA.dtype
:results:
: dtype('float64')
:end:
This shows that A starts out as float64, but the lu array is float 32.
A = np.array(oA)
print(A.dtype)
lu, pivots = jax.lax_linalg.lu(A)
lu.dtype
:results:
float64
: dtype('float32')
:end:
Is this a bug? Or am I doing something wrong here?
In case it matters I am running all this on a cpu on a Macbook Air.
from jax.config import config; config.update("jax_enable_x64", True)
We need to document this. It's touched on briefly in the sharp edges notebook at the bottom, but there's no excuse not to be super clear about enabling 64-bit mode.
Most helpful comment
from jax.config import config; config.update("jax_enable_x64", True)