In the Deep Implicit Layers tutorial from NeurIPS 2020, a very nice grad_graph function was used to make things clearer, around minute 23.
Is that function available elsewhere? If not, could it be added to the library?
I think it would be generally useful as a way to explore how things work in Jax.
Thanks for this suggestion! It certainly seems worth considering.
Here's the version I used in the demo. It'll probably bit-rot unless we merge it into JAX.
I agree it seems like a useful tool for user exploration as well as for presentations, and it is pretty lightweight. I'm interested to hear other opinions though!
Longer twitter thread here: https://twitter.com/lukasheinrich_/status/1340042422223572993 and was asked to add in a (not-so-naive apparently) suggestion: https://twitter.com/SingularMattrix/status/1340306739191672832
It would be nice to have a way to label in jax (both input arrays as well as computational calls) so I can annotate such a graph. I guess an expected API might be like
x = jnp.array(..., label="observations")
y = jnp.sum(x, axis=0, label='calculate sum(logpdf)')
or similar.
Good to see the issue already opened :+1:
Here's also the variation from said twitter thread, reusing @mattjj's gist to graph any jaxpr: jaxpr_graph.py
Just found that jaxlib.xla_extension.XlaComputation.as_hlo_dot_graph exists and can be used for similar results:
import jax
import graphviz
def hlo_graph(f, *args, **kwargs):
comp = jax.xla_computation(f)(*args, **kwargs)
graph = graphviz.Source(comp.as_hlo_dot_graph())
return graph
which gives
import jax.numpy as jnp
f = lambda x: jnp.sum(x**2)
x = jnp.ones(5)
hlo_graph(jax.grad(f), x)

Indeed! Some brainstorming on the differences:
grad_graph can handle data-dependent Python control flow, unlike xla_computation (actually, this was the main desideratum for that demo in the implicit layers tutorial, since the fwd iteration implementation had a Python while loop in it!)Cool! Also, I see now that my jaxpr_graph example, where using jax.make_jaxpr, fails on data-dependent Python control flow, while grad_graph works fine with extracting the jaxpr directly from the vjp. Is there a better/similar-to-vjp way to get a traced jaxpr for this purpose?
Regarding labelled subcomputations, would this call for a dedicated Tracer already?
I certainly appreciate the hackability aspect! Learning a lot about jaxprs here.
Most helpful comment
Thanks for this suggestion! It certainly seems worth considering.
Here's the version I used in the demo. It'll probably bit-rot unless we merge it into JAX.
I agree it seems like a useful tool for user exploration as well as for presentations, and it is pretty lightweight. I'm interested to hear other opinions though!