Jax: stax.serial.apply_fun is not a valid JAX type inside odeint

Created on 1 May 2020  路  9Comments  路  Source: google/jax

Hi,
FWIW, I'm using a self-built jax and jaxlib following instructions from #2083.

#
# Name                    Version                   Build  Channel
jax                       0.1.64                    <pip>
jaxlib                    0.1.45                    <pip>

I'm trying to do get gradients through an ODE solver. First, I ran into AssertionError issue #2718 and I think I solved it by passing all the arguments directly into odeint. Then I followed instructions to solve another AssertionError issue #2531 by doing vmap of grads instead of grads of vmap . Now I'm getting the following error.


Full trace back.

----> 1 batch_grad(batch_y0, batch_t, batch_y,[1.3,1.8], [U1,U2], [U1_params,U2_params])

~/Code/jax/jax/api.py in batched_fun(*args)
    805     _check_axis_sizes(in_tree, args_flat, in_axes_flat)
    806     out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 807                               lambda: _flatten_axes(out_tree(), out_axes))
    808     return tree_unflatten(out_tree(), out_flat)
    809 

~/Code/jax/jax/interpreters/batching.py in batch(fun, in_vals, in_dims, out_dim_dests)
     32   # executes a batched version of `fun` following out_dim_dests
     33   batched_fun = batch_fun(fun, in_dims, out_dim_dests)
---> 34   return batched_fun.call_wrapped(*in_vals)
     35 
     36 @lu.transformation_with_aux

~/Code/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    148     gen = None
    149 
--> 150     ans = self.f(*args, **dict(self.params, **kwargs))
    151     del args
    152     while stack:

~/Code/jax/jax/api.py in value_and_grad_f(*args, **kwargs)
    436     f_partial, dyn_args = argnums_partial(f, argnums, args)
    437     if not has_aux:
--> 438       ans, vjp_py = _vjp(f_partial, *dyn_args)
    439     else:
    440       ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)

~/Code/jax/jax/api.py in _vjp(fun, *primals, **kwargs)
   1437   if not has_aux:
   1438     flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1439     out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
   1440     out_tree = out_tree()
   1441   else:

~/Code/jax/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
    104 def vjp(traceable, primals, has_aux=False):
    105   if not has_aux:
--> 106     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    107   else:
    108     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)

~/Code/jax/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
     93   _, in_tree = tree_flatten(((primals, primals), {}))
     94   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
---> 95   jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
     96   out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
     97   assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)

~/Code/jax/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom, trace_type)
    435   with new_master(trace_type, bottom=bottom) as master:
    436     fun = trace_to_subjaxpr(fun, master, instantiate)
--> 437     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    438     assert not env
    439     del master

~/Code/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    148     gen = None
    149 
--> 150     ans = self.f(*args, **dict(self.params, **kwargs))
    151     del args
    152     while stack:

~/Code/jax/jax/api.py in f_jitted(*args, **kwargs)
    152     flat_fun, out_tree = flatten_fun(f, in_tree)
    153     out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
--> 154                        name=flat_fun.__name__)
    155     return tree_unflatten(out_tree(), out)
    156 

~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params)
   1003     tracers = map(top_trace.full_raise, args)
   1004     process = getattr(top_trace, processor)
-> 1005     outs = map(full_lower, process(primitive, f, tracers, params))
   1006   return apply_todos(env_trace_todo(), outs)
   1007 

~/Code/jax/jax/interpreters/ad.py in process_call(self, call_primitive, f, tracers, params)
    342     name = params.get('name', f.__name__)
    343     params = dict(params, name=wrap_name(name, 'jvp'))
--> 344     result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **params)
    345     primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
    346     return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]

~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params)
   1003     tracers = map(top_trace.full_raise, args)
   1004     process = getattr(top_trace, processor)
-> 1005     outs = map(full_lower, process(primitive, f, tracers, params))
   1006   return apply_todos(env_trace_todo(), outs)
   1007 

~/Code/jax/jax/interpreters/partial_eval.py in process_call(self, call_primitive, f, tracers, params)
    175     in_pvs, in_consts = unzip2([t.pval for t in tracers])
    176     fun, aux = partial_eval(f, self, in_pvs)
--> 177     out_flat = call_primitive.bind(fun, *in_consts, **params)
    178     out_pvs, jaxpr, env = aux()
    179     env_tracers = map(self.full_raise, env)

~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params)
   1003     tracers = map(top_trace.full_raise, args)
   1004     process = getattr(top_trace, processor)
-> 1005     outs = map(full_lower, process(primitive, f, tracers, params))
   1006   return apply_todos(env_trace_todo(), outs)
   1007 

