Jax: vmap is incompatible with custom pytree

Created on 31 May 2020  路  6Comments  路  Source: google/jax

The following repro code

from jax import tree_util, lax, vmap, numpy as np, jit, pmap

class Special(object):
    def __init__(self, x, y):
        shape = lax.broadcast_shapes(np.shape(x), np.shape(y))
        self.x = np.broadcast_to(x, shape)
        self.y = np.broadcast_to(y, shape)

def special_flatten(v):
    return ((v.x, v.y), None)

def special_unflatten(aux_data, children):
    return Special(*children)

tree_util.register_pytree_node(Special, special_flatten, special_unflatten)

def f(x):
    return Special(x, x)

assert jit(f)(np.ones(3)).x.shape == (3,)
assert lax.map(f, np.ones(3)).x.shape == (3,)
assert pmap(f)(np.ones(3)).x.shape == (3,)
vmap(f)(np.ones(3))  # fail!

triggers the error

Details

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-36-e7c65535399a> in <module>
     18 assert jit(f)(np.ones(3)).x.shape == (2, 2, 3)
     19 assert lax.map(f, np.ones(3)).x.shape == (2, 3, 2)
---> 20 vmap(f)(np.ones(3))

~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/api.py in batched_fun(*args)
    768     _ = _mapped_axis_size(in_tree, args_flat, in_axes_flat, "vmap")
    769     out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 770                               lambda: flatten_axes(out_tree(), out_axes))
    771     return tree_unflatten(out_tree(), out_flat)
    772 

~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/interpreters/batching.py in batch(fun, in_vals, in_dims, out_dim_dests)
     32   # executes a batched version of `fun` following out_dim_dests
     33   batched_fun = batch_fun(fun, in_dims, out_dim_dests)
---> 34   return batched_fun.call_wrapped(*in_vals)
     35 
     36 @lu.transformation_with_aux

~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/linear_util.py in call_wrapped(***failed resolving arguments***)
    152     while stack:
    153       gen, out_store = stack.pop()
--> 154       ans = gen.send(ans)
    155       if out_store is not None:
    156         ans, side = ans

~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/interpreters/batching.py in _batch_fun(sum_match, in_dims, out_dims_thunk, out_dim_dests, *in_vals, **params)
     57     out_vals = yield (master, in_dims,) + in_vals, params
     58     del master
---> 59   out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
     60   out_dims = out_dims_thunk()
     61   for od, od_dest in zip(out_dims, out_dim_dests):

~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/api.py in <lambda>()
    768     _ = _mapped_axis_size(in_tree, args_flat, in_axes_flat, "vmap")
    769     out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 770                               lambda: flatten_axes(out_tree(), out_axes))
    771     return tree_unflatten(out_tree(), out_flat)
    772 

~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/api_util.py in flatten_axes(treedef, axis_tree)
    106   # TODO(mattjj,phawkins): improve this implementation
    107   proxy = object()
--> 108   dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
    109   axes = []
    110   add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))

