Jax: JAX Differentially Private SGD Example Compilation Very Slow

Created on 3 Oct 2020  路  3Comments  路  Source: google/jax

The differentially private SGD example takes a very long time to JIT compile at larger batch sizes. The compilation time spikes after a batch size of 512. Any way to fix this or get around this?

Below is a table with some JIT compile times. This colab has more timings.

| Batch Size | Compilation Time |
|------------|---------------------|
|64| 21.71 |
|128| 43.99 |
|256| 63.85 |
|512| 107.62 |
|1024| 300.03 |
|2048| 592.16 |

The DP SGD example in the repo runs with errors, so the colab contains fixes from #4446.

EDIT: In general, should we expect compilation time to scale with batch size? If so, why?

bug

Most helpful comment

This was already fixed at head by https://github.com/google/jax/pull/4439

Note that PR is too new to be in a jax release, so try doing something like:

!pip install --upgrade git+https://github.com/google/jax

in your Colab.

I get the following timings at head in a Colab:

image

All 3 comments

Thanks for the question!

No, we shouldn't expect compilation time to scale with batch size. It could mean we're unrolling an array somewhere in the example.

That is, I suspect this is a bug in the example file.

I think it may be a JIT/stax bug, as it seems like stax.MaxPool is the culprit here. When I remove the maxpool layers, the compilation times for per-example gradients stop being linear w.r.t. to the batch size. Here is a minimal example of this.

UPDATE: stax.AvgPool does not appear to have this issue, only stax.MaxPool does. Here's a colab showing that AvgPool compilation times don't double as batch sizes double.

This was already fixed at head by https://github.com/google/jax/pull/4439

Note that PR is too new to be in a jax release, so try doing something like:

!pip install --upgrade git+https://github.com/google/jax

in your Colab.

I get the following timings at head in a Colab:

image

Was this page helpful?
0 / 5 - 0 ratings

Related issues

shannon63 picture shannon63  路  3Comments

sussillo picture sussillo  路  3Comments

fehiepsi picture fehiepsi  路  3Comments

sursu picture sursu  路  3Comments

asross picture asross  路  3Comments