Jax: jitted function is slow with JAX 0.1.42

Created on 23 Aug 2019  路  5Comments  路  Source: google/jax

Here is a repro script

import jax
import jax.numpy as np
from jax import random, lax, jit


def welford_covariance():
    def init_fn(size):
        return np.zeros(size), np.zeros(size), 0

    def update_fn(sample, state):
        mean, m2, n = state
        n = n + 1
        delta_pre = sample - mean
        mean = mean + delta_pre / n
        delta_post = sample - mean
        m2 = m2 + delta_pre * delta_post
        return mean, m2, n

    def final_fn(state):
        mean, m2, n = state
        cov = m2 / (n - 1)
        cov_inv_sqrt = np.sqrt(np.reciprocal(cov))
        return cov, cov_inv_sqrt

    return init_fn, update_fn, final_fn


def warmup_adapter():
    mm_init, mm_update, mm_final = welford_covariance()

    def init_fn(z, rng, mass_matrix_size):
        inverse_mass_matrix = np.ones(mass_matrix_size)
        mass_matrix_sqrt = inverse_mass_matrix
        mm_state = mm_init(mass_matrix_size)
        return (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)

    def _update_at_window_end(z, rng_ss, state):
        inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng = state
        inverse_mass_matrix, mass_matrix_sqrt = mm_final(mm_state)
        mm_state = mm_init(inverse_mass_matrix.shape[-1])
        return (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)

    def update_fn(t, accept_prob, z, state):
        inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng = state
        rng, rng_ss = random.split(rng)
        state = (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)
        state = lax.cond(t < 10,
                         (z, rng_ss, state), lambda args: _update_at_window_end(*args),
                         state, lambda x: x)
        return state

    return init_fn, update_fn


wa_init, wa_update = warmup_adapter()
wa_update = jit(wa_update)  # uncomment this will make it fast

z = np.ones(3)
wa_state = wa_init(z, random.PRNGKey(0), mass_matrix_size=3)
import time
for t in range(10):
    tic = time.time()
    wa_state = wa_update(t, 0.1 * t, z, wa_state)
    print(time.time() - tic)

which returns

0.0958707332611084
0.08851313591003418
0.08699154853820801
0.09005379676818848
0.08801078796386719
0.08790731430053711
0.09052276611328125
0.08893036842346191
0.0877068042755127
0.08900189399719238

while using non-jit wa_update, we get

0.12209296226501465
0.002318859100341797
0.0018045902252197266
0.002086162567138672
0.0018031597137451172
0.0017747879028320312
0.0020046234130859375
0.002245187759399414
0.0020575523376464844
0.0024704933166503906

I think this is just a typo somewhere (similar to #1237).

cc @neerajprad

bug performance

Most helpful comment

It's hitting a compile every time (not sure why yet). I added a print("COMPILING") before this line and it's printing for every iteration. Pretty much all of these bad performance issues come from missing the compilation cache.

I hereby promise that in the PR closing this bug I'll also add at least a print statement there that can be enabled with a flag! That will make it super easy to check when a performance problem is a recompilation issue.

All 5 comments

Thanks for flagging this. As a stop-gap, you might need to tell your users to pip install jax==0.1.41.

We'll get this squashed ASAP though.

It's hitting a compile every time (not sure why yet). I added a print("COMPILING") before this line and it's printing for every iteration. Pretty much all of these bad performance issues come from missing the compilation cache.

I hereby promise that in the PR closing this bug I'll also add at least a print statement there that can be enabled with a flag! That will make it super easy to check when a performance problem is a recompilation issue.

7539325 will squash this, included in #1240. The fix needed was this line.

I made a mistake in compilation caching when I merged a change in #1224 with a change made on master that affected how xla.abstractify works on DeviceArrays.

The real lesson here is we need tests to avoid performance regressions.

It's hitting a compile every time

Thanks @mattjj, I think so too.

It seems related to an issue I got a few days ago: the code uses all GPU memory while I just run CPU code (with jaxlib 0.1.23 installed from https://storage.googleapis.com/jax-releases). Things are back to normal again when I install jaxlib version from pypi. I'll check if #1240 fixes it too.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

zhongwen picture zhongwen  路  3Comments

kunc picture kunc  路  3Comments

yfji picture yfji  路  3Comments

sussillo picture sussillo  路  3Comments

clemisch picture clemisch  路  3Comments