Both Flax and Haiku implement a named_call JAX primitive. How about we move this into JAX core?
This is the same as JAX's call primitive (see jax.core.call_p), except in how it compiles to XLA, where the name that it carries is included in the HLO computation name. This makes it useful for annotating profiles.
cc @LenaMartens @tomhennigan @trevorcai @levskaya @jheek @avital
I'd be a big fan of doing this! I tried suggesting it before, but I think people were too busy to think about it at the time. It would be much better having such a simple, basic primitive live in a single place in JAX.
Yes, please!
As I understand it, there might be differences between what we'd upstream to JAX and what's currently in Haiku, namely due to Haiku's state threading (?). The Flax version is roughly what I had in mind originally.
The stateful_named_call in haiku looks like a user of the named_call_p that doesn't need to be upstreamed. I attempted a draft in PR #4733 .
Most helpful comment
I'd be a big fan of doing this! I tried suggesting it before, but I think people were too busy to think about it at the time. It would be much better having such a simple, basic primitive live in a single place in JAX.