I recently switched from autograd to jax, as I saw that jax supports jit which will (hopefully) make runtimes faster. However, when comparing the speed of autograd's jacobian(fun) against jax's jit(jacfwd(fun)), the results are disappointing. I haven't done concrete timings yet, but when using jacobian(fun) in my program, it can produce results in about 5 minutes, however, when using jit(jacfwd(fun)) it'll probably take over an hour (I stopped the program after about 15 minutes when I saw it wasn't going anywhere).
I feel like I must be doing something wrong, has anyone else experiences such large runtime differences when using autograd's jacobian(fun) versus jax's jit(jacfwd(fun))?
Can you provide a small runnable reproduction? It's impossible to speculate about performance without something concrete we can benchmark ourselves.
It's gotta be all compilation time (and something's getting unrolled). +1 to a repro!