There鈥檚 a separate pmap cookbook notebook but adding pmap and maybe even remat to the Jax Quickstart would showcase more of what Jax can offer.
This would require switching the notebook to a Cloud TPU Colab. @skye you shared this idea earlier this year offline (@mattjj was cc鈥檈d) , getting back to you now (at last). Let me know what you think, I can start working on PR. Cheers
That sounds like a great idea! In the spirit of making smaller easier-to-review changes, how about starting with just pmap?
One thing I'm less sure of is if there are any downsides to using a Cloud TPU colab for the quickstart. @jekbradbury @skye WDYT?
Thanks @mattjj Adding only pmap for now makes sense. Will work on it.
I think it would be best it at least the majority of the quickstart notebook can run on any platform. It's great to show off Cloud TPUs, but most people getting started with JAX probably won't be using them, and I wouldn't want them to think they _need_ to be using TPUs for JAX to be worthwhile.
I like the idea of switching to Cloud TPU + adding pmap. I'm imagining putting it at the end so the rest of the notebook runs on any colab platform, with a little blurb explaining this will work on any platform if you have enough devices. You can even put an assert like:
assert jax.device_count() >= 4 # or however many devices the example needs
The main downside I see is that Cloud TPU requires putting the preamble at the top to switch to the new TPU driver, which I think will break on other platforms. There might be a colab env var or something we can check to conditionally run the preamble only on a TPU colab, I can look into this if it'd be helpful. I'd also be ok just putting the preamble collapsed by default and explaining it should be removed for non-TPU platforms.
Most helpful comment
I think it would be best it at least the majority of the quickstart notebook can run on any platform. It's great to show off Cloud TPUs, but most people getting started with JAX probably won't be using them, and I wouldn't want them to think they _need_ to be using TPUs for JAX to be worthwhile.