~/Code/jax/jax/interpreters/batching.py in process_call(self, call_primitive, f, tracers, params)
    146     else:
    147       f, dims_out = batch_subtrace(f, self.master, dims)
--> 148       vals_out = call_primitive.bind(f, *vals, **params)
    149       return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())]
    150 

~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params)
    999   if top_trace is None:
   1000     with new_sublevel():
-> 1001       outs = primitive.impl(f, *args, **params)
   1002   else:
   1003     tracers = map(top_trace.full_raise, args)

~/Code/jax/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, *args)
    460 
    461 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name):
--> 462   compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))
    463   try:
    464     return compiled_fun(*args)

~/Code/jax/jax/linear_util.py in memoized_fun(fun, *args)
    219       fun.populate_stores(stores)
    220     else:
--> 221       ans = call(fun, *args)
    222       cache[key] = (ans, fun.stores)
    223     return ans

~/Code/jax/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, *arg_specs)
    477   pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args]
    478   jaxpr, pvals, consts = pe.trace_to_jaxpr(
--> 479       fun, pvals, instantiate=False, stage_out=True, bottom=True)
    480 
    481   _map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))

~/Code/jax/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom, trace_type)
    435   with new_master(trace_type, bottom=bottom) as master:
    436     fun = trace_to_subjaxpr(fun, master, instantiate)
--> 437     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    438     assert not env
    439     del master

~/Code/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    148     gen = None
    149 
--> 150     ans = self.f(*args, **dict(self.params, **kwargs))
    151     del args
    152     while stack:

<ipython-input-17-de50dc731d85> in loss(batch_y0, batch_t, batch_y, params, ufuncs, uparams)
      1 @partial(jit, static_argnums=(4,))
      2 def loss(batch_y0, batch_t, batch_y, params, ufuncs,uparams):
----> 3     pred_y = odeint(batch_y0,batch_t,params,ufuncs,uparams)
      4     loss = np.mean(np.abs(pred_y-batch_y))
      5     return loss

