Jax: Very slow CPU + GPU compilation for large-ish model

Created on 26 Mar 2019  路  6Comments  路  Source: google/jax

Code to reproduce is in this gist.

The slowness also seems to occur on GPU, you can reproduce it by running this colab.

Massively appreciate any help fixing this!

Most helpful comment

I looked into this a bit last night:

  • This model is reasonably complex, with O(hundreds) of layers each composed of O(dozens) of primitive ops, but not any more complex than e.g. a plain deep ResNet (as context, IIRC XLA compilation of ResNet-50 from any frontend takes 30s to a minute)
  • Without dropout, the model jaxpr is about 11k lines; dropout adds an enormous amount of inlined+unrolled PRNG code and makes the jaxpr 94k lines
  • This 94k-line jaxpr then becomes HLO that's 180k lines in textual format or 50MB uncompressed as a proto
  • The version of the XLA compiler (jaxlib) currently on pypi takes several hours to compile this HLO, but the latest internal version takes 2-3 minutes (still not great, but much more understandable)
  • Let's update jaxlib!
  • In the long run it might be worthwhile to represent PRNG code in a way that doesn't always get fully inlined+unrolled, when mechanisms for this are available in XLA.

All 6 comments

I looked into this a bit last night:

  • This model is reasonably complex, with O(hundreds) of layers each composed of O(dozens) of primitive ops, but not any more complex than e.g. a plain deep ResNet (as context, IIRC XLA compilation of ResNet-50 from any frontend takes 30s to a minute)
  • Without dropout, the model jaxpr is about 11k lines; dropout adds an enormous amount of inlined+unrolled PRNG code and makes the jaxpr 94k lines
  • This 94k-line jaxpr then becomes HLO that's 180k lines in textual format or 50MB uncompressed as a proto
  • The version of the XLA compiler (jaxlib) currently on pypi takes several hours to compile this HLO, but the latest internal version takes 2-3 minutes (still not great, but much more understandable)
  • Let's update jaxlib!
  • In the long run it might be worthwhile to represent PRNG code in a way that doesn't always get fully inlined+unrolled, when mechanisms for this are available in XLA.

Thanks for taking a look!

@hawkinsp worked a bit on a jaxlib update, but there's some TF build snag he needs to iron out.

The inlining/unrolling of the PRNG is an unfortunate consequence of (1) us doing the PRNG in software and (2) XLA inlining everything.

One thing we could do is expose XLA's RNG HLO operations. Those won't have the benefits of our software PRNG but on the other hand they won't get inlined+unrolled and thus make the compile times longer in cases like this one.

OK, compilation on GPU is now acceptably fast, around 2 minutes.

@mattjj shall we leave this open to track progress on simplifying the PRNG HLO?

So glad to hear the compilation time improved dramatically! We think we can get it down much further through both JAX improvements (e.g. maybe rolling up some loops in our software PRNG) and XLA improvements (we're discussing those internally with the XLA team).

Re: simplifying the PRNG HLO, let's open a separate issue for that.

Was this page helpful?
0 / 5 - 0 ratings