Jax: namedtuple support in arguments to transformed functions

Created on 25 Feb 2019  路  6Comments  路  Source: google/jax

It would be great if xla.abstractify would also accept namedtuples. Loop state's can consist of quite a lot of values and organizing them in a namedtuple rather than a tuple would make things nicer.

enhancement

Most helpful comment

There's actually a convenient way to add support for custom container types throughout JAX, not just in loop carries but also for grad, jit, vmap, etc, all at once. Of course it's not documented at all... :)

You can register a custom type as a "pytree" (tree-like Python container) like this:

from collections import namedtuple
from jax.tree_util import register_pytree_node
from jax import grad, jit
import jax.numpy as np

Point = namedtuple("Point", ["x", "y"])

register_pytree_node(
    Point,
    lambda xs: (tuple(xs), None),  # tell JAX how to unpack to an iterable
    lambda _, xs: Point(*xs)       # tell JAX how to pack back into a Point
)


def f(pt):
  return np.sqrt(pt.x**2 + pt.y**2)

pt = Point(1., 2.)

print f(pt)        # 2.236068
print grad(f)(pt)  # Point(x=..., y=...)

g = jit(f)
print g(pt)  # 2.236068

So that's an easy and general way to get your code working now. It also means you can have your namedtuple classes contain nested tuples/lists/dicts, or have them nested in other tuples/lists/dicts.

(By the way, the extra data that can be returned by the to-iterable function and consumed by the to-pytree fun is for things like dict keys. In the above example, we're just returning None when mapping to an iterable and then ignoring it when reconstructing.)

However, we should consider making JAX work with all namedtuple classes by default, without having to register them. Any thoughts on that, or objections to it?

All 6 comments

There's actually a convenient way to add support for custom container types throughout JAX, not just in loop carries but also for grad, jit, vmap, etc, all at once. Of course it's not documented at all... :)

You can register a custom type as a "pytree" (tree-like Python container) like this:

from collections import namedtuple
from jax.tree_util import register_pytree_node
from jax import grad, jit
import jax.numpy as np

Point = namedtuple("Point", ["x", "y"])

register_pytree_node(
    Point,
    lambda xs: (tuple(xs), None),  # tell JAX how to unpack to an iterable
    lambda _, xs: Point(*xs)       # tell JAX how to pack back into a Point
)


def f(pt):
  return np.sqrt(pt.x**2 + pt.y**2)

pt = Point(1., 2.)

print f(pt)        # 2.236068
print grad(f)(pt)  # Point(x=..., y=...)

g = jit(f)
print g(pt)  # 2.236068

So that's an easy and general way to get your code working now. It also means you can have your namedtuple classes contain nested tuples/lists/dicts, or have them nested in other tuples/lists/dicts.

(By the way, the extra data that can be returned by the to-iterable function and consumed by the to-pytree fun is for things like dict keys. In the above example, we're just returning None when mapping to an iterable and then ignoring it when reconstructing.)

However, we should consider making JAX work with all namedtuple classes by default, without having to register them. Any thoughts on that, or objections to it?

I revised the issue title because we'd handle the issue in api.py and xla.abstractify would never need to see these types (just like it never sees tuples/lists/dicts).

Ha, that's awesome! Regarding namedtuple support: Given that namedtuple's are real subclasses of tuples, I think supporting all namedtuples out of the box would be the most intuitive solution.

+1 to having JAX work with all namedtuple classes

+1 Our existing codebase has been heavily relying on namedtuple and it would be great to support it in JAX.

736 made namedtuple classes transparent by default. Let us know if you have any issues with it!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

alexbw picture alexbw  路  26Comments

martiningram picture martiningram  路  21Comments

kirk86 picture kirk86  路  22Comments

NeilGirdhar picture NeilGirdhar  路  23Comments

JuliusKunze picture JuliusKunze  路  23Comments