The differentially private SGD example takes a very long time to JIT compile at larger batch sizes. The compilation time spikes after a batch size of 512. Any way to fix this or get around this?
Below is a table with some JIT compile times. This colab has more timings.
| Batch Size | Compilation Time |
|------------|---------------------|
|64| 21.71 |
|128| 43.99 |
|256| 63.85 |
|512| 107.62 |
|1024| 300.03 |
|2048| 592.16 |
The DP SGD example in the repo runs with errors, so the colab contains fixes from #4446.
EDIT: In general, should we expect compilation time to scale with batch size? If so, why?
Thanks for the question!
No, we shouldn't expect compilation time to scale with batch size. It could mean we're unrolling an array somewhere in the example.
That is, I suspect this is a bug in the example file.
I think it may be a JIT/stax bug, as it seems like stax.MaxPool is the culprit here. When I remove the maxpool layers, the compilation times for per-example gradients stop being linear w.r.t. to the batch size. Here is a minimal example of this.
UPDATE: stax.AvgPool does not appear to have this issue, only stax.MaxPool does. Here's a colab showing that AvgPool compilation times don't double as batch sizes double.
This was already fixed at head by https://github.com/google/jax/pull/4439
Note that PR is too new to be in a jax release, so try doing something like:
!pip install --upgrade git+https://github.com/google/jax
in your Colab.
I get the following timings at head in a Colab:

Most helpful comment
This was already fixed at head by https://github.com/google/jax/pull/4439
Note that PR is too new to be in a
jaxrelease, so try doing something like:in your Colab.
I get the following timings at head in a Colab: