Jax: Add graph_grad function

Created on 11 Dec 2020  路  6Comments  路  Source: google/jax

In the Deep Implicit Layers tutorial from NeurIPS 2020, a very nice grad_graph function was used to make things clearer, around minute 23.

Is that function available elsewhere? If not, could it be added to the library?

I think it would be generally useful as a way to explore how things work in Jax.

enhancement

Most helpful comment

Thanks for this suggestion! It certainly seems worth considering.

Here's the version I used in the demo. It'll probably bit-rot unless we merge it into JAX.

I agree it seems like a useful tool for user exploration as well as for presentations, and it is pretty lightweight. I'm interested to hear other opinions though!

All 6 comments

Thanks for this suggestion! It certainly seems worth considering.

Here's the version I used in the demo. It'll probably bit-rot unless we merge it into JAX.

I agree it seems like a useful tool for user exploration as well as for presentations, and it is pretty lightweight. I'm interested to hear other opinions though!

Longer twitter thread here: https://twitter.com/lukasheinrich_/status/1340042422223572993 and was asked to add in a (not-so-naive apparently) suggestion: https://twitter.com/SingularMattrix/status/1340306739191672832

It would be nice to have a way to label in jax (both input arrays as well as computational calls) so I can annotate such a graph. I guess an expected API might be like

x = jnp.array(..., label="observations")
y = jnp.sum(x, axis=0, label='calculate sum(logpdf)')

or similar.

Good to see the issue already opened :+1:
Here's also the variation from said twitter thread, reusing @mattjj's gist to graph any jaxpr: jaxpr_graph.py

Just found that jaxlib.xla_extension.XlaComputation.as_hlo_dot_graph exists and can be used for similar results:

import jax
import graphviz

def hlo_graph(f, *args, **kwargs):
    comp = jax.xla_computation(f)(*args, **kwargs)
    graph = graphviz.Source(comp.as_hlo_dot_graph())
    return graph

which gives

import jax.numpy as jnp
f = lambda x: jnp.sum(x**2)
x = jnp.ones(5)
hlo_graph(jax.grad(f), x)

Indeed! Some brainstorming on the differences:

  • grad_graph can handle data-dependent Python control flow, unlike xla_computation (actually, this was the main desideratum for that demo in the implicit layers tutorial, since the fwd iteration implementation had a Python while loop in it!)
  • the jaxpr IR has some differences with the XLA HLO IR, e.g. it shows custom derivative rules and some primitives like scan, so it's a slightly higher-level representation in some respects (while the HLO representation has the advantage of being more explicit about actually gets executed, especially the optimized HLO)
  • the colors, labels, etc are more hackable in the pure-Python version
  • the XLA HLO tools are much more developed than the quick hack in the gist

Cool! Also, I see now that my jaxpr_graph example, where using jax.make_jaxpr, fails on data-dependent Python control flow, while grad_graph works fine with extracting the jaxpr directly from the vjp. Is there a better/similar-to-vjp way to get a traced jaxpr for this purpose?

Regarding labelled subcomputations, would this call for a dedicated Tracer already?

I certainly appreciate the hackability aspect! Learning a lot about jaxprs here.

Was this page helpful?
0 / 5 - 0 ratings