Hi all,
I'm trying to use classes with jax and one of the problems I have is that I can't print a value that is manipulated in a JIT compiled class method. Example code:
import jax.numpy as np
from jax import jit
from functools import partial
class World:
def __init__(self, p, v):
self.p = p
self.v = v
@partial(jit, static_argnums=(0,))
def step(self, dt):
a = - 9.8
self.v += a * dt
self.p += self.v *dt
world = World(np.array([0, 0]), np.array([1, 1]))
for i in range(1000):
world.step(0.01)
print(world.p)
This prints Traced<ShapedArray(float32[2]):JaxprTrace(level=-1/1)>
I'm aware this is expected behavior when you print something inside the function, but this is not inside the function right?
More generally, I'm wondering if object oriented programming is well suited for jax? Should I avoid this kind of stuff? Is JIT capable of working optimally this way?
Thanks for your time!
Thanks for the question!
The issue is that the step method violates functional purity: it has a side-effect (of updating self.v and self.p). You can only use jit on pure functions. (Unfortunately, we don't have good ways of checking whether a function is pure and warning you; it's just an un-checked user promise.)
See also What's supported in the readme.
More generally, I'm wondering if object oriented programming is well suited for jax? Should I avoid this kind of stuff?
Objects are okay, but jit functions can't have side-effects, and that means the pattern in your example (which is pretty canonical Python OOP) won't work with jit.
(@dougalm your leak-detector could have caught this issue. I wonder if we should look into reviving it.)
WDYT?
Here are a couple styles that do work well with JAX:
import jax.numpy as np
from jax import jit
from collections import namedtuple
World = namedtuple("World", ["p", "v"])
@jit
def step(world, dt):
a = -9.8
new_v = world.v + a * dt
new_p = world.p + new_v * dt
return World(new_p, new_v)
world = World(np.array([0, 0]), np.array([1, 1]))
for i in range(1000):
world = step(world, 0.01)
print(world.p)
That's just a functional version of your code. The key is that step returns a new World, rather than modifying the existing one.
We can organize the same thing into Python classes if we want:
from jax.tree_util import register_pytree_node
from functools import partial
class World:
def __init__(self, p, v):
self.p = p
self.v = v
@jit
def step(self, dt):
a = -9.8
new_v = self.v + a * dt
new_p = self.p + new_v * dt
return World(new_p, new_v)
# By registering 'World' as a pytree, it turns into a transparent container and
# can be used as an argument to any JAX-transformed functions.
register_pytree_node(World,
lambda x: ((x.p, x.v), None),
lambda _, tup: World(tup[0], tup[1]))
world = World(np.array([0, 0]), np.array([1, 1]))
for i in range(1000):
world = world.step(0.01)
print(world.p)
The key difference there is that step returns a new World instance.
Here's one last pattern that works, using your original World class, though it's a bit more subtle:
class World:
def __init__(self, p, v):
self.p = p
self.v = v
def step(self, dt):
a = - 9.8
self.v += a * dt
self.p += self.v *dt
@jit
def run(init_p, init_v):
world = World(init_p, init_v)
for i in range(1000):
world.step(0.01)
return world.p, world.v
out = run(np.array([0, 0]), np.array([1, 1]))
print(out)
(That last one takes much longer to compile, because we're unrolling 1000 steps into a single XLA computation and compiling that; in practice we'd use something like lax.fori_loop or lax.scan to avoid those long compile times.)
The reason your original class works in that last example is that we're only using it under a jit, so the jit function itself doesn't have any side-effects.
Of those styles, I personally have grown to like the first. I wrote all my code in grad school in an OOP-heavy style, and I regret it: it was hard to compose with other code, even other code that I wrote, and that really limited its reach. Functional code, by forcing explicit state management, solves the composition problem. It's also a great fit for numerical computing in general, since numerical computing is much closer to math than, say, writing a web server.
Hope that's helpful :)
Hi mattjj
Thanks a lot for your elaborate answers! They're very enlightening.
It's interesting to hear your experiences from grad school, I'll try to learn from them.
I've been thinking about why exactly I want to use OOP, and the pros and cons in relation to JAX.
The main reason to use OOP in my case is because I'm building a physics engine, and OOP provides a neat way of defining objects in the scene with properties and states.
Eventually you get a global state vector with the states of all the objects in it, so I think I'll still try to use OOP to build up the scene and the objects, and define what's what in the global state vector.
But I could keep the vector itself out of the OOP.
That way, there's nothing in self that changes throughout the simulation. And jit would work, right?
Here are a couple styles that do work well with JAX:
import jax.numpy as np from jax import jit from collections import namedtuple World = namedtuple("World", ["p", "v"]) @jit def step(world, dt): a = -9.8 new_v = world.v + a * dt new_p = world.p + new_v * dt return World(new_p, new_v) world = World(np.array([0, 0]), np.array([1, 1])) for i in range(1000): world = step(world, 0.01) print(world.p)That's just a functional version of your code. The key is that
stepreturns a new World, rather than modifying the existing one.We can organize the same thing into Python classes if we want:
from jax.tree_util import register_pytree_node from functools import partial class World: def __init__(self, p, v): self.p = p self.v = v @jit def step(self, dt): a = -9.8 new_v = self.v + a * dt new_p = self.p + new_v * dt return World(new_p, new_v) # By registering 'World' as a pytree, it turns into a transparent container and # can be used as an argument to any JAX-transformed functions. register_pytree_node(World, lambda x: ((x.p, x.v), None), lambda _, tup: World(tup[0], tup[1])) world = World(np.array([0, 0]), np.array([1, 1])) for i in range(1000): world = world.step(0.01) print(world.p)The key difference there is that
stepreturns a newWorldinstance.Here's one last pattern that works, using your original
Worldclass, though it's a bit more subtle:class World: def __init__(self, p, v): self.p = p self.v = v def step(self, dt): a = - 9.8 self.v += a * dt self.p += self.v *dt @jit def run(init_p, init_v): world = World(init_p, init_v) for i in range(1000): world.step(0.01) return world.p, world.v out = run(np.array([0, 0]), np.array([1, 1])) print(out)(That last one takes much longer to compile, because we're unrolling 1000 steps into a single XLA computation and compiling that; in practice we'd use something like
lax.fori_looporlax.scanto avoid those long compile times.)The reason your original class works in that last example is that we're only using it under a
jit, so thejitfunction itself doesn't have any side-effects.Of those styles, I personally have grown to like the first. I wrote all my code in grad school in an OOP-heavy style, and I regret it: it was hard to compose with other code, even other code that I wrote, and that really limited its reach. Functional code, by forcing explicit state management, solves the composition problem. It's also a great fit for numerical computing in general, since numerical computing is much closer to math than, say, writing a web server.
Hope that's helpful :)
Just wanted to let you know that this answer allowed me to convert from an OOP to a functional mindset for the first time! Was v helpful and illustrative, and I can now @jax.jit everything! ;)
I believe we can close this issue, will keep the useful_read label.
Otherwise why not add support for attr classes? They're basically namedtuples under steroids.
Looks to me like it would need a small contrib to the pytree file, mostly at lines:
Then you'd be able to do something like:
@attr.s
class World:
p: np.ndarray = attr.ib()
v: np.ndarray = attr.ib()
@jax.jit
def step(self, dt):
a = - 9.8
v = a * dt
p = self.v *dt
return attr.evolve(self, p=p, v=v)
Cause World instances would be recognized as acceptable arguments and outputs of jitted functions.
EDIT:
Might even help to do AOT compilation 脿 la numba.
EDIT 2:
Additionally you could verify that the methods are pure by enforcing that the class be declared as being frozen.
Most helpful comment
Here are a couple styles that do work well with JAX:
That's just a functional version of your code. The key is that
stepreturns a new World, rather than modifying the existing one.We can organize the same thing into Python classes if we want:
The key difference there is that
stepreturns a newWorldinstance.Here's one last pattern that works, using your original
Worldclass, though it's a bit more subtle:(That last one takes much longer to compile, because we're unrolling 1000 steps into a single XLA computation and compiling that; in practice we'd use something like
lax.fori_looporlax.scanto avoid those long compile times.)The reason your original class works in that last example is that we're only using it under a
jit, so thejitfunction itself doesn't have any side-effects.Of those styles, I personally have grown to like the first. I wrote all my code in grad school in an OOP-heavy style, and I regret it: it was hard to compose with other code, even other code that I wrote, and that really limited its reach. Functional code, by forcing explicit state management, solves the composition problem. It's also a great fit for numerical computing in general, since numerical computing is much closer to math than, say, writing a web server.
Hope that's helpful :)