Jax: factor named_call primitive into jax core

Created on 13 Oct 2020  路  4Comments  路  Source: google/jax

Both Flax and Haiku implement a named_call JAX primitive. How about we move this into JAX core?

This is the same as JAX's call primitive (see jax.core.call_p), except in how it compiles to XLA, where the name that it carries is included in the HLO computation name. This makes it useful for annotating profiles.

cc @LenaMartens @tomhennigan @trevorcai @levskaya @jheek @avital

enhancement

Most helpful comment

I'd be a big fan of doing this! I tried suggesting it before, but I think people were too busy to think about it at the time. It would be much better having such a simple, basic primitive live in a single place in JAX.

All 4 comments

I'd be a big fan of doing this! I tried suggesting it before, but I think people were too busy to think about it at the time. It would be much better having such a simple, basic primitive live in a single place in JAX.

Yes, please!

As I understand it, there might be differences between what we'd upstream to JAX and what's currently in Haiku, namely due to Haiku's state threading (?). The Flax version is roughly what I had in mind originally.

The stateful_named_call in haiku looks like a user of the named_call_p that doesn't need to be upstreamed. I attempted a draft in PR #4733 .

Was this page helpful?
0 / 5 - 0 ratings