Here is a simple example, which numerically integrates the product of two Gaussian pdfs. One of the Gaussians is fixed, with mean always at 0. The other Gaussian varies in its mean:
import time
import jax.numpy as np
from jax import jit
from jax.scipy.stats.norm import pdf
# set up evaluation points for numerical integration
integr_resolution = 6400
lower_bound = -100
upper_bound = 100
integr_grid = np.linspace(lower_bound, upper_bound, integr_resolution)
proba = pdf(integr_grid)
integration_weight = (upper_bound - lower_bound) / integr_resolution
# integrate with new mean
def integrate(mu_new):
x_new = integr_grid - mu_new
proba_new = pdf(x_new)
total_proba = sum(proba * proba_new * integration_weight)
return total_proba
print('starting jit')
start = time.perf_counter()
integrate = jit(integrate)
integrate(1)
stop = time.perf_counter()
print('took: ', stop - start)
The function looks seemingly simple, but it doesn't scale at all:
integr_resolution | seconds to execute
------------ | -------------
100 | 0.107
200 | 0.23
400 | 0.537
800 | 1.52
1600 | 5.2
3200 | 19
6400 | 134
For reference, the unjitted function, applied to integr_resolution=6400 takes 0.02s.
I thought that this might be related to the fact that the function is accessing a global variable. But moving the code to set up the integration points inside of the function has no notable influence on the timing. The following code takes 5.36s to run. It corresponds to the table entry with 1600 which previously took 5.2s:
# integrate with new mean
def integrate(mu_new):
# set up evaluation points for numerical integration
integr_resolution = 1600
lower_bound = -100
upper_bound = 100
integr_grid = np.linspace(lower_bound, upper_bound, integr_resolution)
proba = pdf(integr_grid)
integration_weight = (upper_bound - lower_bound) / integr_resolution
x_new = integr_grid - mu_new
proba_new = pdf(x_new)
total_proba = sum(proba * proba_new * integration_weight)
return total_proba
What is happening here?
It's because the code says sum where it should say np.sum.
sum is a Python built-in that extracts each element of a sequence and sums them one by one using the + operator. This has the effect of building a large, unrolled chain of adds which XLA takes a long time to compile. (To be honest, I'm kind of amazed this worked at all!)
If you use np.sum, then JAX builds a single XLA reduction operator, which is much faster to compile.
Does this resolve the question? I'm not sure what we could do better here, although I admit it's a bit surprising!
And just to show how I figured this out: I used jax.make_jaxpr, which dumps JAX's internal trace representation of a function. Here, it shows:
In [3]: import jax
In [4]: jax.make_jaxpr(integrate)(1)
Out[4]:
{ lambda b c ; ; a.
let d = convert_element_type[ new_dtype=float32
old_dtype=int32 ] a
e = sub c d
f = sub e 0.0
g = pow f 2.0
h = div g 1.0
i = add 1.8378770351409912 h
j = neg i
k = div j 2.0
l = exp k
m = mul b l
n = mul m 2.0
o = slice[ start_indices=(0,)
limit_indices=(1,)
strides=(1,)
operand_shape=(100,) ] n
p = reshape[ new_sizes=()
dimensions=None
old_sizes=(1,) ] o
q = add p 0.0
r = slice[ start_indices=(1,)
limit_indices=(2,)
strides=(1,)
operand_shape=(100,) ] n
s = reshape[ new_sizes=()
dimensions=None
old_sizes=(1,) ] r
t = add q s
u = slice[ start_indices=(2,)
limit_indices=(3,)
strides=(1,)
operand_shape=(100,) ] n
v = reshape[ new_sizes=()
dimensions=None
old_sizes=(1,) ] u
w = add t v
x = slice[ start_indices=(3,)
limit_indices=(4,)
strides=(1,)
operand_shape=(100,) ] n
y = reshape[ new_sizes=()
dimensions=None
old_sizes=(1,) ] x
z = add w y
... similarly ...
and it's then obvious why this is slow: the program is very big.
Contrast the np.sum version:
In [5]: def integrate(mu_new):
...: x_new = integr_grid - mu_new
...:
...: proba_new = pdf(x_new)
...: total_proba = np.sum(proba * proba_new * integration_weight)
...:
...: return total_proba
...:
In [6]: jax.make_jaxpr(integrate)(1)
Out[6]:
{ lambda b c ; ; a.
let d = convert_element_type[ new_dtype=float32
old_dtype=int32 ] a
e = sub c d
f = sub e 0.0
g = pow f 2.0
h = div g 1.0
i = add 1.8378770351409912 h
j = neg i
k = div j 2.0
l = exp k
m = mul b l
n = mul m 2.0
o = reduce_sum[ axes=(0,)
input_shape=(100,) ] n
in [o] }
Thank you very much for this prompt and insightful help.
I posted the same question on SO, if you want to get the rep for this great support :D https://stackoverflow.com/questions/59068666/jax-time-to-jit-a-function-grows-superlinear-with-memory-accessed-by-function
Most helpful comment
And just to show how I figured this out: I used
jax.make_jaxpr, which dumps JAX's internal trace representation of a function. Here, it shows:and it's then obvious why this is slow: the program is very big.
Contrast the
np.sumversion: