Is there any way to use jax with TPUs in coreless mode?
In TensorFlow you can just use tf.device(None) to use the TPU's 300gb RAM + cpu for bigger operations but after looking at xla, the bridge, trax (which is where I am using jaxlib) and jax, I only seem to run into stuff like this error - 'JAX cannot work yet with n_devices != all devices: 1 != 8'.
Can you say a bit more about what you're trying to do? I don't know what you mean by coreless mode.
I believe with tf.device(None): simply drops any device annotations allowing the TF placer to use whichever device it wants. In practice, it usually makes a simple default choice.
What behavior do you want from JAX?
+1 to Peter's questions, but note jax.devices('cpu')[0] gives you a CpuDevice, which you can pass to jit, device_put, etc. to run on host. I'm not sure if/how trax plumbs this through though.
I think the question is predicated somewhat on the current TensorFlow TPU setup, where you have a TF distributed system that spans the user VM and the host CPU on the Cloud TPU. If the user VM is small, you can use the CPUs on the TPU machine for TensorFlow ops instead of running them locally. But JAX doesn't have a distributed system; if you want to run things on a remote CPU you'll need to use something like Ray or Dask. The JAX TPU integration is currently only able to use the TPU cores on the Cloud TPU.
Thanks, for the answer @jekbradbury. Are there any plans for adding it in? It seems odd to completely ignore such a significant chunk of a TPU's resources.
I think there's nothing we can do along those lines at the moment, because the way this works at present is specific to TensorFlow. However, it's possible future evolutions of the cloud TPU product might make it possible for JAX to make more use of the TPU VM, as TensorFlow does, in addition to the TPU devices themselves. Watch this space!
However, since there's no action we can take at the moment, I'm going to close this issue.
Hi everyone,
I think there might be some confusion in this thread.
A TPU isn't just a piece of hardware with 8 cores. It's a piece of hardware that has a CPU, RAM, and 8 cores.
You can run code on a TPU's CPU. I do this all the time for fine-tuning GPT-2 1.5B. It's as easy as running tf.device(None): # ops go here.
When you run ops on the TPU's CPU, you have access to up to 300 GB of memory(!) without running into errors. 300 GB is far, far higher than the TPU's normal limit. As far as I'm concerned, it's one of the best features of TPUs.
In fact, the 300 GB limit is so high that people often refuse to believe that this is even possible. It's not advertised anywhere. I myself discovered the feature by accident.
Here's an HN thread where I illustrate how the TPU can use 300GB of memory: https://news.ycombinator.com/item?id=22196855
And a simple notebook that fine-tunes GPT-2 1.5B using a TPUv2 (which is quite impossible if you were limited to only 8GB): https://colab.research.google.com/drive/1ohuxvB7nuvcjpLLIF1L3WR7SSzFENwQY
So, given that using 300 GB of memory is one of the best features that TPUs have to offer, is there anything that can be done to support this feature in Jax?
All that needs to be done to support it is to be able to execute ops on the TPU's CPU. This corresponds to the TPU's /device:CPU:0 device, which doesn't seem special. It's just like running ops on one of the TPU's cores, except it corresponds to the TPU's CPU instead.
Note that Google's official MLPerf benchmarks uses this technique for resnet training: https://github.com/mlperf/training_results_v0.6/blob/8f510835d9afc68ba3c9608329730a66f6de50d8/Google/benchmarks/resnet/implementations/tpu-v3-512-resnet/resnet/train_and_eval_runner.py#L57
@shawwn Unfortunately that's a capability that is only available to TensorFlow at the moment, and not to other users of TPUs. It's possible that might change in the future, but we can't make any promises at this time.
It seems strange that this feature is only available in Tensorflow. Being able to run ops on the TPU's CPU is necessary to do infeed processing.
For example, in MLPerf's TPU imagenet benchmark, the codebase runs image processing operations on the TPU's CPU: https://github.com/mlperf/training_results_v0.6/blob/8f510835d9afc68ba3c9608329730a66f6de50d8/Google/benchmarks/resnet/implementations/tpu-v3-512-resnet/resnet/train_and_eval_runner.py#L179
Without this feature, there's no way that TPUs could possibly achieve an imagenet benchmark time of 3.5 minutes on a TPUv3-512 (https://mlperf.org/training-results-0-6)
What is the correct way to do infeed processing with Jax at the moment? For example, if you wanted the TPU to decode a JPEG and extract a random crop, how would you do it?
Most helpful comment
Hi everyone,
I think there might be some confusion in this thread.
A TPU isn't just a piece of hardware with 8 cores. It's a piece of hardware that has a CPU, RAM, and 8 cores.
You can run code on a TPU's CPU. I do this all the time for fine-tuning GPT-2 1.5B. It's as easy as running
tf.device(None): # ops go here.When you run ops on the TPU's CPU, you have access to up to 300 GB of memory(!) without running into errors. 300 GB is far, far higher than the TPU's normal limit. As far as I'm concerned, it's one of the best features of TPUs.
In fact, the 300 GB limit is so high that people often refuse to believe that this is even possible. It's not advertised anywhere. I myself discovered the feature by accident.
Here's an HN thread where I illustrate how the TPU can use 300GB of memory: https://news.ycombinator.com/item?id=22196855
And a simple notebook that fine-tunes GPT-2 1.5B using a TPUv2 (which is quite impossible if you were limited to only 8GB): https://colab.research.google.com/drive/1ohuxvB7nuvcjpLLIF1L3WR7SSzFENwQY
So, given that using 300 GB of memory is one of the best features that TPUs have to offer, is there anything that can be done to support this feature in Jax?
All that needs to be done to support it is to be able to execute ops on the TPU's CPU. This corresponds to the TPU's
/device:CPU:0device, which doesn't seem special. It's just like running ops on one of the TPU's cores, except it corresponds to the TPU's CPU instead.Note that Google's official MLPerf benchmarks uses this technique for resnet training: https://github.com/mlperf/training_results_v0.6/blob/8f510835d9afc68ba3c9608329730a66f6de50d8/Google/benchmarks/resnet/implementations/tpu-v3-512-resnet/resnet/train_and_eval_runner.py#L57