Jax: backward pass of scan is very slow to compile in CPU

Created on 26 May 2019  路  10Comments  路  Source: google/jax

The following script took more than 1 minute to compile in my system (with the newest versions of jax & jaxlib). A very strange phenomenon is: compiling time will be reduced to 2s for some random changes (e.g. replacing 0.5 * level by level) in the body function.

import jax.numpy as np
from jax import jit, grad, lax
from jax.config import config; config.update("jax_platform_name", "cpu")

def f(s):
    def scan_fn(carry, y_t):
        level, s = carry
        exp_val = (level + 2 * level ** 0.5) * s[0]
        exp_val = np.clip(exp_val, a_min=0)
        level = y_t / s[0] + 0.5 * level
        s = y_t / level + 0.5 * s
        return (level, s), exp_val

    return lax.scan(scan_fn, (1., s), np.ones(80))[1].sum()

jit(grad(f))(np.ones(2))
bug

Most helpful comment

This looks like it's either an LLVM bug or an XLA bug. The computation isn't all that large, but lots of time is being spent in the LLVM SelectionDAG logic. I filed an (internal) bug for the XLA team.

All 10 comments

This feels pretty strange to me. Just make_jaxpr(grad(f))(np.ones(2)) also displays the 2min+ computation time, which I believe is mostly spent in LLVM (but seems to be caused by an enormous HLO graph that's compiled somewhere during grad of scan; haven't yet looked into why?)

The jaxpr for grad(f) itself is pretty small and looks fine, though.

This looks like it's either an LLVM bug or an XLA bug. The computation isn't all that large, but lots of time is being spent in the LLVM SelectionDAG logic. I filed an (internal) bug for the XLA team.

Thank @jekbradbury, @hawkinsp for looking into this issue!

I agree that the computation isn't that large in absolute terms, but it's still surprising to me (the one that takes a lot of time in LLVM looks vaguely unrolled at the HLO level or something?)

We can probably hope for XLA/LLVM to get faster, but also we can improve the kind of graph we're staging out.

I noticed (because @hawkinsp pointed it out in related circumstances) that we're adding a lot of scalars literals to the loop-carry tuple, when instead those could be staged into the XLA computation as literals. @dougalm and I recently added some logic to instantiate scalar literals as XLA literals rather than hosting them into the loop carry (in d27bc0a, part of #704), but we conservatively only switched that on for the Python types int and float. In particular, DeviceArray constants still got hoisted.

I noticed that this computation was doing a good amount of scalar-hoisting, and so in #780 (specifically 9c931dd) I sketched out some logic that allows more types to be treated as literals in jaxprs (and hence in the staged-out XLA computations). That seems to make the compile time for the OP's code essentially instantaneous.

I want to look over that code with fresh eyes tomorrow, but I'm optimistic it (or something like it) will handle this scan compilation time issue and a lot of related ones.

@fehiepsi can you sync past #780 and verify that it solved this issue for you? If not, let's reopen and keep exploring!

Thanks so much, @mattjj! #780 has solved the compiling issue of the script in this topic. Currently, I am getting a shape error issue in the original script (after the update) so I can't test for it. I'll isolate the error then get back to you.

@mattjj Here is a repro code to trigger that error:

import jax.numpy as np
from jax import jit, grad, lax
from jax.config import config; config.update("jax_platform_name", "cpu")

def f(init_s):
    seasonality = init_s.shape[0]

    def scan_fn(carry, t):
        level, s, moving_sum = carry
        season = s[0] * level ** 0.6
        exp_val = level + 2 * level ** 0.5 + season
        exp_val = np.clip(exp_val, a_min=0)

        moving_sum = moving_sum + y[t] - np.where(t >= seasonality, y[t - seasonality], 0.)
        level_p = np.where(t >= seasonality, moving_sum / seasonality, y[t] - season)
        level = 0.2 * level_p + (1 - 0.2) * level
        level = np.clip(level, a_min=0)
        new_s = (0.3 * (y[t] - level) / season + (1 - 0.3)) * s[0]
        s = np.concatenate([s[1:], new_s[None]], axis=0)
        return (level, s, moving_sum), exp_val

    y = np.ones(80)
    level_init = y[0]
    s_init = np.concatenate([init_s[1:], init_s[:1]], axis=0)
    moving_sum = level_init
    (last_level, last_s, moving_sum), exp_vals = lax.scan(
        scan_fn, (level_init, s_init, moving_sum), np.arange(1, y.shape[0]))
    return exp_vals.sum()

f(np.ones(38))
jit(f)(np.ones(38))
grad(f)(np.ones(38))
jit(grad(f))(np.ones(38))  # TypeError: (ShapedArray(float32[]), ())

The good new is your fix runs pretty well for many other scripts. :D I think that it is just a typo somewhere in that PR.

This script triggers an error in forward pass

def f(init_s):
    seasonality = init_s.shape[0]

    def scan_fn(carry, t):
        level, s, moving_sum = carry
        season = s[0] * level ** 0.6
        exp_val = level + 2 * level ** 0.5 + season
        exp_val = np.clip(exp_val, a_min=0)

        moving_sum = moving_sum + y[t] - np.where(t >= seasonality, y[t - seasonality], 0.)
        #level_p = np.where(t >= seasonality, moving_sum / seasonality, y[t] - season)
        #level = 0.2 * level_p + (1 - 0.2) * level
        level = np.clip(level, a_min=0)
        new_s = (0.3 * (y[t] - level) / season + (1 - 0.3)) * s[0]
        s = np.concatenate([s[1:], new_s[None]], axis=0)
        return (level, s, moving_sum), exp_val

    y = np.ones(80)
    level_init = y[0]
    s_init = np.concatenate([init_s[1:], init_s[:1]], axis=0)
    moving_sum = level_init
    (last_level, last_s, moving_sum), exp_vals = lax.scan(
        scan_fn, (level_init, s_init, moving_sum), np.arange(1, y.shape[0]))
    return exp_vals.sum()

f(np.ones(38))
jit(f)(np.ones(38))  # Type Error

Thanks for the repros! I think this is exercising a bug that I noticed in #780 (but didn't cause in #780 as far as I know... though maybe you have evidence to the contrary!). I'm going to open a separate issue.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

alexbw picture alexbw  路  26Comments

dwang55 picture dwang55  路  22Comments

NeilGirdhar picture NeilGirdhar  路  23Comments

ricardobarroslourenco picture ricardobarroslourenco  路  35Comments

shoyer picture shoyer  路  24Comments