Here is a repro code, which works for previous version
import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
def dz_dt(z, t, theta):
""" Lotka鈥揤olterra equations. """
u = z[0]
v = z[1]
alpha, beta, gamma, delta = theta[0], theta[1], theta[2], theta[3]
du_dt = (alpha - beta * v) * u
dv_dt = (-gamma + delta * u) * v
return jnp.stack([du_dt, dv_dt])
def f(z):
y = odeint(dz_dt, z, jnp.arange(10.), jnp.ones(4))
return jnp.sum(y)
jax.grad(f)(jnp.ones(2))
Running the above script raises the error TypeError: Primal inputs to reverse-mode differentiation must be of float or complex type, got type int32. I tried to trace the error but got no hint where int variables are created. I think the issue happens after https://github.com/google/jax/pull/3562.
A simpler repro code
def dz_dt(z, t):
return jnp.stack([z[0], z[1]])
def f(z):
y = odeint(dz_dt, z, jnp.arange(10.))
return jnp.sum(y)
jax.grad(f)(jnp.ones(2))
It seems to me that the indices 0, 1 cause the issue.
Ah, this is indeed because of #3562. Thanks for catching it!
Unfortunately I've got to go afk for a while, but I should be able to fix this tonight (if no one beats me to it).
As a temporary workaround, you can use this version:
from jax.experimental.ode import _odeint_wrapper
def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf):
return _odeint_wrapper(func, rtol, atol, mxstep, y0, t, *args)
Thanks, @mattjj!
I didn't get to it last night, but #3587 should fix this. Thanks for catching it.
I'll do another pypi release after the fix goes in.
Just pushed jax==0.1.72 to pypi.