~/Code/jax/jax/experimental/ode.py in odeint(func, y0, t, rtol, atol, mxstep, *args)
    152     shape/structure as `y0` except with a new leading axis of length `len(t)`.
    153   """
--> 154   return _odeint_wrapper(func, rtol, atol, mxstep, y0, t, *args)
    155 
    156 @partial(jax.jit, static_argnums=(0, 1, 2, 3))

~/Code/jax/jax/api.py in f_jitted(*args, **kwargs)
    149       dyn_args = args
    150     args_flat, in_tree = tree_flatten((dyn_args, kwargs))
--> 151     _check_args(args_flat)
    152     flat_fun, out_tree = flatten_fun(f, in_tree)
    153     out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,

~/Code/jax/jax/api.py in _check_args(args)
   1558     if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
   1559       raise TypeError("Argument '{}' of type {} is not a valid JAX type"
-> 1560                       .format(arg, type(arg)))
   1561 
   1562 def _valid_jaxtype(arg):

TypeError: Argument '<function serial.<locals>.apply_fun at 0x2b06c3d6f7a0>' of type <class 'function'> is not a valid JAX type

I'm passing two stax.Serial modules with three Dense layers each as an input to odeint to integrate the Lotka-Volterra ODEs. ufuncs and uparams contains apply functions and params of stax.Serial module.

def lv_UDE(y,t,params,ufuncs,uparams):
    R, F = y
    alpha, theta = params
    U1, U2 = ufuncs
    U1_params, U2_params = uparams
    dRdt = alpha*R - U1(U1_params, y)
    dFdt = -theta*F + U2(U2_params, y)
    return np.array([dRdt,dFdt])

I'm trying to get gradients through an odeint w.r.t uparams. Is there a workaround to pass stax.Serial modules as an argument? Thanks in advance.

better_errors documentation question

All 9 comments

Could you please share a full example of how you get this error? Ideally something that I could copy into a terminal and run.

Hi,
I just noticed that even the non vmapped version of a function with stax.serial as an input errors out with the same error message. Here's the full example. Thanks

import jax 
import jax.numpy as np
import numpy as onp
from jax import random
from jax import grad, jit, vmap, value_and_grad
from jax.experimental.ode import odeint
from jax.experimental import stax
from functools import partial


def lv(y,t,params):
    """
    original lotka-volterra equations
    """
    R,F = y
    alpha, beta, gamma, theta = params
    dRdt = alpha*R - beta*R*F
    dFdt = gamma*R*F - theta*F
    return np.hstack([dRdt,dFdt])

t = np.linspace(0.,4.,num=1000)
y0 = np.array([0.44249296,4.6280594])

true_y = odeint(partial(lv,params=[1.3,0.9,0.5,1.8]),y0=y0,t=t) #training data generation


def lv_UDE(y,t,params,ufuncs,uparams):
    """
    additional parameters include stax.Serial 
    modules and uparams associated with them
    """
    R, F = y
    alpha, theta = params
    U1, U2 = ufuncs
    U1_params, U2_params = uparams
    dRdt = alpha*R - U1(U1_params, y)
    dFdt = -theta*F + U2(U2_params, y)
    return np.hstack([dRdt,dFdt])

#two modules of stax Serial
U1_init, U1 = stax.serial(stax.Dense(32),stax.Tanh, 
                            stax.Dense(32), stax.Tanh, 
                            stax.Dense(32),stax.Tanh,
                           stax.Dense(1))
U2_init, U2 = stax.serial(stax.Dense(32),stax.Tanh, 
                            stax.Dense(32), stax.Tanh, 
                            stax.Dense(32),stax.Tanh,
                           stax.Dense(1))

key, subkey = random.split(random.PRNGKey(0))

_,U1_params = U1_init(key,(2,)) #inputs of size 2
_,U2_params = U2_init(subkey,(2,))
key,subkey = random.split(subkey)


def get_batch():
    """
    Get batches of inital conditions and 
    times along with true time history
    """
    s = onp.random.choice(onp.arange(1000 - 20, 
                        dtype=onp.int64), 20, replace=False)
    batch_y0 = true_y[s]  # (M, D)
    batch_t = t[:20]  # (T)
    batch_y = np.stack([true_y[s + i] for i in range(20)])  # (T, M, D)
    return batch_y0, batch_t, batch_y


def loss(batch_y0, batch_t, batch_y, params, ufuncs,uparams):
    """
    Mean absolute loss 
    """
    pred_y = odeint(batch_y0,batch_t,params,ufuncs,uparams) # integrate using odeint
    loss = np.mean(np.abs(pred_y-batch_y)) #calculate loss
    return loss


grads = value_and_grad(loss,(5,)) #grads w.r.t uparams 
batch_grad = vmap(grads,(0, None, None, None, None, None)) #vectorize over initial conditions (batch_y0)


grads(y0,t,true_y,[1.3,1.8], [U1,U2], 
      [U1_params,U2_params]) #non vmappped  doesn't work
batch_grad(batch_y0, batch_t, batch_y,[1.3,1.8], 
           [U1,U2], [U1_params,U2_params]) #vmap version same error

Hey @skrsna , thanks for the question!

In your example, it seems the lv_UDE is never called. Is that intentional?

The underlying issue here is that odeint can't take function-valued arguments in *args; those must be arrays (or potentially-nested containers of arrays, like potentially-nested lists/tuples/dicts of arrays). Instead of passing ufuncs via the *args of odeint, maybe you can instead just write something like:

def lv_UDE(ufuncs,y,t,params,uparams):  # moved ufuncs to front
    ...

odeint(partial(lv_UDE, ufuncs), ...)

WDYT?

It's possible we could support passing function-valued arguments in *args, but I'm not sure it'd be worth the extra complexity. We could at least raise a better error...

Hi @mattjj , thanks for the super fast response. My bad I forgot to add lv_UDE while refactoring the code to make it look nice. I'll try your solution and update the issue with the workaround. Thanks again.

Awesome, glad to hear that might help!

I just pushed #2931 to improve the error message. Now running your test program we get:

TypeError: The contents of odeint *args must be arrays or scalars, but got
<function serial.<locals>.apply_fun at 0x7f17fc69ca70>.

I also improved the docstring from this:

     *args: tuple of additional arguments for `func`.

To this:

    *args: tuple of additional arguments for `func`, which must be arrays
      scalars, or (nested) standard Python containers (tuples, lists, dicts,
      namedtuples, i.e. pytrees) of those types.

To make odeint handle those types in *args automatically, we could try to hoist non-arrays out of *args inside odeint. But maybe we can open a separate issue for that enhancement if it's a high priority for anyone. (@shoyer interested to hear if you have a strong opinion!)

I'm going to let #2931 close this issue, just so as to keep our issues under control. Let me know if that's a bad idea :)

Sure, please close the issue. I'm currently trying to try out your suggestions and I'll update the issue with working code just in case if anyone else runs into the same error.

Hi @mattjj , I tried your solution and it works seamlessly with vmap. Thanks again.

Was this page helpful?
0 / 5 - 0 ratings