Jax: TPU not detected by jax in Colab

Created on 24 Jan 2020  路  20Comments  路  Source: google/jax

I am attempting to use the Google provided notebook for Reformer with TPU and I get

usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:118: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')

It works fine with GPU on the other hand.

Most helpful comment

I've not run Reformer but this may work: try running the following cell first, so that the Colab runtime is set to TPU acceleration:

# get the latest JAX and jaxlib
!pip install --upgrade -q jax jaxlib

# Colab runtime set to TPU accel
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

# TPU driver as backend for JAX
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

Source: Cloud TPU NeurIPS 2019 Colab.

The notebook was made by @skye and @mattjj, I think, and demoed by @skye at NeurIPS, when TPU-support was unveiled, I think. (Yours truly was lucky to have attended that demo).

All 20 comments

I've not run Reformer but this may work: try running the following cell first, so that the Colab runtime is set to TPU acceleration:

# get the latest JAX and jaxlib
!pip install --upgrade -q jax jaxlib

# Colab runtime set to TPU accel
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

# TPU driver as backend for JAX
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

Source: Cloud TPU NeurIPS 2019 Colab.

The notebook was made by @skye and @mattjj, I think, and demoed by @skye at NeurIPS, when TPU-support was unveiled, I think. (Yours truly was lucky to have attended that demo).

Thanks, currently testing this -- I also ended up finding these notebooks. It seems like those lines only suppress the error, however, if I time it I get the exact same time for all operations on cpu and a tpu instance, which suggests that this doesn't really help.

Update: just ran the Trax Transformer/Reformer Colab after adding and running an extra cell with the above-mentioned code and it worked.

