Jax: controlling float precision?

Created on 29 Apr 2019  路  2Comments  路  Source: google/jax

I am losing some precision with float32 and 64 conversions, and also confused on how to get the level of precision I want. Here for example, it seems like I lose precision out of the gate.

import jax
import jax.numpy as np

A = np.array([[0, 1, 1],
              [1, 0, 1],
              [1, 1, 0]], dtype=np.float64)

print(A.dtype)

+RESULTS:

:results:

Out [8]:

output

float32

:end:

If I start with an original float64 numpy array I can get a jax array of that type, but the lu array that comes from decomposition is downgraded to float32.

import numpy as onp
oA = onp.array([[0, 1, 1],
              [1, 0, 1],
              [1, 1, 0]], dtype=np.float64)
oA.dtype

+RESULTS:

:results:

Out [13]:

text/plain

: dtype('float64')
:end:

This shows that A starts out as float64, but the lu array is float 32.

A = np.array(oA)
print(A.dtype)
lu, pivots = jax.lax_linalg.lu(A)
lu.dtype

+RESULTS:

:results:

Out [16]:

output

float64

text/plain

: dtype('float32')
:end:

Is this a bug? Or am I doing something wrong here?

In case it matters I am running all this on a cpu on a Macbook Air.

documentation

Most helpful comment

from jax.config import config; config.update("jax_enable_x64", True)

All 2 comments

from jax.config import config; config.update("jax_enable_x64", True)

We need to document this. It's touched on briefly in the sharp edges notebook at the bottom, but there's no excuse not to be super clear about enabling 64-bit mode.

Was this page helpful?
0 / 5 - 0 ratings