It would be great if xla.abstractify would also accept namedtuples. Loop state's can consist of quite a lot of values and organizing them in a namedtuple rather than a tuple would make things nicer.
There's actually a convenient way to add support for custom container types throughout JAX, not just in loop carries but also for grad, jit, vmap, etc, all at once. Of course it's not documented at all... :)
You can register a custom type as a "pytree" (tree-like Python container) like this:
from collections import namedtuple
from jax.tree_util import register_pytree_node
from jax import grad, jit
import jax.numpy as np
Point = namedtuple("Point", ["x", "y"])
register_pytree_node(
Point,
lambda xs: (tuple(xs), None), # tell JAX how to unpack to an iterable
lambda _, xs: Point(*xs) # tell JAX how to pack back into a Point
)
def f(pt):
return np.sqrt(pt.x**2 + pt.y**2)
pt = Point(1., 2.)
print f(pt) # 2.236068
print grad(f)(pt) # Point(x=..., y=...)
g = jit(f)
print g(pt) # 2.236068
So that's an easy and general way to get your code working now. It also means you can have your namedtuple classes contain nested tuples/lists/dicts, or have them nested in other tuples/lists/dicts.
(By the way, the extra data that can be returned by the to-iterable function and consumed by the to-pytree fun is for things like dict keys. In the above example, we're just returning None when mapping to an iterable and then ignoring it when reconstructing.)
However, we should consider making JAX work with all namedtuple classes by default, without having to register them. Any thoughts on that, or objections to it?
I revised the issue title because we'd handle the issue in api.py and xla.abstractify would never need to see these types (just like it never sees tuples/lists/dicts).
Ha, that's awesome! Regarding namedtuple support: Given that namedtuple's are real subclasses of tuples, I think supporting all namedtuples out of the box would be the most intuitive solution.
+1 to having JAX work with all namedtuple classes
+1 Our existing codebase has been heavily relying on namedtuple and it would be great to support it in JAX.
Most helpful comment
There's actually a convenient way to add support for custom container types throughout JAX, not just in loop carries but also for
grad,jit,vmap, etc, all at once. Of course it's not documented at all... :)You can register a custom type as a "pytree" (tree-like Python container) like this:
So that's an easy and general way to get your code working now. It also means you can have your namedtuple classes contain nested tuples/lists/dicts, or have them nested in other tuples/lists/dicts.
(By the way, the extra data that can be returned by the to-iterable function and consumed by the to-pytree fun is for things like dict keys. In the above example, we're just returning None when mapping to an iterable and then ignoring it when reconstructing.)
However, we should consider making JAX work with all namedtuple classes by default, without having to register them. Any thoughts on that, or objections to it?