Hello!
I was wondering why there is a timing difference (gpu) between jit(slogdet) and the plain slogdet expression by one order of magnitude.
From the source code, it seems that slogdet is decorated with jit already link to code.
Therefore, I would have expected no difference. Though, this is what I get.
import jax.numpy as np
from jax import random, jit
key = random.PRNGKey(0)
a = random.normal(key, (2, 2))
jit_slogdet = jit(np.linalg.slogdet)
slogdet = np.linalg.slogdet
# initial run
jit_slogdet(a)
slogdet(a)
# timeit
%timeit -n 10 jit_slogdet(a)[1].block_until_ready()
%timeit -n 10 slogdet(a)[1].block_until_ready()
Output
148 碌s 卤 16.4 碌s per loop (mean 卤 std. dev. of 7 runs, 10 loops each)
68.6 ms 卤 437 碌s per loop (mean 卤 std. dev. of 7 runs, 10 loops each)
What am I missing?
jax-version: 0.1.55
Cheers
Christian
It looks like using custom_tranforms in addition to jit is slowing down raw slogdet. @mattjj is this an issue you're already aware of?
I wasn鈥檛 aware of this but it sounds plausible. We need to rewrite custom_transforms and this will go on the list of fixes. I might not have time until after the holidays though.
Thanks for raising this!
Yeah, it looks like custom_transforms needs to be inside a jit to avoid recompilation. With slogdet, the custom_transforms decorator is outside jit, so the inner jit is re-traced into a new xla_call in a new jaxpr each time slogdet is called, and that new xla_call becomes a new XLA compilation.
For now probably best to use another jit like you have in your code. You can also use
from jax import config
config.FLAGS.jax_log_compiles=True
to see when JAX is unexpectedly compiling things.
My guess is that this eval_jaxpr in the custom_transforms logic is getting cache misses every time, for the reason outlined in #1829. However, I had to roll back #1829 because of one internal test failure (not a JAX test) that I didn't understand.
My long-promised still-vaporware rewrite of custom_transforms would attempt to avoid eval_jaxpr, so two fix options for this issue are (1) try to roll-forward #1829, (2) just wait for a custom_transforms rewrite that fixes all problems and brings about world peace.
I think this is fixed now the custom_transforms rewrite has landed!
On a V100 GPU, I get:
10 loops, best of 3: 182 碌s per loop
10 loops, best of 3: 289 碌s per loop
Most helpful comment
My guess is that this
eval_jaxprin thecustom_transformslogic is getting cache misses every time, for the reason outlined in #1829. However, I had to roll back #1829 because of one internal test failure (not a JAX test) that I didn't understand.My long-promised still-vaporware rewrite of
custom_transformswould attempt to avoideval_jaxpr, so two fix options for this issue are (1) try to roll-forward #1829, (2) just wait for acustom_transformsrewrite that fixes all problems and brings about world peace.