Jax: slogdet timing differences

Created on 23 Dec 2019  路  6Comments  路  Source: google/jax

Hello!

I was wondering why there is a timing difference (gpu) between jit(slogdet) and the plain slogdet expression by one order of magnitude.

From the source code, it seems that slogdet is decorated with jit already link to code.

Therefore, I would have expected no difference. Though, this is what I get.

import jax.numpy as np
from jax import random, jit

key = random.PRNGKey(0)
a = random.normal(key, (2, 2))

jit_slogdet = jit(np.linalg.slogdet)
slogdet = np.linalg.slogdet

# initial run
jit_slogdet(a)
slogdet(a)

# timeit
%timeit -n 10 jit_slogdet(a)[1].block_until_ready()
%timeit -n 10 slogdet(a)[1].block_until_ready()

Output

148 碌s 卤 16.4 碌s per loop (mean 卤 std. dev. of 7 runs, 10 loops each)
68.6 ms 卤 437 碌s per loop (mean 卤 std. dev. of 7 runs, 10 loops each)

What am I missing?

jax-version: 0.1.55

Cheers
Christian

bug

Most helpful comment

My guess is that this eval_jaxpr in the custom_transforms logic is getting cache misses every time, for the reason outlined in #1829. However, I had to roll back #1829 because of one internal test failure (not a JAX test) that I didn't understand.

My long-promised still-vaporware rewrite of custom_transforms would attempt to avoid eval_jaxpr, so two fix options for this issue are (1) try to roll-forward #1829, (2) just wait for a custom_transforms rewrite that fixes all problems and brings about world peace.

All 6 comments

It looks like using custom_tranforms in addition to jit is slowing down raw slogdet. @mattjj is this an issue you're already aware of?

I wasn鈥檛 aware of this but it sounds plausible. We need to rewrite custom_transforms and this will go on the list of fixes. I might not have time until after the holidays though.

Thanks for raising this!

Yeah, it looks like custom_transforms needs to be inside a jit to avoid recompilation. With slogdet, the custom_transforms decorator is outside jit, so the inner jit is re-traced into a new xla_call in a new jaxpr each time slogdet is called, and that new xla_call becomes a new XLA compilation.

For now probably best to use another jit like you have in your code. You can also use

from jax import config
config.FLAGS.jax_log_compiles=True

to see when JAX is unexpectedly compiling things.

My guess is that this eval_jaxpr in the custom_transforms logic is getting cache misses every time, for the reason outlined in #1829. However, I had to roll back #1829 because of one internal test failure (not a JAX test) that I didn't understand.

My long-promised still-vaporware rewrite of custom_transforms would attempt to avoid eval_jaxpr, so two fix options for this issue are (1) try to roll-forward #1829, (2) just wait for a custom_transforms rewrite that fixes all problems and brings about world peace.

I think this is fixed now the custom_transforms rewrite has landed!

On a V100 GPU, I get:

10 loops, best of 3: 182 碌s per loop
10 loops, best of 3: 289 碌s per loop
Was this page helpful?
0 / 5 - 0 ratings

Related issues

DylanMuir picture DylanMuir  路  3Comments

sursu picture sursu  路  3Comments

yfji picture yfji  路  3Comments

sschoenholz picture sschoenholz  路  3Comments

alexbw picture alexbw  路  3Comments