I was reimplementing my old Autograd code in JAX, but I can't seem to find an equivalent of the "flatten" utility:
https://github.com/HIPS/autograd/tree/master/autograd/misc
Basically, given a list of the network's parameters w, I would need something like:
w_flat, unflattener = flatten(w)
where w_flat is a one-dimensional vector containing all parameters, and unflattener reverts this operation. If I understand correctly, this was used in Autograd at the beginning for the implementation of the optimizers, but it is not needed here. However, it is useful in many cases and also goes well with the idea of chainable transformations on data.
I added some simple utilities in #201, so you should be able to write
from jax.flatten_util import ravel_pytree
w_flat, unflattener = ravel_pytree(w)
I changed the name from "flatten" to "ravel_pytree" because I think it's more descriptive and in-line with other names in JAX. At one point we had too many uses of "flatten" and "unflatten" to mean slightly different things.
Please re-open this issue if you have issues with that function, or if I missed the mark somehow.
Is there a flatten_func analog?
@mattjj - I'm probably missing it, is a flatten_func analog baked into JAX somewhere? I may be missing it in the flatten_util (?) or tree_util files?
Thanks for the ping. My notification setup isn't great. When in doubt, open a new issue, since those are much more visible. It also makes things like feature requests easier to track.
We haven't added at flatten_func to JAX, but it's probably a good idea to put it in flatten_util.py and make it call ravel_pytree. I think it should look about the same as Autograd's. We could call it ravel_func or flatten_func, I don't have a strong opinion.
Want to send us a PR with that change?
Great! - will do.
Most helpful comment
I added some simple utilities in #201, so you should be able to write
I changed the name from "flatten" to "ravel_pytree" because I think it's more descriptive and in-line with other names in JAX. At one point we had too many uses of "flatten" and "unflatten" to mean slightly different things.
Please re-open this issue if you have issues with that function, or if I missed the mark somehow.