With multiple GPUs there should be a way to direct specific computations to run on a specific GPU.
Context: I want to run #gpu different models in parallel each on its own GPU.
The current best solution that comes to mind is to RPC between different machines, each with 1 GPU -- I'd like to avoid doing this.
@hawkinsp @mattjj
Besides, is there a way to set default device to cpu?
These features are already implemented, though they're somewhat experimental and the API is likely to change. @levskaya and @skye have been working on different parts of it.
Here's a way to do it now:
import jax
gpus = jax.devices('gpu')
model1 = jax.jit(model1, device=gpus[0])
model2 = jax.jit(model2, device=gpus[1])
Let's leave this issue open until we've finalized the API, added tests, and also provided the right documentation.
We chatted a bit further about this.
The current API (jit(..., device=...)) doesn't extend to op-by-op mode. In op-by-op mode, it would be tedious or impossible to plumb a device argument to each primitive function.
We are tempted to follow PyTorch's design here: all inputs to an operator must be on the same device (with the likely exceptions of classic NumPy arrays and Python scalars.), and the operator runs on that common device. If different inputs are on different devices, that is an error.
This suggests by extension that the default policy of jit should be: run on the device on which all the arguments are placed, and error on disagreement. We could possibly still allow a device argument on jit which either checks for a particular choice of device, or copies to a particular device, but it seems optional given the more general policy.
I think we should always allow explicit annotation on a jit, and just have a more sophisticated default policy (like follow-the-arguments-when-clear-and-error-otherwise).
These features are already implemented, though they're somewhat experimental and the API is likely to change. @levskaya and @skye have been working on different parts of it.
Here's a way to do it now:
import jax gpus = jax.devices('gpu') model1 = jax.jit(model1, device=gpus[0]) model2 = jax.jit(model2, device=gpus[1])Let's leave this issue open until we've finalized the API, added tests, and also provided the right documentation.
I wonder if the same approach works for TPU?
Yes, it does. On a Cloud TPU (e.g. in Colab), jax.devices() returns 8 TpuDevice objects (corresponding to the 8 cores).
Very nice. Thanks for the response!
As a stopgap solution for setting the default device, I've found I can just set
CUDA_VISIBLE_DEVICES=0,1
before importing JAX. Setting it to an empty string will force JAX to use the CPU.
That only works if you want the entire program on a single GPU mind you; I'd still like to a see a pytorch-style with torch.device(i) context manager.
As of https://github.com/google/jax/pull/1916 or so, computation will "follow" data, meaning if some values are already placed on a specific device, downstream operations will also be placed on that device. For example, something like:
x = device_put(1, jax.devices()[1])
y = x + 1
z = x + y
x, y, and z will all be on device 1 (prior to this new behavior, y and z would go back to the default device 0).
Combined with the device argument to jit, I believe this should give full control over where computations are executed. I'm gonna close this issue, but please comment if I'm missing something and I'll reopen!
(Apologies for not mentioning this earlier!)
I wonder if this works for flax optimizer?
@BoyuanJackChen - there's nothing special about flax optimizers, they just operate on jax arrays like any other computation, so you should be able to target them if needed.
@n2cholas clearly we should improve our documentation around this (PRs welcome, though I'll open an issue so we don't drop it).
Is there any way to create jax.numpy.arrays on a specific device, instead of creating them then moving them via device_put?
If you use jnp.device_put(jnp.zeros(...), jax.devices()[1]) or similar, it'll do the creation on the device thanks to #1668. You can also use jit, as in
partial(jit, device=jax.devices()[1])
def foo():
return jnp.zeros(...)
(That second approach will handle more cases once #4038 is in, i.e. so that #3370 is on by default, but for now it should handle simple constant creation.)
Other than zeros, ones, eye, etc. did you have some other kind of array creation in mind?
Besides, is there a way to set default device to cpu?
You can use the JAX_PLATFORM_NAME env var, or the jax_platform_name absl flag, to set the default device. After a very quick glance I didn't see this in our documentation anywhere...
Thanks @mattjj for the quick reply (and sorry about deleting my comment here then opening the issue haha). #1668 has some really really cool stuff.
I would like to specify the device where a PRNG generates numbers. Specifically, in my JIT'd function, I have some computation running on the GPU and would like to sample random numbers on the CPU. Based on your response, it seems like JIT will take care of this for me and do the sampling on the CPU if I do something like jax.device_put(jax.random.bernoulli(key), jax.devices('cpu')[0]).
Outside of a JIT context, will the lazy sub-language allow JAX to execute jax.device_put(jax.random.bernoulli(key), jax.devices('cpu')[0]) directly on the CPU?
I suggest something like jit(jax.random.bernoulli, static_argnums=2, device=jax.devices('cou')[0]) to get a version of bernoulli that will run on cpu. (You can also use jit(..., backend='cpu), we need to de-duplicate the API.)
The lazy sublanguage stuff really only applies to special cases like jnp.zeros, so jax.device_put(jax.random.bernoulli(key), jax.devices('cpu')[0]) will run the sampling on the default device (i.e. not the CPU in this case) then transfer.
Actually, maybe the most relevant part of the docs is Controlling data and computation placement on devices in the FAQ. That explains the computation-follows-data policy, which means an alternative to giving a device argument to jit you can write
key = jax.device_put(key, jax.devices('cpu')[0])
jax.random.bernoulli(key)
I think that computation-follows-data control is better than using jit(..., device=...).
Makes sense, thanks again!
Most helpful comment
As of https://github.com/google/jax/pull/1916 or so, computation will "follow" data, meaning if some values are already placed on a specific device, downstream operations will also be placed on that device. For example, something like:
x, y, and z will all be on device 1 (prior to this new behavior, y and z would go back to the default device 0).
Combined with the device argument to jit, I believe this should give full control over where computations are executed. I'm gonna close this issue, but please comment if I'm missing something and I'll reopen!
(Apologies for not mentioning this earlier!)