Repro from @fehiepsi in the #772 thread:
import jax.numpy as np
from jax import jit, grad, lax
from jax.config import config; config.update("jax_platform_name", "cpu")
def f(init_s):
seasonality = init_s.shape[0]
def scan_fn(carry, t):
level, s, moving_sum = carry
season = s[0] * level ** 0.6
exp_val = level + 2 * level ** 0.5 + season
exp_val = np.clip(exp_val, a_min=0)
moving_sum = moving_sum + y[t] - np.where(t >= seasonality, y[t - seasonality], 0.)
#level_p = np.where(t >= seasonality, moving_sum / seasonality, y[t] - season)
#level = 0.2 * level_p + (1 - 0.2) * level
level = np.clip(level, a_min=0)
new_s = (0.3 * (y[t] - level) / season + (1 - 0.3)) * s[0]
s = np.concatenate([s[1:], new_s[None]], axis=0)
return (level, s, moving_sum), exp_val
y = np.ones(80)
level_init = y[0]
s_init = np.concatenate([init_s[1:], init_s[:1]], axis=0)
moving_sum = level_init
(last_level, last_s, moving_sum), exp_vals = lax.scan(
scan_fn, (level_init, s_init, moving_sum), np.arange(1, y.shape[0]))
return exp_vals.sum()
f(np.ones(38))
jit(f)(np.ones(38)) # Type Error
I think this is related to partial_eval._split_avals not instantiating tuple-trees of units but instead truncating at the same tree prefix as the second_component tree.
Here is a much smaller repro code that exercises this issue:
import jax.numpy as np
from jax import jit, lax
def f(s):
def scan_fn(carry, t):
a, b = carry
return (a, b), a + b
_, bs = lax.scan(scan_fn, (0., s), np.arange(10))
return bs.sum()
print(f(np.array(1.)))
jit(f)(np.array(1.))
@mattjj I just find a trick to resolve this problem. That is to replace
_, bs = lax.scan(scan_fn, (0., s), np.arange(10))
by
i = 0. + s - s
_, bs = lax.scan(scan_fn, (i, s), np.arange(10))
I think that currently, scan does not accept a constant init value.
@fehiepsi I believe we have a branch that will close this. Hoping to merge sometime this week.
@mattjj Thanks! I can confirm that the issue is resolved in master branch.
Woo!
Most helpful comment
@fehiepsi I believe we have a branch that will close this. Hoping to merge sometime this week.