Jax: Regression in performance of samplers

Created on 10 May 2019  路  7Comments  路  Source: google/jax

With jax versions > 0.1.24, random samplers are slow. I think that it is due to recent changes on how jit works with static_args/kwargs. A script to reproduce,

import time
from jax import random

t = time.time()
random.normal(random.PRNGKey(0), shape=(1,))
print(time.time() - t)
t = time.time()
random.normal(random.PRNGKey(0), shape=(1,))
print(time.time() - t)

which returns 0.12923526763916016 and 0.12221717834472656.

However, if we wrap these samplers in some function, then it is fast. For example,

def f():
    return random.normal(random.PRNGKey(0), shape=(1,))

t = time.time()
f()
print(time.time() - t)
t = time.time()
f()
print(time.time() - t)

will return 0.12787413597106934 and 0.0010831356048583984.

I think that there is a small bug elsewhere which forces the sampler recompile. If this is an expected behaviour, then which function we should use to wrap these samplers to make it not recompiled?
cc @neerajprad

bug

Most helpful comment

This was an issue with all uses of static_argnums, though it only cropped up recently in random.py because of how we changed random.py following the kwargs-handling change.

692 should fix it.

The use of static_argnums was only resulting in compilation cache hits based on _object identity_ equivalence. That meant that random.py functions were only caching on object identity equivalence (i.e. x is y instead of equality x == y) of their shape parameters. In your first code example, which passed in (1,) the first time and a fresh literal (1,) the second time, those two objects didn't have the same identity (though they would have compared equal). In your second code example, the same literal was used twice and so the cache was being hit on object identity.

The fix in #692 is just to make static_argnums use equality checks when possible (i.e. when argument objects corresponding to static argnums have __hash__ and, by implication, __eq__), and fall back to object identity equivalence when not. This should still allow arbitrary objects to be passed as static args (i.e. even unhashable ones), but also let us get cache hits where appropriate, especially for shape args like in random.py.

Does that make sense? What do you think?

I can't thank you enough for spotting this and providing such a clear repro. Your sleuthing made this an easy fix, and without it we would have missed it for who knows how long!

All 7 comments

Thanks so much for catching this! We will fix it asap.

(Longer term we aim to set up a continuous benchmarking solution to avoid regressions.)

This was an issue with all uses of static_argnums, though it only cropped up recently in random.py because of how we changed random.py following the kwargs-handling change.

692 should fix it.

The use of static_argnums was only resulting in compilation cache hits based on _object identity_ equivalence. That meant that random.py functions were only caching on object identity equivalence (i.e. x is y instead of equality x == y) of their shape parameters. In your first code example, which passed in (1,) the first time and a fresh literal (1,) the second time, those two objects didn't have the same identity (though they would have compared equal). In your second code example, the same literal was used twice and so the cache was being hit on object identity.

The fix in #692 is just to make static_argnums use equality checks when possible (i.e. when argument objects corresponding to static argnums have __hash__ and, by implication, __eq__), and fall back to object identity equivalence when not. This should still allow arbitrary objects to be passed as static args (i.e. even unhashable ones), but also let us get cache hits where appropriate, especially for shape args like in random.py.

Does that make sense? What do you think?

I can't thank you enough for spotting this and providing such a clear repro. Your sleuthing made this an easy fix, and without it we would have missed it for who knows how long!

Can you verify the fix on your end, and let me know what you see?

(I just pushed 0.1.28 out to pypi with this fix. Hopefully it actually worked!)

Whoa, it is so fast. Thank you so much for a very clear explanation! To my end, the samplers are fast again now. :)

Thanks for such a quick resolution and for cutting a new release, @mattjj! :)

JAX bugs fixed in 30 minutes or less, or your money* back!

*_JAX costs $0_

Was this page helpful?
0 / 5 - 0 ratings

Related issues

ricardobarroslourenco picture ricardobarroslourenco  路  35Comments

ericmjl picture ericmjl  路  53Comments

proteneer picture proteneer  路  22Comments

samuela picture samuela  路  27Comments

shoyer picture shoyer  路  35Comments