Jax: Flattening function like in Autograd

Created on 3 Jan 2019  路  5Comments  路  Source: google/jax

I was reimplementing my old Autograd code in JAX, but I can't seem to find an equivalent of the "flatten" utility:
https://github.com/HIPS/autograd/tree/master/autograd/misc

Basically, given a list of the network's parameters w, I would need something like:

w_flat, unflattener = flatten(w)

where w_flat is a one-dimensional vector containing all parameters, and unflattener reverts this operation. If I understand correctly, this was used in Autograd at the beginning for the implementation of the optimizers, but it is not needed here. However, it is useful in many cases and also goes well with the idea of chainable transformations on data.

enhancement

Most helpful comment

I added some simple utilities in #201, so you should be able to write

from jax.flatten_util import ravel_pytree

w_flat, unflattener = ravel_pytree(w)

I changed the name from "flatten" to "ravel_pytree" because I think it's more descriptive and in-line with other names in JAX. At one point we had too many uses of "flatten" and "unflatten" to mean slightly different things.

Please re-open this issue if you have issues with that function, or if I missed the mark somehow.

All 5 comments

I added some simple utilities in #201, so you should be able to write

from jax.flatten_util import ravel_pytree

w_flat, unflattener = ravel_pytree(w)

I changed the name from "flatten" to "ravel_pytree" because I think it's more descriptive and in-line with other names in JAX. At one point we had too many uses of "flatten" and "unflatten" to mean slightly different things.

Please re-open this issue if you have issues with that function, or if I missed the mark somehow.

Is there a flatten_func analog?

@mattjj - I'm probably missing it, is a flatten_func analog baked into JAX somewhere? I may be missing it in the flatten_util (?) or tree_util files?

Thanks for the ping. My notification setup isn't great. When in doubt, open a new issue, since those are much more visible. It also makes things like feature requests easier to track.

We haven't added at flatten_func to JAX, but it's probably a good idea to put it in flatten_util.py and make it call ravel_pytree. I think it should look about the same as Autograd's. We could call it ravel_func or flatten_func, I don't have a strong opinion.

Want to send us a PR with that change?

Great! - will do.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

ricardobarroslourenco picture ricardobarroslourenco  路  35Comments

froystig picture froystig  路  34Comments

dionhaefner picture dionhaefner  路  22Comments

shoyer picture shoyer  路  35Comments

proteneer picture proteneer  路  22Comments