Jax: multi-machine allreduce

Created on 10 Jul 2019  路  14Comments  路  Source: google/jax

Hi! I am looking to do fast multi-machine allreduce and broadcast operations when using JAX and MPI.

Here is a script that should be similar to my workload which I ran on 8 GCE instances with 8 V100 GPUs each and a 32 Gbit network:

https://gist.github.com/christopherhesse/b5d141a59d9648caab191d9ff6333117

I ran it using mpich:

mpiexec -f <hosts file> python <path to script>

The output looks like this:

num_params 1
compute             : min_elapsed 0.000424  avg_elapsed 0.026451  max_elapsed 0.259608
device_to_host      : min_elapsed 0.000070  avg_elapsed 0.000106  max_elapsed 0.000298
allreduce           : min_elapsed 0.000209  avg_elapsed 0.002230  max_elapsed 0.018252
num_params 16000000
compute             : min_elapsed 0.006838  avg_elapsed 0.023782  max_elapsed 0.155499
device_to_host      : min_elapsed 0.123953  avg_elapsed 0.135843  max_elapsed 0.163817
allreduce           : min_elapsed 0.505218  avg_elapsed 0.592024  max_elapsed 0.640469

So about 600 ms per allreduce for 16M float32s.

If I use nccl-tests with MPI support (make MPI=1):

mpiexec -f <hosts file> ./nccl-tests/build/all_reduce_perf -b 1M -e 64M -f 2 -g 1 -c 0

The output looks like this:

[0] #                                                     out-of-place                       in-place          
[0] #       size         count    type   redop     time   algbw   busbw  error     time   algbw   busbw  error
[0] #        (B)    (elements)                     (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)       
[0]      1048576        262144   float     sum[0]    5328.1    0.20    0.39    N/A[0]    3856.1    0.27    0.54    N/A
[0]      2097152        524288   float     sum[0]    6751.2    0.31    0.61    N/A[0]    6132.7    0.34    0.67    N/A
[0]      4194304       1048576   float     sum[0]     11100    0.38    0.74    N/A[0]     10899    0.38    0.76    N/A
[0]      8388608       2097152   float     sum[0]    9818.9    0.85    1.68    N/A[0]    9351.1    0.90    1.77    N/A
[0]     16777216       4194304   float     sum[0]     17219    0.97    1.92    N/A[0]     17121    0.98    1.93    N/A
[0]     33554432       8388608   float     sum[0]     35836    0.94    1.84    N/A[0]     36609    0.92    1.80    N/A
[0]     67108864      16777216   float     sum[0]     73365    0.91    1.80    N/A[0]     78911    0.85    1.67    N/A

Which looks like 80 ms for 16M float32s.

For my particular training setup, I am seeing ~600 ms spent doing the allreduce, out of ~800 ms total per training loop, so improving this could improve the runtime of my script substantially.

The two ways that seem most promising to me would be:

1) Use XLA's existing NCCL support or extend it to do this call through XLA

2) Use pointers to GPU memory to call NCCL from Python (not sure if this would encounter weird issues with XLA also using CUDA)

What do you guys think?

enhancement

Most helpful comment

The unsafe_buffer_pointer() method should be available in Jaxlib 0.1.22. Please experiment with it (although I wouldn't consider it a final API.)

All 14 comments

Thanks for raising this! I used mpi4py a bit in grad school and loved it, so it will be really satisfying to get this working well.

I don't have much to offer right now but just wanted to collect some of our clues in one place. @hawkinsp just added a way to grab raw XLA:GPU memory pointers to our Python XLA client so that we could explore option 2: here's the TF commit.

We might need an XLA expert to weigh in on option 1, but AIUI XLA:GPU's NCCL support is for the single-host setting. That said, XLA:TPU's multi-replica computations can span accelerators across multiple hosts, so there could be some path forward there.

(By the way, we need to update the TF commit that our repo points to before our build process will build a version of XLA with that update. Also we'll need to update the jaxlib wheels.)

Thanks for the quick response! I'll give the pointers a try once XLA is updated.

@mattjj any idea when XLA will be updated next?

Related to this issue, is there a way to tell JAX to use a specific GPU without setting CUDA_VISIBLE_DEVICES?

@mattjj any idea when XLA will be updated next?

Sorry, we're getting behind because almost all of the team is in London this week.

I just kicked off a build for Linux machines, which should be done in an hour or so; @hawkinsp could you do the macOS build, or will that have to wait for us to return to the colonies?

Related to this issue, is there a way to tell JAX to use a specific GPU without setting CUDA_VISIBLE_DEVICES?

No, not yet. Would you want to place different computations on different devices, or just set a program-global parameter to control this? (The former is more general than the latter, and we plan to add that soon, but maybe the latter would be a quicker fix that covers your use case.)

Thanks for kicking off that build!

Program-global parameter would work for me. The reason is that it looks like I have to set CUDA_VISIBLE_DEVICES to a different value for NCCL, it's possible that I can work around this by setting the env var only during NCCL initialization. Should I file a new issue if it turns out that the global parameter is important to me?

No need for a new issue; several others are keen on having a way to control
jit device assignment, so it鈥檚 on my mind already.

(Haven鈥檛 uploaded the wheels yet.)

>

It looks like I can get NCCL to work even when using CUDA_VISIBLE_DEVICES, but performance is noticeably impacted possibly because it can't see all the GPUs on the machine at once.

As a result, it's not required for this issue, but it would be very nice to control device assignment at even a global manner.

The unsafe_buffer_pointer() method should be available in Jaxlib 0.1.22. Please experiment with it (although I wouldn't consider it a final API.)

The buffer pointer works great! Should I leave this issue open until there's some sort of control over device assignment? It looks like my setup with NCCL should be even faster once that is supported by JAX.

Specifically, I want to confine JAX to using a single GPU (so no allocating a ton of memory on the other GPUs) without setting CUDA_VISIBLE_DEVICES.

@mattjj I wanted to mention that I may have found a workaround for this CUDA_VISIBLE_DEVICES issue that does not require any JAX changes. I will try it out and update this issue with the result.

Specifically the fix for that is to upgrade to CUDA 10.1: https://github.com/NVIDIA/nccl/blob/master/src/transport/p2p.cc#L75

It looks like JAX does not have a jaxlib for CUDA 10.1 though, do I need to build that myself?

Looks like not only does tensorflow not support CUDA 10.1, but neither do our compute clusters, so nevermind.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

lonelykid picture lonelykid  路  3Comments

DylanMuir picture DylanMuir  路  3Comments

sschoenholz picture sschoenholz  路  3Comments

yfji picture yfji  路  3Comments

murphyk picture murphyk  路  3Comments