Jax: scan abstract values and tree prefix convention bug

Created on 30 May 2019  路  5Comments  路  Source: google/jax

Repro from @fehiepsi in the #772 thread:

import jax.numpy as np
from jax import jit, grad, lax
from jax.config import config; config.update("jax_platform_name", "cpu")

def f(init_s):
    seasonality = init_s.shape[0]

    def scan_fn(carry, t):
        level, s, moving_sum = carry
        season = s[0] * level ** 0.6
        exp_val = level + 2 * level ** 0.5 + season
        exp_val = np.clip(exp_val, a_min=0)

        moving_sum = moving_sum + y[t] - np.where(t >= seasonality, y[t - seasonality], 0.)
        #level_p = np.where(t >= seasonality, moving_sum / seasonality, y[t] - season)
        #level = 0.2 * level_p + (1 - 0.2) * level
        level = np.clip(level, a_min=0)
        new_s = (0.3 * (y[t] - level) / season + (1 - 0.3)) * s[0]
        s = np.concatenate([s[1:], new_s[None]], axis=0)
        return (level, s, moving_sum), exp_val

    y = np.ones(80)
    level_init = y[0]
    s_init = np.concatenate([init_s[1:], init_s[:1]], axis=0)
    moving_sum = level_init
    (last_level, last_s, moving_sum), exp_vals = lax.scan(
        scan_fn, (level_init, s_init, moving_sum), np.arange(1, y.shape[0]))
    return exp_vals.sum()

f(np.ones(38))
jit(f)(np.ones(38))  # Type Error

I think this is related to partial_eval._split_avals not instantiating tuple-trees of units but instead truncating at the same tree prefix as the second_component tree.

bug

Most helpful comment

@fehiepsi I believe we have a branch that will close this. Hoping to merge sometime this week.

All 5 comments

Here is a much smaller repro code that exercises this issue:

import jax.numpy as np
from jax import jit, lax

def f(s):
    def scan_fn(carry, t):
        a, b = carry
        return (a, b), a + b

    _, bs = lax.scan(scan_fn, (0., s), np.arange(10))
    return bs.sum()

print(f(np.array(1.)))
jit(f)(np.array(1.))

@mattjj I just find a trick to resolve this problem. That is to replace

    _, bs = lax.scan(scan_fn, (0., s), np.arange(10))

by

    i = 0. + s - s
    _, bs = lax.scan(scan_fn, (i, s), np.arange(10))

I think that currently, scan does not accept a constant init value.

@fehiepsi I believe we have a branch that will close this. Hoping to merge sometime this week.

@mattjj Thanks! I can confirm that the issue is resolved in master branch.

Woo!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

fehiepsi picture fehiepsi  路  3Comments

murphyk picture murphyk  路  3Comments

harshit-2115 picture harshit-2115  路  3Comments

shannon63 picture shannon63  路  3Comments

kunc picture kunc  路  3Comments