This is a continuation of https://github.com/google/jax/issues/304 , in which @joschu observed that TensorFlow GPU outperforms JAX GPU for a Transformer benchmark.
I did a little digging in to the performance difference, and the thing that sticks out most is the memory allocator. The XLA Python client simply calls cudaMalloc/cudaFree on GPU, but NVidia's allocator is synchronous between host and device and apparently slow. Allocating dozens of small tensors can add up.
The obvious fix would be to use a suballocator, as TF and PyTorch both do. TF allocates a large chunk of GPU memory and suballocates it using a best-fit with coalescing (BFC) allocator, whereas PyTorch uses a caching allocator that lazily returns blocks to the CUDA allocator on OOM.
The easiest thing to do would be to use TF's allocator on GPU, although there is merit to the PyTorch approach too (better at sharing GPU memory).
I would be cool if there was an option to use Nvidia's native allocator as well as a custom allocator. The native allocator makes it possible to detect memory leaks in kernel development (with cuda-memcheck).
Resolved via https://github.com/tensorflow/tensorflow/commit/73ae0b9d065e6fe39b90ef0db5b96adc220990fb (included in jaxlib 0.1.15). We expect multi-GPU workloads to be performant enough for real usage now!
@scott-gray you can set the environment variable XLA_PYTHON_CLIENT_ALLOCATOR=platform to use the nvidia allocator (I forgot to update the commit message, so this is incorrectly documented there :disappointed:).
@hawkinsp @skye should we close this now, or wait for full GPU async features to be switched on too?
Ah meant to close this.
Most helpful comment
I would be cool if there was an option to use Nvidia's native allocator as well as a custom allocator. The native allocator makes it possible to detect memory leaks in kernel development (with cuda-memcheck).