Context:
I would like to use JAX to express a bunch of tensor computations in numpy-ish syntax, but delay the actual execution of the computation until later -- ideally by registering the compiled function as a function that could be looked up from a shared lib. (the function would need to be called from a c++ program / library).
My initial idea was to:
.so using the approach in tensorflow/compiler/aotAssuming this approach makes sense (Please let me know if there is a better way), could you let me know how I could extract the XLA HLO during that second step?
Thanks for your interest in JAX!
Yes, I think something like this would make a lot of sense for, say, inference use cases that want to get Python out of the way. We've discussed things along these lines, but haven't done anything concrete yet.
One idea would be to add a new Python API jax.aot_compile (probably not that exact name), which, rather than running the computation immediately as JIT does, writes a .so file and .h file to disk that you can link into your code (or whatever language headers/wrappers seem appropriate). I think we could definitely improve on the ergonomics of tensorflow/compiler/aot!
If you'd like to try prototyping something along these lines, you might start from the undocumented function jax.xla_computation (https://github.com/google/jax/blob/master/jax/api.py#L155) which returns a Computation object from the XLA client. In particular, it has a method GetSerializedProto() that returns an xla.HloModule proto containing the computation (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py#L720)
PRs welcome!
Any updates?
Most helpful comment
Thanks for your interest in JAX!
Yes, I think something like this would make a lot of sense for, say, inference use cases that want to get Python out of the way. We've discussed things along these lines, but haven't done anything concrete yet.
One idea would be to add a new Python API
jax.aot_compile(probably not that exact name), which, rather than running the computation immediately as JIT does, writes a.sofile and.hfile to disk that you can link into your code (or whatever language headers/wrappers seem appropriate). I think we could definitely improve on the ergonomics oftensorflow/compiler/aot!If you'd like to try prototyping something along these lines, you might start from the undocumented function
jax.xla_computation(https://github.com/google/jax/blob/master/jax/api.py#L155) which returns aComputationobject from the XLA client. In particular, it has a methodGetSerializedProto()that returns anxla.HloModuleproto containing the computation (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py#L720)PRs welcome!