Jax: polyval is exceedingly slow to compile for high degree polynomials

Created on 12 May 2020  路  8Comments  路  Source: google/jax

In my latest attempt at Twitter Driven Development (TM), I wanted to see if JAX could be as fast as Julia for evaluating a 10,000 degree polynomials.

Unfortunately, the implementation of polyval uses a Python for loop, so I can't even run this example because it takes forever to run and then crashes:

import jax.numpy as jnp
import jax
import numpy as np

x = np.random.rand()
p = np.random.randn(10000)
jnp.polyval(p, x)  # works
jax.jit(jnp.polyval)(p, x)  # never finishes

There are probably more sophisticated algorithms (like whatever @sethaxen was using here) but a simple improvement would be to switch from a Python for loop to lax.scan.

performance

Most helpful comment

https://github.com/google/jax/pull/3076 switches from the for loop to lax.scan, which resolves the slow compilation.

It turns out this is a little hard to benchmark in JAX, because 22 us is less than the overhead of calling pre-compiled code (~200 us).

Instead, we can use lax.map to repeat the computation 100x times on different values of x in a loop:

import jax.numpy as jnp
import jax
import numpy as np
from functools import partial
from jax.config import config

config.update("jax_enable_x64", True)

@jax.jit
def bench(p, x):
  return jax.lax.map(partial(jax.grad(jnp.polyval), p), x)

x = jnp.asarray(np.random.randn(100))
p = jnp.asarray(np.random.randn(10000))
bench(p, x).block_until_ready()
%timeit bench(p, x).block_until_ready()
# 3.11 ms 卤 326 碌s per loop (mean 卤 std. dev. of 7 runs, 100 loops each)

31 us on average for the forward+backwards computation is in the same ballbark as Julia's 36 us. It's a nice win that we get this from idiomatic JAX, without the need to write a special backwards pass.

All 8 comments

Twitter-driven development is the best development. I also was curious to know how jax could do here but don't have it on this machine.

The algorithm I'm using is similar to that in Section 2.3.4 of Giles' paper. It's slightly changed to support either complex or real matrix/matrix, scalar/scalar, or mixed scalar/matrix x and p. The WIP PR with the implementation here: https://github.com/JuliaDiff/ChainRules.jl/pull/190.

Although I think you need to switch your x and p to match my function call.

@sethaxen thanks for sharing the details!

Those derivative looks from Giles look exactly the same as the standard gradient rules for the composite operations in JAX, so I think we could do just as well simply by switching to lax.scan and using JAX's standard auto-diff rules.

Exactly, yeah there aren't any efficiency shortcuts there. If jax is sufficiently fast differentiating the polyeval implementation, then there's no benefit from a custom rule (with Zygote at least it benchmarked pretty badly for some reason, hence the PR).

Although I think you need to switch your x and p to match my function call.

It looks like the arguments for NumPy's polyval are flipped compared to Julia's evalpoly: https://numpy.org/doc/1.18/reference/generated/numpy.polyval.html

It looks like the arguments for NumPy's polyval are flipped compared to Julia's evalpoly: https://numpy.org/doc/1.18/reference/generated/numpy.polyval.html

Oh you're right. I was looking at np.polynomial.polynomial.polyval, which is the same as Julia's.

https://github.com/google/jax/pull/3076 switches from the for loop to lax.scan, which resolves the slow compilation.

It turns out this is a little hard to benchmark in JAX, because 22 us is less than the overhead of calling pre-compiled code (~200 us).

Instead, we can use lax.map to repeat the computation 100x times on different values of x in a loop:

import jax.numpy as jnp
import jax
import numpy as np
from functools import partial
from jax.config import config

config.update("jax_enable_x64", True)

@jax.jit
def bench(p, x):
  return jax.lax.map(partial(jax.grad(jnp.polyval), p), x)

x = jnp.asarray(np.random.randn(100))
p = jnp.asarray(np.random.randn(10000))
bench(p, x).block_until_ready()
%timeit bench(p, x).block_until_ready()
# 3.11 ms 卤 326 碌s per loop (mean 卤 std. dev. of 7 runs, 100 loops each)

31 us on average for the forward+backwards computation is in the same ballbark as Julia's 36 us. It's a nice win that we get this from idiomatic JAX, without the need to write a special backwards pass.

I see now that Seth ran this on a 2.4 GHz Intel Core i5:
https://twitter.com/sethaxen/status/1260342201306898432

My benchmark is on a several year old 2.8 GHz Intel Core i7.

I'm sure there's a more nuanced way to do this comparison, but that 17% difference in clockspeed happens to exactly match the difference in performance. It seems plausible that XLA CPU and Julia may be generating essentially the exact same machine instructions from LLVM.

Was this page helpful?
0 / 5 - 0 ratings