One of the problems with making random number generators into objects (https://github.com/google/jax/issues/2294) is that objects are not closed under JAX's transformation rules like batching. I'm creating this issue as a discussion point.
It seems like what's necessary is an abstract base class for objects that need to support transformation rules. I don't know enough to propose something too concrete, but maybe:
class JaxArrayLike:
@abstractmethod
def as_jax_array(self):
"""
Returns T where T is either a JAX array or a sequence of T.
"""
raise NotImplementedError
Currently, a simple neural network object is a mess of static members:
class NeuralNetwork:
def __init__(self, sizes, key):
keys = random.split(key, len(sizes))
self.sizes = sizes
self.weights = [self.random_layer_params(m, n, k)
for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
# Initialization ----------------------------------------------------------
@staticmethod
def random_layer_params(m, n, key, scale=1e-2):
w_key, b_key = random.split(key)
return (scale * random.normal(w_key, (n, m)),
scale * random.normal(b_key, (n,)))
# Learning ----------------------------------------------------------------
@staticmethod
@jit
def _update(weights, images, targets):
grads = grad(BaseNeuralNetwork.loss)(weights, images, targets)
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(weights, grads)]
@staticmethod
def loss(weights, images, targets):
preds = batched_predict(weights, images)
return -jnp.sum(preds * targets)
def update(self, images, targets):
self.weights = self._update(self.weights, images, targets)
def accuracy(self, images, targets):
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(
batched_predict(self.weights, images), axis=1)
return jnp.mean(predicted_class == target_class)
# Inference -------------------------------------------------------------------
def predict(weights, image):
# per-example predictions
activations = image
for w, b in weights[:-1]:
outputs = jnp.dot(w, activations) + b
activations = relu(outputs)
final_w, final_b = weights[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits - logsumexp(logits)
batched_predict = vmap(predict, in_axes=(None, 0))
But by deriving from the abstract class, hopefully most of those static methods could be replaced with regular methods, and as_jax_array could return self.weights. Even better would be to add object-oriented structure to the weights (make them say a list or graph of of objects), which works as longs as as_jax_array is also called on the components of its return value.
JAX transformation rules don't let you safely mutate state, but you can absolutely use immutable objects with transformations. You just need to register them as "pytrees": https://jax.readthedocs.io/en/latest/notebooks/JAX_pytrees.html
Why not return a pytree from a special method? That would allow loss and _update and predict to be nonstatic.
(I know this doesn't address state mutation, which is a different problem.)
Why not return a pytree from a special method? That would allow
lossand_updateandpredictto be nonstatic.
You can do this by registering your class as a pytree. Is there any particular advantage to using special methods?
The main one I can think of is that it allows the method to be overridden. For example, a child class may want to do something like:
class SomeJaxySubclass(SomeJaxySuperclass):
def as_jax_array(self):
return [super().as_jax_array(), self.more_weights]
This facilitates _cooperative inheritance_ since you might not know who your parent class actually is and what he's putting into the jax array. You only need to know what your subclass is adding.
The registry system also makes polymorphic code more verbose since every class in the polymorphic tree has to be registered. I'm not sure what registering entails, but if there's some preparation work, it could probably be done using __init_subclass__ on an abstract base class.
Finally, it's nice to be able to do something like isinstance(some_object, JaxArrayLike), in user code, and it might allow JAX to give better errors.
I suspect it might work just fine to register a generic base class as a pytree. You should give that a try, and if it works well we can probably put it into JAX.
Oh right, because it will call its polymorphic method. I see, good thinking.
So the only question left is the mutable state. I'm still trying to wrap my head around your example.
Thanks for being patient with me. I think I understand now: It seems like the main problem is that vmap wants to build vectorized versions of user-defined types, but doesn't know how to do that. I think something like this might work:
from abc import abstractmethod
from typing import Optional, Tuple
from jax.tree_util import register_pytree_node
__all__ = ['JaxValuesLike']
class JaxValuesLike:
def __init__(self, *, jax_shape: Optional[Tuple[int, ...]]):
super().__init__()
self.jax_shape = jax_shape
def __init_subclass__(cls):
def special_unflatten(aux_data, children):
jax_shape = aux_data
return cls.from_jax_values(children, jax_shape=jax_shape)
register_pytree_node(cls,
special_flatten,
special_unflatten)
@classmethod
@abstractmethod
def from_jax_values(cls,
values,
jax_shape: Optional[Tuple[int, ...]] = None):
"""
Args:
values: A JAX pytree of values from which the objects are
constructed
jax_shape: If None, the values represent a single object,
otherwise the values represent an array of objects of the given
jax_shape.
Returns:
An object with the given jax_shape constructed with the given
values.
"""
raise NotImplementedError
@abstractmethod
def as_jax_values(self):
"""
Returns:
values: A JAX pytree of values representing the object.
"""
raise NotImplementedError
def special_flatten(obj: JaxValuesLike):
children = obj.as_jax_values()
return children, obj.jax_shape
Now, in your RNG example, vmap applied to a Generator (whose self.jax_vmap_shape is None) would call from_jax_values to create a vectorized Generator with the correct self.jax_vmap_shape. The uniform method on Generator would use self.jax_vmap_shape and everything should work?
Thanks for the design discussion!
We're not planning to make JAX extensible by OOP subclassing/inheritance, which IMO isn't a good means of composition. (I used a lot of OOP patterns in grad school.)
Our goal with JAX is to maximize composability, and the best way we know how to do that is to provide a minimal functional substrate. As you've demonstrated, users can create wrappers if they like, for OOP abstractions or even for side-effects (as Haiku and Flax both do). That separation of concerns is working well so far.
We're kind of drowning in GitHub issues right now. Since AIUI the main suggestion here is to add subclassing-based extensibility to JAX itself (rather than doing it in user code), and since we don't plan to do that, I'm inclined to close this issue. WDYT?
@mattjj Thanks for the discussion Matt. I wrote up this issue pretty early on in my learning JAX. Since then I've written much more JAX code. I'm still not sure what the best way is to organize my code. Since there's not a lot of documentation, I've been going through all of the JAX issues to gain insight. (You've written some very informative comments in them!)
One of the reasons that object-oriented code seems to be useful is as a natural way of attaching behavior to data. For example, passing around array-typed random seeds around and forcing the programmer to remember the type and to always use the appropriate functions is annoying. And then what do you do if you decide to use a different random number generator? Are you going to pass a flag to every function that accepts a seed (or calls a function that does)? This flag is essentially the polymorphic type.
Let's consider a more complicated example like an object with many arrays in it, and many methods that act on it. vmap doesn't know how to broadcast the methods. I think what you end up having to do is vmap static methods and have the object call the vmapped methods? One of my big questions for you is: is vmap ever called non-explicitly? That is, if I call vmap on a function f, could that vmap make other vmap calls on functions called by f? If so, there may be significant problems with mixing in OO code because vmap won't be able to figure out how to apply to objects and their methods.
I understand that there are bigger priorities and you might want to defer this kind of problem so feel free to close this issue. If it's a question of work, I may be able to help out a bit. I'm a Xoogler, and I've contributed a variety of open-source projects (including Python itself)—although I really need to finish my PhD. JAX is really helping with this, so a huge thank you for your wonderful project!!
I think a baseclass in jax.tree_util that allows for defining a few special methods for flattening/unflattening and is automatically registered as a pytree would be perfectly safe: https://github.com/google/jax/issues/2396. Essentially inheritance is just defining an interface here, which users can implement in whatever way they like.
Overall:
JAX is really helping with this, so a huge thank you for your wonderful project!!
Thanks so much for the kind words! That's our dream.
(You've written some very informative comments in them!)
Thanks for noticing! It can be hard work.
forcing the programmer to remember the type and to always use the appropriate functions is annoying
I'm not sure what the alternative being proposed is: don't you still have to remember which names are bound to arrays versus PRNG keys? Perhaps this is hard to discuss in the abstract, and we need concrete examples.
And then what do you do if you decide to use a different random number generator?
My thinking is we can cross that bridge when/if we get there.
vmap doesn't know how to broadcast the methods. I think what you end up having to do is vmap static methods and have the object call the vmapped methods?
Hmm, it's possible I don't know what you mean, but this seems to work:
from functools import partial
import jax.numpy as np
from jax import vmap
class A:
def __init__(self, val):
self.val = val
def foo(self, x):
assert np.ndim(x) == 0
return np.sin(x) + self.val
a = A(3)
out = vmap(a.foo)(np.arange(3))
print(out) # [3. 3.841471 3.9092975]
What other kind of vmapping-of-methods do you want?
One of my big questions for you is: is vmap ever called non-explicitly? That is, if I call vmap on a function f, could that vmap make other vmap calls on functions called by f?
Sorry, I'm not sure I understand. Maybe a concrete code example would help (ideally one with a real use case in mind!).
I understand that there are bigger priorities and you might want to defer this kind of problem so feel free to close this issue. If it's a question of work, I may be able to help out a bit.
I think it's best to close until we have some concrete real-use-case issues to pore over together. The discussion is a bit abstract, and it sounds like we have some pressing issues on both sides (ours on GitHub, yours in a PhD!).
Maybe we can leave it open until we pin down whether the above is the kind of vmapping-of-methods you had in mind.
@mattjj:
Sorry, I'm not sure I understand. Maybe a concrete code example would help (ideally one with a real use case in mind!).
Maybe I can give you (and anyone who wants it) access to what I've written? Would that be too much to look over (about 500 lines of JAX)? I could open-source the whole thing, but that makes me a bit uncomfortable.
What other kind of vmapping-of-methods do you want?
Well, what if you want vmap to also vectorize some (but not all) of the members like self.val? This is what I'm running into.
@shoyer:
I think a baseclass in jax.tree_util that allows for defining a few special methods for flattening/unflattening and is automatically registered as a pytree would be perfectly safe…
Yeah, I was really pleasantly surprised when a registered class passed as the carry to scan was automatically disassembled and reassembled by these tree routines so that the scan function could just work on objects of that class transparently. It might be nice to (I could if you want) add an example to the scan documentation or an example in the examples folder.
Figuring out what it means to "vmap an object" might be important if we want to support writing classes with methods. But the fact that it's confusing is probably an indication that sticking to pure functional programming could be a good idea....
You may be right, but would I be able to interest you in looking at the JAX code I've written to evaluate how you would write it? I am confused about the best way to organize code. For example, in a model whose inference is done by a call to scan, is there any harm in having the parameters in the carry? Does the compilation of the scan recognize that these parameters don't change, or should I move the parameters to a different object that is passed into the scan function?
When JAX works, it feels magical, but when it doesn't, I realize I really don't understand what it's doing underneath.
When JAX works, it feels magical, but when it doesn't, I realize I really don't understand what it's doing underneath.
I know you want to write object oriented code, but hopefully everything "just works" if you keep things fully functional?
I know you want to write object oriented code, but hopefully everything "just works" if you keep things fully functional?
Would you mind taking a look at my code? It's not that easy to write it fully functional.
While we like to work with users on code in general, what I think we're talking about in this issue is changes/additions to JAX's function transformations API to be OOP-y. That's just not on our roadmap. (You can always build your own OOPy interface on top of JAX however you like!)
I wish I had the bandwidth to dig into 500 lines of code right now, but unfortunately I don't. In general, the time it takes to get a response to an issue is inversely proportional to the length of the issue description (perhaps raised to some greater-than-1 power). If you can distill specific, concrete challenges you have, please open issues for those.
Well, what if you want vmap to also vectorize some (but not all) of the members like self.val? This is what I'm running into.
That's a pretty concrete question!
You can play a lot of Python games to automatically get at the real functions underneath your objects. Here's one example:
import jax.numpy as np
from jax import vmap
class A:
def __init__(self, y, z):
self.y = y # let's map over this one
self.z = z # but not this one
def foo(self, x):
return np.sin(x) + np.cos(self.y) + self.z
def my_vmap(method):
self = method.__self__
cls = self.__class__
def function(x, y, z):
return cls(y, z).foo(x)
return lambda x: vmap(function, (0, 0, None))(x, self.y, self.z)
out = my_vmap(A(np.arange(3.), 4).foo)(np.arange(3.))
print(out) # [5. 5.381773 4.4931507]
Of course, you can abstract that to be more general. And if you don't want to mess with Python's runtime object representation, you can just define your own methods/abstractions to do the same thing.
Let's close this issue, since its current framing (as indicated by the issue title) is pretty big and not something we're going to pursue, but please follow up with small concrete real-use-case issues :)
For example, in a model whose inference is done by a call to
scan, is there any harm in having the parameters in the carry? Does the compilation of thescanrecognize that these parameters don't change, or should I move the parameters to a different object that is passed into the scan function?
It can't hurt to make sure constant parameters like this are part of the scan's non-carried arguments (or are closed over by its body function) rather than being part of its carry. We don't do any optimization in JAX to detect constant-over-time values in carry position, so we'd be relying on XLA to perform loop-invariant code motion (and XLA's optimizations on loops are generally less mature than for straight-line code).
Most helpful comment
While we like to work with users on code in general, what I think we're talking about in this issue is changes/additions to JAX's function transformations API to be OOP-y. That's just not on our roadmap. (You can always build your own OOPy interface on top of JAX however you like!)
I wish I had the bandwidth to dig into 500 lines of code right now, but unfortunately I don't. In general, the time it takes to get a response to an issue is inversely proportional to the length of the issue description (perhaps raised to some greater-than-1 power). If you can distill specific, concrete challenges you have, please open issues for those.
That's a pretty concrete question!
You can play a lot of Python games to automatically get at the real functions underneath your objects. Here's one example:
Of course, you can abstract that to be more general. And if you don't want to mess with Python's runtime object representation, you can just define your own methods/abstractions to do the same thing.
Let's close this issue, since its current framing (as indicated by the issue title) is pretty big and not something we're going to pursue, but please follow up with small concrete real-use-case issues :)