Hi! I am looking to do fast multi-machine allreduce and broadcast operations when using JAX and MPI.
Here is a script that should be similar to my workload which I ran on 8 GCE instances with 8 V100 GPUs each and a 32 Gbit network:
https://gist.github.com/christopherhesse/b5d141a59d9648caab191d9ff6333117
I ran it using mpich:
mpiexec -f <hosts file> python <path to script>
The output looks like this:
num_params 1
compute : min_elapsed 0.000424 avg_elapsed 0.026451 max_elapsed 0.259608
device_to_host : min_elapsed 0.000070 avg_elapsed 0.000106 max_elapsed 0.000298
allreduce : min_elapsed 0.000209 avg_elapsed 0.002230 max_elapsed 0.018252
num_params 16000000
compute : min_elapsed 0.006838 avg_elapsed 0.023782 max_elapsed 0.155499
device_to_host : min_elapsed 0.123953 avg_elapsed 0.135843 max_elapsed 0.163817
allreduce : min_elapsed 0.505218 avg_elapsed 0.592024 max_elapsed 0.640469
So about 600 ms per allreduce for 16M float32s.
If I use nccl-tests with MPI support (make MPI=1):
mpiexec -f <hosts file> ./nccl-tests/build/all_reduce_perf -b 1M -e 64M -f 2 -g 1 -c 0
The output looks like this:
[0] # out-of-place in-place
[0] # size count type redop time algbw busbw error time algbw busbw error
[0] # (B) (elements) (us) (GB/s) (GB/s) (us) (GB/s) (GB/s)
[0] 1048576 262144 float sum[0] 5328.1 0.20 0.39 N/A[0] 3856.1 0.27 0.54 N/A
[0] 2097152 524288 float sum[0] 6751.2 0.31 0.61 N/A[0] 6132.7 0.34 0.67 N/A
[0] 4194304 1048576 float sum[0] 11100 0.38 0.74 N/A[0] 10899 0.38 0.76 N/A
[0] 8388608 2097152 float sum[0] 9818.9 0.85 1.68 N/A[0] 9351.1 0.90 1.77 N/A
[0] 16777216 4194304 float sum[0] 17219 0.97 1.92 N/A[0] 17121 0.98 1.93 N/A
[0] 33554432 8388608 float sum[0] 35836 0.94 1.84 N/A[0] 36609 0.92 1.80 N/A
[0] 67108864 16777216 float sum[0] 73365 0.91 1.80 N/A[0] 78911 0.85 1.67 N/A
Which looks like 80 ms for 16M float32s.
For my particular training setup, I am seeing ~600 ms spent doing the allreduce, out of ~800 ms total per training loop, so improving this could improve the runtime of my script substantially.
The two ways that seem most promising to me would be:
1) Use XLA's existing NCCL support or extend it to do this call through XLA
2) Use pointers to GPU memory to call NCCL from Python (not sure if this would encounter weird issues with XLA also using CUDA)
What do you guys think?
Thanks for raising this! I used mpi4py a bit in grad school and loved it, so it will be really satisfying to get this working well.
I don't have much to offer right now but just wanted to collect some of our clues in one place. @hawkinsp just added a way to grab raw XLA:GPU memory pointers to our Python XLA client so that we could explore option 2: here's the TF commit.
We might need an XLA expert to weigh in on option 1, but AIUI XLA:GPU's NCCL support is for the single-host setting. That said, XLA:TPU's multi-replica computations can span accelerators across multiple hosts, so there could be some path forward there.
(By the way, we need to update the TF commit that our repo points to before our build process will build a version of XLA with that update. Also we'll need to update the jaxlib wheels.)
Thanks for the quick response! I'll give the pointers a try once XLA is updated.
@mattjj any idea when XLA will be updated next?
Related to this issue, is there a way to tell JAX to use a specific GPU without setting CUDA_VISIBLE_DEVICES?
@mattjj any idea when XLA will be updated next?
Sorry, we're getting behind because almost all of the team is in London this week.
I just kicked off a build for Linux machines, which should be done in an hour or so; @hawkinsp could you do the macOS build, or will that have to wait for us to return to the colonies?
Related to this issue, is there a way to tell JAX to use a specific GPU without setting
CUDA_VISIBLE_DEVICES?
No, not yet. Would you want to place different computations on different devices, or just set a program-global parameter to control this? (The former is more general than the latter, and we plan to add that soon, but maybe the latter would be a quicker fix that covers your use case.)
Thanks for kicking off that build!
Program-global parameter would work for me. The reason is that it looks like I have to set CUDA_VISIBLE_DEVICES to a different value for NCCL, it's possible that I can work around this by setting the env var only during NCCL initialization. Should I file a new issue if it turns out that the global parameter is important to me?
No need for a new issue; several others are keen on having a way to control
jit device assignment, so it鈥檚 on my mind already.
(Haven鈥檛 uploaded the wheels yet.)
>
It looks like I can get NCCL to work even when using CUDA_VISIBLE_DEVICES, but performance is noticeably impacted possibly because it can't see all the GPUs on the machine at once.
As a result, it's not required for this issue, but it would be very nice to control device assignment at even a global manner.
The unsafe_buffer_pointer() method should be available in Jaxlib 0.1.22. Please experiment with it (although I wouldn't consider it a final API.)
The buffer pointer works great! Should I leave this issue open until there's some sort of control over device assignment? It looks like my setup with NCCL should be even faster once that is supported by JAX.
Specifically, I want to confine JAX to using a single GPU (so no allocating a ton of memory on the other GPUs) without setting CUDA_VISIBLE_DEVICES.
@mattjj I wanted to mention that I may have found a workaround for this CUDA_VISIBLE_DEVICES issue that does not require any JAX changes. I will try it out and update this issue with the result.
Specifically the fix for that is to upgrade to CUDA 10.1: https://github.com/NVIDIA/nccl/blob/master/src/transport/p2p.cc#L75
It looks like JAX does not have a jaxlib for CUDA 10.1 though, do I need to build that myself?
Looks like not only does tensorflow not support CUDA 10.1, but neither do our compute clusters, so nevermind.
Most helpful comment
The
unsafe_buffer_pointer()method should be available in Jaxlib 0.1.22. Please experiment with it (although I wouldn't consider it a final API.)