Jax: Extremely slow runtime when using jacfwd()

Created on 24 Feb 2020  路  2Comments  路  Source: google/jax

I recently switched from autograd to jax, as I saw that jax supports jit which will (hopefully) make runtimes faster. However, when comparing the speed of autograd's jacobian(fun) against jax's jit(jacfwd(fun)), the results are disappointing. I haven't done concrete timings yet, but when using jacobian(fun) in my program, it can produce results in about 5 minutes, however, when using jit(jacfwd(fun)) it'll probably take over an hour (I stopped the program after about 15 minutes when I saw it wasn't going anywhere).

I feel like I must be doing something wrong, has anyone else experiences such large runtime differences when using autograd's jacobian(fun) versus jax's jit(jacfwd(fun))?

performance question

All 2 comments

Can you provide a small runnable reproduction? It's impossible to speculate about performance without something concrete we can benchmark ourselves.

It's gotta be all compilation time (and something's getting unrolled). +1 to a repro!

Was this page helpful?
0 / 5 - 0 ratings