Code to reproduce is in this gist.
The slowness also seems to occur on GPU, you can reproduce it by running this colab.
Massively appreciate any help fixing this!
I looked into this a bit last night:
Thanks for taking a look!
@hawkinsp worked a bit on a jaxlib update, but there's some TF build snag he needs to iron out.
The inlining/unrolling of the PRNG is an unfortunate consequence of (1) us doing the PRNG in software and (2) XLA inlining everything.
One thing we could do is expose XLA's RNG HLO operations. Those won't have the benefits of our software PRNG but on the other hand they won't get inlined+unrolled and thus make the compile times longer in cases like this one.
OK, compilation on GPU is now acceptably fast, around 2 minutes.
@mattjj shall we leave this open to track progress on simplifying the PRNG HLO?
So glad to hear the compilation time improved dramatically! We think we can get it down much further through both JAX improvements (e.g. maybe rolling up some loops in our software PRNG) and XLA improvements (we're discussing those internally with the XLA team).
Re: simplifying the PRNG HLO, let's open a separate issue for that.
Most helpful comment
I looked into this a bit last night: