Jax: Add way to run jax program on CPU

Created on 30 Oct 2019  Â·  6Comments  Â·  Source: google/jax

In my case, I have a small numpy array and a relative simple computation.

Is there a way to place the computation on the CPU to benchmark what's the overhead of transfering the computation to the GPU is?

Related to #957.

question

All 6 comments

Indeed, @levskaya added this functionality in #1211! In fact, #957 is fixed too, we just haven't documented it...

The APIs are a bit in flux, though, as we're trying to simplify a couple things.

For placing on CPU:

from jax import jit

def f(x): return x**2

f_cpu = jit(f, backend='cpu')
f_gpu = jit(f, backend='gpu')

cc @skye who helped with this example code

I'm going to leave this issue open as a documentation request.

@skye pointed out that there is a mention in the jit docstring

device – This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via jax.devices().) The default is inherited from XLA’s DeviceAssignment logic and is usually to use jax.devices()[0].

So let's actually close this issue, but leave up #957 as a request to finalize the API and document device placement of computations in a "Note" at jax.readthedocs.io. (I think one of the changes might be to unify "device" and "backend" arguments.)

By the way, if you want to measure just the transfer time, you can also use jax.device_put which itself also takes device and backend arguments.

Please reopen if the above answers aren't satisfactory!

Thanks for the quick replies!

If I understand correctly, there is no way to specify that _all_ computations should be on CPU, right? It seems we need to adjust all the code to put all jits on CPU backend. It would be useful to have something similar to "config.update("jax_disable_jit", True)" to test your code when it won't run properly on GPU. Is such a config update available?

You can set the default platform globally with:

import jax
jax.config.update('jax_platform_name', platform)

For example to use CPU for all computations (even if other platforms like GPU are available):

import jax
import jax.numpy as jnp

# Global flag to set a specific platform, must be used at startup.
jax.config.update('jax_platform_name', 'cpu')

x = jnp.square(2)
print(repr(x.device_buffer.device()))  # CpuDevice(id=0)
Was this page helpful?
0 / 5 - 0 ratings