The tree_flatten function currently returns a list of arrays, and this is indeed very useful to have in many cases.
Sometimes, it might be also useful to further flatten the already flattened list to a single contiguous 1d array, and then have the option to zip it up to a list of arrays again. This can be done with a trivial wrapper, but it might be useful to add an optional argument to tree_flatten which does this for convenience.
Here's a class which builds on tree_util to do this, assuming all leaves of the pytree [I mean the data stored in the tree] have the attributes .shape and .size.
import numpy as np
import jax.numpy as jnp
from jax import jit
from jax.tree_util import tree_structure, tree_flatten, tree_unflatten
class NN_Tree(object):
def __init__(self,pytree):
# assumes flattened elements are all arrays
flattened, self.tree = tree_flatten(pytree)
self.shapes=[]
self.sizes=np.zeros(len(flattened),dtype=int)
for j,x in enumerate(flattened):
self.shapes.append(x.shape)
self.sizes[j]=x.size
@jit
def flatten(self,data):
# flatten maximally
return jnp.concatenate([x.flatten() for x in self.tree.flatten_up_to(data)])
@jit
def unflatten(self,data):
# restore shapes
cum_size=0
data_flattened=[]
for j,size in enumerate(self.sizes):
data_flattened.append( jnp.array(data[cum_size:cum_size+size].reshape(self.shapes[j])) )
cum_size+=size
# unflatten
return self.tree.unflatten(data_flattened)
Thanks for the idea!
I think this functionality might already be in ravel_pytree in jax.flatten_util. Can you take a look at that? (Notice our strange and beautiful use of vjp to automatically generate the inverse function!)
Yes, this is precisely what jax.flatten_util does :), it's likely even more generic than my class
You should update the online docs whenever you get the chance to reflect the new functionality :)
Anyway, thanks for the quick reply, I'm closing this issue now.