Jax: Support ahead of time compilation

Created on 4 Mar 2019  路  2Comments  路  Source: google/jax

Context:

I would like to use JAX to express a bunch of tensor computations in numpy-ish syntax, but delay the actual execution of the computation until later -- ideally by registering the compiled function as a function that could be looked up from a shared lib. (the function would need to be called from a c++ program / library).

My initial idea was to:

  • use jax.numpy for describing the computations
  • export the XLA HLO when jitting on materialized tensors with the shapes/types of interest
  • compile the XLA into executable functions and link into an .so using the approach in tensorflow/compiler/aot

Assuming this approach makes sense (Please let me know if there is a better way), could you let me know how I could extract the XLA HLO during that second step?

enhancement

Most helpful comment

Thanks for your interest in JAX!

Yes, I think something like this would make a lot of sense for, say, inference use cases that want to get Python out of the way. We've discussed things along these lines, but haven't done anything concrete yet.

One idea would be to add a new Python API jax.aot_compile (probably not that exact name), which, rather than running the computation immediately as JIT does, writes a .so file and .h file to disk that you can link into your code (or whatever language headers/wrappers seem appropriate). I think we could definitely improve on the ergonomics of tensorflow/compiler/aot!

If you'd like to try prototyping something along these lines, you might start from the undocumented function jax.xla_computation (https://github.com/google/jax/blob/master/jax/api.py#L155) which returns a Computation object from the XLA client. In particular, it has a method GetSerializedProto() that returns an xla.HloModule proto containing the computation (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py#L720)

PRs welcome!

All 2 comments

Thanks for your interest in JAX!

Yes, I think something like this would make a lot of sense for, say, inference use cases that want to get Python out of the way. We've discussed things along these lines, but haven't done anything concrete yet.

One idea would be to add a new Python API jax.aot_compile (probably not that exact name), which, rather than running the computation immediately as JIT does, writes a .so file and .h file to disk that you can link into your code (or whatever language headers/wrappers seem appropriate). I think we could definitely improve on the ergonomics of tensorflow/compiler/aot!

If you'd like to try prototyping something along these lines, you might start from the undocumented function jax.xla_computation (https://github.com/google/jax/blob/master/jax/api.py#L155) which returns a Computation object from the XLA client. In particular, it has a method GetSerializedProto() that returns an xla.HloModule proto containing the computation (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py#L720)

PRs welcome!

Any updates?

Was this page helpful?
0 / 5 - 0 ratings

Related issues

murphyk picture murphyk  路  31Comments

froystig picture froystig  路  34Comments

JuliusKunze picture JuliusKunze  路  23Comments

NeilGirdhar picture NeilGirdhar  路  23Comments

proteneer picture proteneer  路  53Comments