In TF and PyTorch, there is an easy way to tell if the GPU is being used (see below).
How can we do this with jax?
import tensorflow as tf
if tf.test.is_gpu_available():
print(tf.test.gpu_device_name())
else:
print("TF cannot find GPU")
import torch
import torchvision
if torch.cuda.is_available():
print(torch.cuda.get_device_name(0))
else:
print("Torch cannot find GPU")
There's no true public API for this yet, but for right now you can do this:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
and it will print cpu, gpu, or tpu. But note that API is internal and subject to change.
I agree we need a supported API, although exactly how that API will look will almost certainly evolve as we grow support for multiple kinds of devices in the same process.
Also note that if no GPU is found, JAX currently prints a loud warning the first time you run an op:
xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Do those address your immediate need, or do you need something better soon?
Thanks, that works.
On Wed, Jul 3, 2019 at 11:56 AM Peter Hawkins notifications@github.com
wrote:
There's no true public API for this yet, but for right now you can do this:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)and it will print cpu, gpu, or tpu. But note that API is internal and
subject to change.I agree we need a supported API, although exactly how that API will look
will almost certainly evolve as we grow support for multiple kinds of
devices in the same process.Also note that if no GPU is found, JAX currently prints a loud warning the
first time you run an op:xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')Do those address your immediate need, or do you need something better soon?
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/971?email_source=notifications&email_token=ABDK6EBRGBGD5BS4YC4PL4TP5TY67A5CNFSM4H5ICXVKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGODZFMI5Y#issuecomment-508216439,
or mute the thread
https://github.com/notifications/unsubscribe-auth/ABDK6EHJTTT2I4LO6GQSV43P5TY67ANCNFSM4H5ICXVA
.
Closing this because I think we sorted out the current solution. As Peter mentioned, we'll likely make and document a more general API at some point.
Most helpful comment
There's no true public API for this yet, but for right now you can do this:
and it will print
cpu,gpu, ortpu. But note that API is internal and subject to change.I agree we need a supported API, although exactly how that API will look will almost certainly evolve as we grow support for multiple kinds of devices in the same process.
Also note that if no GPU is found, JAX currently prints a loud warning the first time you run an op:
Do those address your immediate need, or do you need something better soon?