Jax: add complex128 support (and match NumPy behavior)

Created on 10 Jan 2019  路  9Comments  路  Source: google/jax

In the following example, the inputs to jax.numpy.exp are float64, but the output is complex64.

import numpy as onp
import jax.numpy as np
from jax import jacrev, jit

R = onp.random.RandomState(0).randn

x = R(128)
arg_dtype = x.dtype


def np_fn(x):
    assert x.dtype == arg_dtype
    return onp.exp(x + 1j)


def jax_fn(x):
    assert x.dtype == arg_dtype
    return np.exp(x + 1j)


assert x.dtype == np.float64
assert np_fn(x).dtype == np.complex128
result = jax_fn(x)
assert result.dtype == np.complex128, result.dtype
enhancement

Most helpful comment

@sjperkins I'd be interested to hear what you're working on (apologies if you wrote it in another thread already).

@mattj Radio Astronomy, in particular the MeerKAT telescope, a pre-cursor to the much larger SKA. As the raw data of a Radio Telescope are complex voltages, double-precision complex support would be useful!

We're putting together a library of radio astronomy algorithms (https://github.com/ska-sa/codex-africanus) with a focus on jitted compilation (for both CPU and GPU). The combination of

  • jit compilation to CPU and GPU backends
  • autodiff

makes jax very promising so I've started experimenting with it.

All 9 comments

There are two things going on here.

a) JAX doesn't support complex128, because XLA doesn't support complex128. It wouldn't be hard to add complex128 support to XLA, it's just no-one has ever asked for it yet. It would be a fairly mechanical change. I'll look into doing this.

b) JAX has a (slightly hacky) mechanism to squash 64-bit types to 32-bit types to run computations on hardware that doesn't have a 64-bit support (e.g., TPU) (jax_enable_x64). We should probably enable x64 by default on CPU and GPU.

To enable X64 mode, run the following at the start of execution (I don't think it works if run later.)

import jax.config
jax.config.config.update("jax_enable_x64", True)

My personal preference would be to write

from jax.config import config
config.update("jax_enable_x64", True)

You can also set things up to configure via command-line flags like --jax_enable_x64=True, parsed with absl-py, by putting this at the top of your file:

from jax.config import config
config.config_with_absl()

You can also set the environment variable JAX_ENABLE_X64=True, or to anything non-falsey.

We should probably enable x64 by default on CPU and GPU.

Maybe. We should set the default based on whether most users want to use JAX for 64-bit computations (following NumPy) or 32-bit computations.

NumPy is really happy to promote to 64-bit dtypes, and that can actually make it annoying to write e.g. conv nets where you want to keep everything in float32. You don't want one infix multiplication operator to promote everything into float64. We got frustrated with that ourselves at one point, and so we decided to add jax_enable_x64=False as a big hammer that crushes all 64-bit values out of JAX. That's a reason to have the flag, not an argument about what its default should be, but if it ends up that most users are using JAX for deep learning and want to stay in 32-bit, then a default of False would make sense to me.

What do you all think?

NumPy is really happy to promote to 64-bit dtypes, and that can actually make it annoying to write e.g. conv nets where you want to keep everything in float32. You don't want one infix multiplication operator to promote everything into float64. We got frustrated with that ourselves at one point, and so we decided to add jax_enable_x64=False as a big hammer that crushes all 64-bit values out of JAX. That's a reason to have the flag, not an argument about what its default should be, but if it ends up that most users are using JAX for deep learning and want to stay in 32-bit, then a default of False would make sense to me.

Hear hear! I also wonder sometimes whether float64 and higher precision are really necessary for DL applications. Perhaps I'm uninformed somewhere?

Using other frameworks, float32 and float64 bites me if I'm not careful. I like the idea of having the "BFH" (big fat hammer) set everything to a sane setting so that end-users can focus on the more important modelling parts.

One thing, though: why is the front-facing API:

from jax.config import config

Rather than

from jax import config

?

Having nested repeated names adds potential confusion, I think. Perhaps it'd be nice the config object were imported into the jax namespace directly? PyMC3 does this with all of the distributions, they are imported under the pymc3 namespace. E.g.

# jax/__init__.py
from .config import config

Maybe I'm missing something here, though, as I'm not privy to the architectural decisions that are being made.

I understand that ML is probably the primary driver behind the development of jax so I'm fine with the x86 workaround. However, by providing a numpy interface there's likely always going to be an implicit assumption that jax will behave as numpy does.

(I attached Peter's name to this issue because he's already done a lot of follow-up work behind-the-scenes.)

Good point re: assumptions about numpy behavior. We don't want to be surprising. Whichever default we settle on, we should have this documented clearly.

While deep learning is a driver on all of our minds, we would like JAX to be broader in scope and to help bring accelerator power to a broader spectrum of Python numerical computing work. @sjperkins I'd be interested to hear what you're working on (apologies if you wrote it in another thread already).

Not directly related, but #225 is upgrading a lot of complex number support.

@sjperkins I'd be interested to hear what you're working on (apologies if you wrote it in another thread already).

@mattj Radio Astronomy, in particular the MeerKAT telescope, a pre-cursor to the much larger SKA. As the raw data of a Radio Telescope are complex voltages, double-precision complex support would be useful!

We're putting together a library of radio astronomy algorithms (https://github.com/ska-sa/codex-africanus) with a focus on jitted compilation (for both CPU and GPU). The combination of

  • jit compilation to CPU and GPU backends
  • autodiff

makes jax very promising so I've started experimenting with it.

I checked in some preliminary complex128 support. You'll need an up to date jaxlib to use it.

As far as I am aware the complex128 support should be fairly complete on CPU; a couple of cases of reduction are known not to work on GPU but will fail with an loud error from LLVM if you hit them.

Please try it out and let us know how it goes!

Regarding an up-to-date jaxlib, the linux wheels should all be updated now (to version 0.1.4).

Was this page helpful?
0 / 5 - 0 ratings