Jax: Support multinode training on GPU

Created on 16 Apr 2020  路  9Comments  路  Source: google/jax

I don't have a node with 8 gpus. I have two nodes each with 4 gpus. So is it possible to train a model on multiple nodes?

enhancement

Most helpful comment

This is actually something that does work right now but it's still experimental. There's also no real public-facing API for it yet; you have to type in some obscure and fairly magical things to set it all up correctly.

We should polish it off and document it!

All 9 comments

This is actually something that does work right now but it's still experimental. There's also no real public-facing API for it yet; you have to type in some obscure and fairly magical things to set it all up correctly.

We should polish it off and document it!

Can you say a bit more about your model, though? Would gradient all-reductions across multiple nodes suffice?

@hawkinsp Technically, I'm training a reformer model using Trax library.

And I assume you're just looking for data parallelism, i.e., partitioning a minibatch across GPUs, not partitioning in any other way (e.g., model parallelism)?

@hawkinsp yeah my concern is data parallelism

@hawkinsp Can you please share your notes on this (don't need a stable api) ? We are trying some hybrid data/model/pipeline parallelism so it is a little different from @py4 but would love to get started with data parallelism

Data parallelism would of value to other projects that use XLA as well (eg https://www.tensorflow.org/swift). Exposing this functionality in a standardized way would help drive progress in the broader ecosystem!

I don't have a node with 8 gpus. I have two nodes each with 4 gpus. So is it possible to train a model on multiple nodes?

Hello py4, I am meeting the same problem, have you found some solutions?

This _is_ actually something that does work right now but it's still experimental. There's also no real public-facing API for it yet; you have to type in some obscure and fairly magical things to set it all up correctly.

We should polish it off and document it!

Hello hawkinsp, Could you please provide more details about how to run data parallel with multi node GPUs?

Was this page helpful?
0 / 5 - 0 ratings

Related issues

alexbw picture alexbw  路  3Comments

murphyk picture murphyk  路  3Comments

sursu picture sursu  路  3Comments

madvn picture madvn  路  3Comments

lonelykid picture lonelykid  路  3Comments