import collections
from jax.lax import scan
Point = collections.namedtuple(
'Point',
['x', 'y']
)
class ZeroPoint(Point):
def is_zero(self):
return (x==0) and (y==0)
def loop(in_, idx):
return in_
a = ZeroPoint(1., 0.)
b = ZeroPoint(1., 1.)
scan(loop, [a,b], np.arange(2))
The error was TypeError: <class '__main__.ZeroPoint'> is not a valid Jax type
A related note would be, a class which inherits pytree class may be considered as a valid pytree type as well?
So far we haven't followed the convention that a subclass of a pytree is a pytree. I think there are clear ways to implement that, but it's a bit more magic that can lead to less predictable behavior (plus potential overheads, though those would need to be profiled before we worry about them). So it's less clear whether we should add that complexity.
When there's an ambiguous decision like whether pytrees should follow subclassing, we usually take the conservative route and keep things the way they are (more explicit, less magic) until a concrete use case comes along, or we get enough evidence that we're really violating user expectations.
isinstance instead of manually traversing the MRO).We might want to separate out the the narrower case of namedtuple subclasses, though. Namedtuples were requested several times to act like pytrees by default (without having to explicitly register them), and users expected them to work like other Python builtins. But our special handling for namedtuples only works for direct subclasses of tuple (like namedtuple classes are, but not their children). I think we could make namedtuple subclasses work just by writing isinstance there.
That would fix the original issue, but not address the broader question in the second comment. Is it worth handling all subclassing? Opinions welcome!
By the way, for posterity (since I know @zhongwen already knows this), as a workaround in the meantime you can always register your class as a pytree using jax.tree_util.register_pytree_node, and for this special case of a namedtuple you can use this utility by @sschoenholz:
def register_pytree_namedtuple(cls):
register_pytree_node(
cls,
lambda xs: (tuple(xs), None),
lambda _, xs: cls(*xs))
In particular, you'd just have to import that function and write after your class definition:
register_pytree_namedtuple(ZeroPoint)
As a tradeoff that requires more explicitness from the user but keeps the core system simpler and faster, that seems pretty good!
I forgot to say, thanks for opening this with such a clear and concise explanation!
Most helpful comment
So far we haven't followed the convention that a subclass of a pytree is a pytree. I think there are clear ways to implement that, but it's a bit more magic that can lead to less predictable behavior (plus potential overheads, though those would need to be profiled before we worry about them). So it's less clear whether we should add that complexity.
When there's an ambiguous decision like whether pytrees should follow subclassing, we usually take the conservative route and keep things the way they are (more explicit, less magic) until a concrete use case comes along, or we get enough evidence that we're really violating user expectations.
506 explored one way to implement a pytree mechanism that follows subclassing, though I think it added significant complexity. If we decide to go that route, there might be simpler alternative mechanisms (like using
isinstanceinstead of manually traversing the MRO).We might want to separate out the the narrower case of namedtuple subclasses, though. Namedtuples were requested several times to act like pytrees by default (without having to explicitly register them), and users expected them to work like other Python builtins. But our special handling for namedtuples only works for direct subclasses of
tuple(like namedtuple classes are, but not their children). I think we could make namedtuple subclasses work just by writingisinstancethere.That would fix the original issue, but not address the broader question in the second comment. Is it worth handling all subclassing? Opinions welcome!
By the way, for posterity (since I know @zhongwen already knows this), as a workaround in the meantime you can always register your class as a pytree using
jax.tree_util.register_pytree_node, and for this special case of a namedtuple you can use this utility by @sschoenholz:In particular, you'd just have to import that function and write after your class definition:
As a tradeoff that requires more explicitness from the user but keeps the core system simpler and faster, that seems pretty good!