(Don't forget to set the accelerator to TPU under Edit > Notebook settings).

@8bitmp3 it works in the sense that it removes the error. Try running it on cpu vs tpu with this snippet. You'll get the same times for everything. (I checked by adding %%time on the top of cells)

I suspect the "No GPU/TPU found" error message is not so much of an "Issue" per se, but a good warning message, given that (Cloud) TPU support in JAX has probably only just been green-lighted for use by the public on Google Colab (or maybe I'm wrong?).

However, having raised this as an "Issue" here, we've figured out a quicker way to train this Reformer.

Then why is it not any faster on TPU vs on CPU? It doesn't seem greenlit on Colab.

I think the TPU support might be fake, at least in the public implementation. If you look through the code it relies on a try:except block around a tpu_client that doesnt exist anywhere in the code.

https://github.com/google/jax/blob/dcc882cf6b2cf980783cb2221beac1318bcbb412/jax/lib/xla_bridge.py#L39

For what is worth, GPU works and would be faster than TPU at any rate if I only cared about Colab. But I was hoping to use Jax on TRFC for my next project.

If you're interested in running Reformer an TPU, you can check out our image generation and text generation training examples. These ones are likely a better starting point for TPU usage specifically because they were written with the TPU runtime in mind.

My experience has been positive and let's not forget JAX is still v.0.1.57.

Also, I ran the trainer again on CPU vs TPU and below are the results (I added %time to get more insights:)

CPU:

Step    500: Ran 500 train steps in 160.54 secs
...
Step   1000: Ran 500 train steps in 3.91 secs
...
Step   1500: Ran 500 train steps in 3.75 secs
...

CPU times: user 4 碌s, sys: 1e+03 ns, total: 5 碌s
Wall time: 8.82 碌s

TPU:

Step    500: Ran 500 train steps in 41.07 secs
...
Step   1000: Ran 500 train steps in 28.38 secs
...
Step   1500: Ran 500 train steps in 28.47 secs
...
CPU times: user 2 碌s, sys: 2 碌s, total: 4 碌s
Wall time: 8.58 碌s

_Update:_ GPU:

Step    500: Ran 500 train steps in 33.22 secs
...
Step   1000: Ran 500 train steps in 2.08 secs
...
Step   1500: Ran 500 train steps in 2.20 secs

CPU times: user 4 碌s, sys: 1 碌s, total: 5 碌s
Wall time: 7.39 碌s

The intro colab you're running is using a tiny model, I think it might even be a factor of 100x smaller than anything you'd see in practice:

def tiny_transformer_lm(mode):
  return trax.models.TransformerLM(   # You can try trax_models.ReformerLM too.
    d_model=32, d_ff=128, n_layers=2, vocab_size=32, mode=mode)

The TPU should pull out ahead of CPU once you have a larger batch size (up from the demo of 20 tokens per batch, per core) and a larger model size.

@8bitmp3 you need to add %%time not %time. Single % only evaluates the line it's on (which is empty) which is why you get 碌s.

I'm not sure why the Step 1500 time is way lower on CPU than on TPU though.

I'm rerunning them now.

TPU timing Colab:
TPU Wall time: 2min 2s

CPU timing Colab:
CPU Wall time: Wall time: 2min 43s

Note: the time varies so sometimes CPU ends up faster.

@nkitaev thanks! I'll try those instead. Though I'm still not sure if the TPU gets used at all, given that it performs about as fast as the cpu instance on the simpler example.

@Tenoke if you're not getting that "No GPU/TPU found" warning, it should be running on TPU (as another way to check, jax.devices() should return 8 TpuDevice objects). Like @nkitaev says, TPUs really shine with large inputs. If you're only performing tiny quick computations, non-TPU overheads will dominate the overall time and you won't see a benefit from hardware acceleration. If you try the demo colab, the microbenchmark at the end of "The basics: interactive NumPy on GPU and TPU" section should show a difference between TPU and CPU. (Also that tpu_client import comes from jaxlib here).

Thanks for reporting this btw. Even if turns out that everything's working as expected, this is still useful feedback on the initial experience. We're still working to make JAX + Cloud TPU better!

Ok yes, I get a 2x speed over cpu (304s vs 123s for actuual training, bit less in wall clock time probably) in the text generation notebook.

Like @nkitaev says, TPUs really shine with large inputs. If you're only performing tiny quick computations, non-TPU overheads will dominate the overall time and you won't see a benefit from hardware acceleration.

This was likely what was happening. Though the speedup over cpu (cpu, not gpu) still seems kind of tiny - ~2x only.

@Tenoke have you trained the text generation notebook to completion?

The first cell takes two minutes to run (Step 1: Ran 1 train steps in 124.84 secs), but almost all of this time is spent JIT-compiling the python code to run on TPU. Steady-state speed is closer to 3.8 seconds/step (Step 30: Ran 10 train steps in 38.16 secs), you just need to wait a few minutes for the first few steps to run, as the JIT and caches get warmed up.

@nkitaev Yes, I ran it to completion and I'm getting good results once I account for the compiling. I played with it some to get a better idea and I'm trying it on a bigger dataset next (wikipedia) since that's where it will shine most.

The only thing I'm unsure of is if I can use the same config to work with TPUs on gcloud (I haven't tried yet), given that the settings seem at least partially Colab-specific.

As for this issue - it's basically resolved for me and can be closed since I have it working. I assume there are plans already to either include the relevant config steps in the docs or to detect the tpu directly in xla_bridge.py anyway.

I think the only change you'll need to make for Cloud is to replace the COLAB_TPU_ADDR with the address of the TPU you've allocated. And yes, at least for Colab we should be able to auto-detect the TPU in the future.

I'm gonna close this issue, but @Tenoke please let us know if you run into any other surprising or annoying behavior! Thanks again for reporting.

Is there any reason why the TPUs could not be detected on GitHub-stored notebooks, but just on drive-stored notebooks?

It may be related to security? If so, how can I get Google colab to trust a github notebook?

@epignatelli just go file>copy to drive first

Was this page helpful?
0 / 5 - 0 ratings

Related issues

JuliusKunze picture JuliusKunze  路  23Comments

ibulu picture ibulu  路  29Comments

christopherhesse picture christopherhesse  路  32Comments

ricardobarroslourenco picture ricardobarroslourenco  路  35Comments

dwang55 picture dwang55  路  22Comments