The following repro code
from jax import tree_util, lax, vmap, numpy as np, jit, pmap
class Special(object):
def __init__(self, x, y):
shape = lax.broadcast_shapes(np.shape(x), np.shape(y))
self.x = np.broadcast_to(x, shape)
self.y = np.broadcast_to(y, shape)
def special_flatten(v):
return ((v.x, v.y), None)
def special_unflatten(aux_data, children):
return Special(*children)
tree_util.register_pytree_node(Special, special_flatten, special_unflatten)
def f(x):
return Special(x, x)
assert jit(f)(np.ones(3)).x.shape == (3,)
assert lax.map(f, np.ones(3)).x.shape == (3,)
assert pmap(f)(np.ones(3)).x.shape == (3,)
vmap(f)(np.ones(3)) # fail!
triggers the error
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-36-e7c65535399a> in <module>
18 assert jit(f)(np.ones(3)).x.shape == (2, 2, 3)
19 assert lax.map(f, np.ones(3)).x.shape == (2, 3, 2)
---> 20 vmap(f)(np.ones(3))
~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/api.py in batched_fun(*args)
768 _ = _mapped_axis_size(in_tree, args_flat, in_axes_flat, "vmap")
769 out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 770 lambda: flatten_axes(out_tree(), out_axes))
771 return tree_unflatten(out_tree(), out_flat)
772
~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/interpreters/batching.py in batch(fun, in_vals, in_dims, out_dim_dests)
32 # executes a batched version of `fun` following out_dim_dests
33 batched_fun = batch_fun(fun, in_dims, out_dim_dests)
---> 34 return batched_fun.call_wrapped(*in_vals)
35
36 @lu.transformation_with_aux
~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/linear_util.py in call_wrapped(***failed resolving arguments***)
152 while stack:
153 gen, out_store = stack.pop()
--> 154 ans = gen.send(ans)
155 if out_store is not None:
156 ans, side = ans
~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/interpreters/batching.py in _batch_fun(sum_match, in_dims, out_dims_thunk, out_dim_dests, *in_vals, **params)
57 out_vals = yield (master, in_dims,) + in_vals, params
58 del master
---> 59 out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
60 out_dims = out_dims_thunk()
61 for od, od_dest in zip(out_dims, out_dim_dests):
~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/api.py in <lambda>()
768 _ = _mapped_axis_size(in_tree, args_flat, in_axes_flat, "vmap")
769 out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 770 lambda: flatten_axes(out_tree(), out_axes))
771 return tree_unflatten(out_tree(), out_flat)
772
~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/api_util.py in flatten_axes(treedef, axis_tree)
106 # TODO(mattjj,phawkins): improve this implementation
107 proxy = object()
--> 108 dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
109 axes = []
110 add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/tree_util.py in tree_unflatten(treedef, leaves)
68 structure described by `treedef`.
69 """
---> 70 return treedef.unflatten(leaves)
71
72 def tree_leaves(tree):
<ipython-input-36-e7c65535399a> in special_unflatten(aux_data, children)
9
10 def special_unflatten(aux_data, children):
---> 11 return Special(*children)
12
13 tree_util.register_pytree_node(Special, special_flatten, special_unflatten)
<ipython-input-36-e7c65535399a> in __init__(self, x)
3 class Special(object):
4 def __init__(self, x):
----> 5 self.x = np.broadcast_to(x, (2,) + np.shape(x))
6
7 def special_flatten(v):
~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/numpy/lax_numpy.py in broadcast_to(arr, shape)
1280 def broadcast_to(arr, shape):
1281 """Like Numpy's broadcast_to but doesn't necessarily return views."""
-> 1282 arr = arr if isinstance(arr, ndarray) else array(arr)
1283 shape = canonicalize_shape(shape) # check that shape is concrete
1284 arr_shape = _shape(arr)
~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/numpy/lax_numpy.py in array(object, dtype, copy, order, ndmin)
2070 return array(onp.asarray(view), dtype, copy)
2071
-> 2072 raise TypeError("Unexpected input type for array: {}".format(type(object)))
2073
2074 if ndmin > ndim(out):
TypeError: Unexpected input type for array: <class 'object'>
The issue here is we are applying np.broadcast_to to an "object" x (instead of a ndarray) when unflatenning the tree under vmap. Is there any workaround available in JAX?
You need a "direct" constructor for unflatten that doesn't apply broadcast_to but rather just assigns the value.
You can use Python tricks to skip __init__ when constructing objects, but the more idiomatic option is to keep class constructors minimal, just doing variable assignment, and to make another dedicated constructor (make_special below) for typical use:
from jax import tree_util, lax, vmap, numpy as np, jit
class Special(object):
def __init__(self, x):
self.x = x
def make_special(x):
return Special(np.broadcast_to(x, (2,) + np.shape(x)))
def special_flatten(v):
return ((v.x,), None)
def special_unflatten(aux_data, children):
return Special(*children)
tree_util.register_pytree_node(Special, special_flatten, special_unflatten)
def f(x):
return make_special(x)
assert jit(f)(np.ones(3)).x.shape == (2, 3)
assert lax.map(f, np.ones(3)).x.shape == (3, 2)
assert vmap(f)(np.ones(3)).x.shape == (3, 2)
(This split between user facing and lower level constructors is pretty common, e.g., it's what NumPy for its ndarray class with array() for construction/coercion.)
By the way, one clue that your original unflatten is doing something really strange is that merely adding jit is changing the output:
>>> f(np.ones(3)).x
DeviceArray([[1., 1., 1.],
[1., 1., 1.]], dtype=float32)
>>> jit(f)(np.ones(3)).x
DeviceArray([[[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.]]], dtype=float32)
You probably don't want that!
one clue that your original unflatten is doing something really strange
Yeah, that's a bad example. I just updated the code to illustrate the issue better. The result is consistent now. jit, map, pmap work as expected but vmap failed.
You need a "direct" constructor for unflatten that doesn't apply broadcast_to but rather just assigns the value.
Thanks! Unfortunately, I need some logic in the constructor to simplify the broadcasting code in other methods of a pytree class. Actually, we want to convert our distributions in NumPyro to pytree to support scan. And we want to allow constructors such as Normal(loc, 1) where loc can have an arbitrary shape. The code Normal(loc, np.ones(np.shape(loc))) is a bit ugly if we incorporate that restriction.
Like I said, you can use some Python magic for an alternative constructor that skips __init__:
from jax import tree_util, lax, vmap, numpy as np, jit, pmap
class Special:
def __init__(self, x, y):
shape = lax.broadcast_shapes(np.shape(x), np.shape(y))
self.x = np.broadcast_to(x, shape)
self.y = np.broadcast_to(y, shape)
@classmethod
def restore(cls, x, y):
obj = object.__new__(cls)
obj.x = x
obj.y = y
return obj
def special_flatten(v):
return ((v.x, v.y), None)
def special_unflatten(aux_data, children):
return Special.restore(*children)
tree_util.register_pytree_node(Special, special_flatten, special_unflatten)
def f(x):
return Special(x, x)
assert jit(f)(np.ones(3)).x.shape == (3,)
assert lax.map(f, np.ones(3)).x.shape == (3,)
assert vmap(f)(np.ones(3)).x.shape == (3,)
I don't really recommend it as a design pattern (cheap constructors that just do assignment/validation are generally preferred, for this among other reasons), but it works if you need it for backwards compatibility. If doing it from scratch, I would consider having a normal() function that constructs a Normal() object.
Oh, good idea!! I believe we can use your trick to create a temporary instance (with unbroadcasted arguments) and invoke __init__ again after scan. Thanks a lot, @shoyer !
Edit: It turns out that we won't have to worry about this issue. The following way of using vmap works:
def f(x):
def g(x):
return Special(x, x)
return lax.map(g, x).x
vmap(f)(np.ones((3, 4)))
I think this issue is resolved, but if not please let us know @fehiepsi ! By the way, it's great to hear from you again :D