Jax: add maximally flatten option for tree_flatten

Created on 30 Dec 2019  路  2Comments  路  Source: google/jax

The tree_flatten function currently returns a list of arrays, and this is indeed very useful to have in many cases.

Sometimes, it might be also useful to further flatten the already flattened list to a single contiguous 1d array, and then have the option to zip it up to a list of arrays again. This can be done with a trivial wrapper, but it might be useful to add an optional argument to tree_flatten which does this for convenience.

Here's a class which builds on tree_util to do this, assuming all leaves of the pytree [I mean the data stored in the tree] have the attributes .shape and .size.

import numpy as np
import jax.numpy as jnp
from jax import jit
from jax.tree_util import tree_structure, tree_flatten, tree_unflatten

class NN_Tree(object):

    def __init__(self,pytree):
        # assumes flattened elements are all arrays

        flattened, self.tree = tree_flatten(pytree)

        self.shapes=[]
        self.sizes=np.zeros(len(flattened),dtype=int)

        for j,x in enumerate(flattened):

            self.shapes.append(x.shape)
            self.sizes[j]=x.size

    @jit
    def flatten(self,data):
        # flatten maximally
        return jnp.concatenate([x.flatten() for x in self.tree.flatten_up_to(data)])

    @jit
    def unflatten(self,data):
        # restore shapes
        cum_size=0
        data_flattened=[]
        for j,size in enumerate(self.sizes):
            data_flattened.append( jnp.array(data[cum_size:cum_size+size].reshape(self.shapes[j])) )
            cum_size+=size

        # unflatten
        return self.tree.unflatten(data_flattened)
enhancement question

All 2 comments

Thanks for the idea!

I think this functionality might already be in ravel_pytree in jax.flatten_util. Can you take a look at that? (Notice our strange and beautiful use of vjp to automatically generate the inverse function!)

Yes, this is precisely what jax.flatten_util does :), it's likely even more generic than my class

You should update the online docs whenever you get the chance to reflect the new functionality :)

Anyway, thanks for the quick reply, I'm closing this issue now.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

sursu picture sursu  路  3Comments

sschoenholz picture sschoenholz  路  3Comments

clemisch picture clemisch  路  3Comments

rdaems picture rdaems  路  3Comments

zhongwen picture zhongwen  路  3Comments