Thanks for this project! Looking forward to using it more.
This is a feature request, feel free to close if this is not a good place to track:
I'd love to be able to export Tensorflow Ops (.so) from functions defined via JAX. The main use case is for embedding these functions in a serving context. For training this is is less necessary bc the two systems can interact at the python level, though I'm not clear on how to eliminate memory copies in that scenario.
Ideally the API would be something like passing in a tf.placeholder to the function, or otherwise using the annotations being introduced in TF 2.0. Would be fine if this was a separate package to avoid direct dependency on TF in JAX.
Thanks!
Great idea! Some version of this should be possible. One way to do it might be to use XRT to call JAX-generated XLA computations directly in a TF graph, though I'm not sure if there would be overheads there that wouldn't make sense for a serving use-case. Another might be to have XLA generate AOT-style .so files, as you suggested, which may be possible. Yet another possibility is to have JAX export TF graphs directly, which wouldn't be as big of a lift as it might sound.
@froystig and @hawkinsp have much more expertise here, and might be able to say more. Anything to correct with the above guesswork?
Thanks a lot for considering this. From a UX POV embedding directly into the graph is even better than generating a separate .so. I've tried to take a look at XRT but its above my paygrade as it were.. if thats possible that would be quite slick. But really any solution path would be great. Thanks again.
Super interested in this, I'm investigating JAX as an additional frontend to optorch - it allows you to write nonlinear losses as TorchScript modules, which are then saved and loaded into C++ with the torch::jit runtime. A JAX frontend would have probably have better performance for Jacobian evaluation. Is there currently any way to define a function, jit it, and export the TF frozen graph?
edit: I see it's part of the 0.2 milestones, and probably not possible yet - I don't have that much experience with TF backend, but if there's some python/c++ grunt work to be done feel free to kick it my way!
JAX doesn't involve any TF graphs, but as of #853 you can dump an XLA HLO proto at build time (e.g. with a bazel build rule) which can then be JIT-compiled and run entirely from C++ (without involving any Python at run-time).
Can you take a look at the docstrings in that PR and see if that might work for you?
Beautiful! Perfect! Thanks @mattjj @jlebar!
HLO dumping takes us a decent distance to solving this, nice!
Is it possible to load the HLO & execute using the Xrt ops in TF? If so an example would be awesome :)
Hi, re-upping this.
Last I checked I was not able to figure out a way to load the HLO via the python TF APIs.
One reason to want to do this is -- and maybe I'm wrong here -- is that AFAICT from the docs, running both TF and JAX will lead to competition for GPU memory.
Given the advent of tf.function, maybe there is some totally transparent way to integrate the two systems?
Thanks!
I think your best path to trying this right now would be to combine two features: jax.experimental.jax2tf which exports JAX computations in a form that can be consumed by TF, and tf.function(experimental_compile=True), which invokes XLA from TF to compile a function using XLA. Try it out?
On a tangent, when I was working on this last year I had a PR: https://github.com/tensorflow/tensorflow/pull/30520 for the tfcompile cli tool to accept HLO instead of TF graphs. Unfortunately priorities shuffled and it fell by the wayside.
I'm poking around here again, and just wondering what was the purpose of writing a completely new jaxpr -> tf transformation? Obviously, GraphDefs are accepted pretty commonly across tools, but what are the tradeoffs like vs the HLO IR?
I guess my real question is, if we wanted to generate the .so and .h, should it go through HLO or is there a good reason to use TF?
Most helpful comment
Hi, re-upping this.
Last I checked I was not able to figure out a way to load the HLO via the python TF APIs.
One reason to want to do this is -- and maybe I'm wrong here -- is that AFAICT from the docs, running both TF and JAX will lead to competition for GPU memory.
Given the advent of tf.function, maybe there is some totally transparent way to integrate the two systems?
Thanks!