~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/tree_util.py in tree_unflatten(treedef, leaves)
     68     structure described by `treedef`.
     69   """
---> 70   return treedef.unflatten(leaves)
     71 
     72 def tree_leaves(tree):

<ipython-input-36-e7c65535399a> in special_unflatten(aux_data, children)
      9 
     10 def special_unflatten(aux_data, children):
---> 11     return Special(*children)
     12 
     13 tree_util.register_pytree_node(Special, special_flatten, special_unflatten)

<ipython-input-36-e7c65535399a> in __init__(self, x)
      3 class Special(object):
      4     def __init__(self, x):
----> 5         self.x = np.broadcast_to(x, (2,) + np.shape(x))
      6 
      7 def special_flatten(v):

~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/numpy/lax_numpy.py in broadcast_to(arr, shape)
   1280 def broadcast_to(arr, shape):
   1281   """Like Numpy's broadcast_to but doesn't necessarily return views."""
-> 1282   arr = arr if isinstance(arr, ndarray) else array(arr)
   1283   shape = canonicalize_shape(shape)  # check that shape is concrete
   1284   arr_shape = _shape(arr)

~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/numpy/lax_numpy.py in array(object, dtype, copy, order, ndmin)
   2070       return array(onp.asarray(view), dtype, copy)
   2071 
-> 2072     raise TypeError("Unexpected input type for array: {}".format(type(object)))
   2073 
   2074   if ndmin > ndim(out):

TypeError: Unexpected input type for array: <class 'object'>

The issue here is we are applying np.broadcast_to to an "object" x (instead of a ndarray) when unflatenning the tree under vmap. Is there any workaround available in JAX?

question

All 6 comments

You need a "direct" constructor for unflatten that doesn't apply broadcast_to but rather just assigns the value.

You can use Python tricks to skip __init__ when constructing objects, but the more idiomatic option is to keep class constructors minimal, just doing variable assignment, and to make another dedicated constructor (make_special below) for typical use:

from jax import tree_util, lax, vmap, numpy as np, jit

class Special(object):
    def __init__(self, x):
        self.x = x

def make_special(x):
    return Special(np.broadcast_to(x, (2,) + np.shape(x)))

def special_flatten(v):
    return ((v.x,), None)

def special_unflatten(aux_data, children):
    return Special(*children)

tree_util.register_pytree_node(Special, special_flatten, special_unflatten)

def f(x):
    return make_special(x)

assert jit(f)(np.ones(3)).x.shape == (2, 3)
assert lax.map(f, np.ones(3)).x.shape == (3, 2)
assert vmap(f)(np.ones(3)).x.shape == (3, 2)

(This split between user facing and lower level constructors is pretty common, e.g., it's what NumPy for its ndarray class with array() for construction/coercion.)

By the way, one clue that your original unflatten is doing something really strange is that merely adding jit is changing the output:

>>> f(np.ones(3)).x
DeviceArray([[1., 1., 1.],
             [1., 1., 1.]], dtype=float32)
>>> jit(f)(np.ones(3)).x
DeviceArray([[[1., 1., 1.],
              [1., 1., 1.]],

             [[1., 1., 1.],
              [1., 1., 1.]]], dtype=float32)

You probably don't want that!

one clue that your original unflatten is doing something really strange

Yeah, that's a bad example. I just updated the code to illustrate the issue better. The result is consistent now. jit, map, pmap work as expected but vmap failed.

You need a "direct" constructor for unflatten that doesn't apply broadcast_to but rather just assigns the value.

Thanks! Unfortunately, I need some logic in the constructor to simplify the broadcasting code in other methods of a pytree class. Actually, we want to convert our distributions in NumPyro to pytree to support scan. And we want to allow constructors such as Normal(loc, 1) where loc can have an arbitrary shape. The code Normal(loc, np.ones(np.shape(loc))) is a bit ugly if we incorporate that restriction.

Like I said, you can use some Python magic for an alternative constructor that skips __init__:

from jax import tree_util, lax, vmap, numpy as np, jit, pmap

class Special:
    def __init__(self, x, y):
        shape = lax.broadcast_shapes(np.shape(x), np.shape(y))
        self.x = np.broadcast_to(x, shape)
        self.y = np.broadcast_to(y, shape)

    @classmethod
    def restore(cls, x, y):
        obj = object.__new__(cls)
        obj.x = x
        obj.y = y
        return obj

def special_flatten(v):
    return ((v.x, v.y), None)

def special_unflatten(aux_data, children):
    return Special.restore(*children)

tree_util.register_pytree_node(Special, special_flatten, special_unflatten)

def f(x):
    return Special(x, x)

assert jit(f)(np.ones(3)).x.shape == (3,)
assert lax.map(f, np.ones(3)).x.shape == (3,)
assert vmap(f)(np.ones(3)).x.shape == (3,)

I don't really recommend it as a design pattern (cheap constructors that just do assignment/validation are generally preferred, for this among other reasons), but it works if you need it for backwards compatibility. If doing it from scratch, I would consider having a normal() function that constructs a Normal() object.

Oh, good idea!! I believe we can use your trick to create a temporary instance (with unbroadcasted arguments) and invoke __init__ again after scan. Thanks a lot, @shoyer !

Edit: It turns out that we won't have to worry about this issue. The following way of using vmap works:

def f(x):
    def g(x):
        return Special(x, x)

    return lax.map(g, x).x

vmap(f)(np.ones((3, 4)))

I think this issue is resolved, but if not please let us know @fehiepsi ! By the way, it's great to hear from you again :D

Was this page helpful?
0 / 5 - 0 ratings