It would be great to support calls into Numba via an XLA CustomCall (which works inside jax.jit). This would let you use Numba as an alternative for writing low-level kernels in JAX.
xref https://github.com/google/jax/issues/81, https://github.com/google/jax/issues/1100
We are also interested how this would work. (cc: @sklam, @esc, @stuartarchibald)
Looks like this could go via Numba's @cfunc machinery.
For examples of how JAX wraps XLA's CustomCall interface, take a look at lapack.pyx and cusolver.cc/cusolver.py in jaxlib:
https://github.com/google/jax/tree/master/jaxlib
Ah very cool! Does this work on TPUs as well?
We've been a little curious how one generates custom functions for the TPU as well. We haven't seen an LLVM backend for the TPU anywhere, so I assume the toolchain is not public?
I've been having a little play to see how doable this is. I've got the basics of Numba for CPU working.
import jax
import jax.numpy as jnp
from jax import lax
from jax.lib import xla_bridge as xb
from jax.lib import xla_client as xc
import functools
import numba
prim = jax.core.Primitive("Foo")
prim.def_impl(functools.partial(jax.interpreters.xla.apply_primitive, prim))
def shape_rule(aval): return aval.shape
def dtype_rule(aval): return aval.dtype
prim.def_abstract_eval(functools.partial(jax.lax.standard_abstract_eval, prim, shape_rule, dtype_rule))
def encapsulate(address):
import ctypes
PyCapsule_Destructor = ctypes.CFUNCTYPE(None, ctypes.py_object)
PyCapsule_New = ctypes.pythonapi.PyCapsule_New
PyCapsule_New.restype = ctypes.py_object
PyCapsule_New.argtypes = (ctypes.c_void_p, ctypes.c_char_p, PyCapsule_Destructor)
capsule = PyCapsule_New(address, b"xla._CUSTOM_CALL_TARGET", PyCapsule_Destructor(0))
return capsule
# Is it called a kernel on CPU?
@numba.cfunc(numba.types.void(numba.types.voidptr, numba.types.CPointer(numba.types.voidptr)))
def kernel_cpu(output, inbuffers):
size = numba.carray(inbuffers[1], 1, numba.types.int32)[0] # TODO: Find the proper way to do this
input = numba.carray(inbuffers[0], size, numba.types.float32)
output = numba.carray(output, size, dtype = numba.types.float32)
for i in range(size): output[i] = input[i] + 1
xc.register_custom_call_target("test", encapsulate(kernel_cpu.address), platform = 'cpu')
def translation_cpu(builder: jax.lib.xla_client.XlaBuilder, op):
shape = builder.get_shape(op)
assert shape.rank() == 1
size_const = xb.constant(builder, shape.dimensions()[0])
return xc.ops.CustomCall(builder, b'test', operands = (op, size_const), shape = shape)
jax.interpreters.xla.backend_specific_translations['cpu'][prim] = translation_cpu
# TEST
prim.bind(7 * jnp.ones(20, dtype=jnp.float32))
It took a bit of fiddling about, but it's pretty easy in the end!
GPU, however, is proving a little more tricky.
Using Numba, I can make a kernel and get a handle for it.
@numba.cuda.jit("void(float32[:], float32[:])")
def kernel_gpu(input, output):
i = numba.cuda.grid(1)
output[i] = input[i] + 1
handle = kernel_gpu[1,10]._func.get().handle
I can also apply this kernel directly to JAX arrays through the standard Numba API. However, as JAX arrays are not writable and kernels can't return anything, I'm unable to get any output using this method. I've tested using a Numpy array as outputs and a JAX array as inputs, and that seems to work fine.
The step required to make this all work together is invoking the GPU kernel from CustomCall. This XLA page gives a good overview of the basic structure, and this cusolver source file in the JAX source is an example implementation.
I'm not sure what the best next step is here. I presume that the kernel invocation is going to need to be implemented in C++ proper rather than Numba, although I'm curious as to whether I can make Cython work. The only reason for my reluctance to do this properly is that I've done this all in Colab so far, so getting multiple files and compilers involved is a (probably necessary) step up in project complexity.
The cusolver.cc reference makes it all look quite complicated. There's manual device memory copying and stuff going on and I was hoping that things wouldn't need to get that complex, although maybe they do. I haven't grokked that file in detail yet. I was hoping that there may be a way to hijack the fact that Numba already does most of what we want under the hood - it already knows how to call the kernel with JAX inputs.
In theory all we need to implement is the equivalent of this sample from the XLA page that I linked:
void do_custom_call(CUstream stream, void** buffers,
const char* opaque, size_t opaque_len) {
const float* in0 = reinterpret_cast<const float*>(buffers[0]);
const float* in1 = reinterpret_cast<const float*>(buffers[1]);
float* out = reinterpret_cast<float*>(buffers[2]);
const int64 block_dim = 64;
const int64 grid_dim = 2048 / block_dim;
custom_call_kernel<<<grid_dim, block_dim,
/*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}
This makes it seem like it may actually be quite easy! However, that sample is running through the NVCC compiler for the invocation and it's all getting a bit too much for my current approach where I'm working out of a Jupyter notebook! I think we're probably at the point where this needs to be a JAX source modification to be practical, although I'm not sure.
I'd appreciate feedback from knowledgeable people about what the best direction to take this in next is! Is it feasible to write a small C++ utility to handle the kernel invocation? Does this machinery already exist in JAX somewhere? Do I need to get NVCC involved? Is this something that it makes sense to attempt externally as I'm doing, or would it be much smoother to patch this into the JAX source?
This may be as far as I can justify taking this distraction for now, so I thought I would write up what I've done to at least inspire somebody else to take it further. I might do more if it feels approachable, although I'm quite far off track from the model I'm supposed to be training!
Hello all,
I have been coincidentally also working on this, just pushed the WIP to my fork
https://github.com/josipd/jax/blob/master/jax/experimental/jambax.py
Would it make sense for you to collaborate to avoid duplicate work? There are quite a few additional things that can be done (better), e.g. better automatic batching, understanding if the interface can be improved, and quite importantly CUDA support, which I have not touched at all, but you have made some progress.
That's interesting @josipd! It looks like we've done roughly the same thing, although you've done a bit more towards an API - mine was only a technical exploration.
I have some vague ideas for what a nice - slightly higher level - API might look like, but I've not put them together or tested them yet. I have this idea in the back of my mind that for certain types of kernels - particularly pointwise and pointwise like - it may be possible to share an implementation between CPU and GPU. I'd like to play with that but I haven't got that far in my explorations yet.
For me this is only actually useful if it runs on GPU. My CPU implementation was only an exploratory stepping stone to that. My CUDA progress so far is minimal - I've basically just read some documentation.
I'm not familiar with how NVCC works and that whole side of setting things up, or indeed if that's actually necessary. That became a significant hurdle for me building an independent implementation, but if we went down the route of forking it may turn out to be very simple to just add extra CUDA code into the existing project infrastructure.
I'm on a Windows machine and that's not supported for building, which was another reason I stopped where I did.
I'm interested in collaborating if we're both working on this, but I haven't yet decided to commit more time to it.
On our side, we are quite interested in CPU support. One of our motivation for a Numba / JAX bridge was that Numba can generate very efficient code for nested loops (typically, code like this or this). Also Numba supports native Python for loops, while in JAX one would need to use jax.lax or jax.experimental.loops. Is your main use case to write low-level CUDA kernels in Numba?
That's interesting @mblondel. I was almost wondering whether CPU support is even worth doing, but clearly from your use case it is. It sounds like it's definitely worth tidying my and @josipd's experiments up into a nice API.
For me, my main interest in custom kernels is performance. There are situations where JAX's JIT is significantly slower than what I can hand code and I'd like to be able to experiment with throwing a custom solution together. These would usually be quite small kernels that represent custom layers. I'd be defining them inline in a notebook as part of a model definition so brevity and ease are important. That's what makes Numba ideal.
I think that I may have actually made some progress on GPU support! I've got some separate parts that I think should work in theory, but I haven't put them all together and tested it yet, so we'll see. It's probably too much of a hack to be used as an official solution; I'm patching into Numba's internal kernel invocation machinery and I don't think it's considered a public or stable API. It may be worth publishing as a plugin until somebody can come along and put a more official solution together in JAX itself. I'll update this thread with details when I've taken it a little further.
I'm currently blocked on this Numba issue and a few other fiddly bits trying to wrangle Numba into doing what I want. This hack may prove more trouble than it's worth. I think it's clear that the proper way to do this is to integrate the kernel launching code into JAX properly but I don't think I'm the right person for that task. It's probably quite easy, so if somebody from the team with the right experience to pull it off would like to add that to jaxlib that would be great! Although...
Even the process of obtaining the kernel handle from Numba is proving problematic. The handle = kernel_gpu[1,10]._func.get().handle method I posted above breaks when upgrading Numba from 0.48.0 to 0.52.0. Clearly that's not a stable API. I think we would need Numba to expose something for us.
If we're relying on getting changes made to Numba anyway, there may be a smoother way forward. About a year ago, @seibert mentioned here that they would like to enable Numba cfuncs to launch Numba CUDA kernels. I believe that would cover everything that we need to do this. The CustomCall could go into a standard Numba cfunc in the same way as I've done for the CPU, and that cfunc would call the Numba CUDA kernel. It's neat and allows this to be implemented in JAX with only Python changes.
@seibert, if you'd still like to get that put in then I think this becomes very easy. I don't think there's a feature request open for that on the Numba github yet so I'll probably add one soon and link back to here.
If that Numba change happened, I would probably be able to take on putting in a JAX PR for a nice wrapper API if we want that, as well as documentation. That appears to be what @josipd is also working on, so maybe we could collaborate on that.
Most helpful comment
I've been having a little play to see how doable this is. I've got the basics of Numba for CPU working.
It took a bit of fiddling about, but it's pretty easy in the end!
GPU, however, is proving a little more tricky.
Using Numba, I can make a kernel and get a handle for it.
I can also apply this kernel directly to JAX arrays through the standard Numba API. However, as JAX arrays are not writable and kernels can't return anything, I'm unable to get any output using this method. I've tested using a Numpy array as outputs and a JAX array as inputs, and that seems to work fine.
The step required to make this all work together is invoking the GPU kernel from
CustomCall. This XLA page gives a good overview of the basic structure, and this cusolver source file in the JAX source is an example implementation.I'm not sure what the best next step is here. I presume that the kernel invocation is going to need to be implemented in C++ proper rather than Numba, although I'm curious as to whether I can make Cython work. The only reason for my reluctance to do this properly is that I've done this all in Colab so far, so getting multiple files and compilers involved is a (probably necessary) step up in project complexity.
The
cusolver.ccreference makes it all look quite complicated. There's manual device memory copying and stuff going on and I was hoping that things wouldn't need to get that complex, although maybe they do. I haven't grokked that file in detail yet. I was hoping that there may be a way to hijack the fact that Numba already does most of what we want under the hood - it already knows how to call the kernel with JAX inputs.In theory all we need to implement is the equivalent of this sample from the XLA page that I linked:
This makes it seem like it may actually be quite easy! However, that sample is running through the NVCC compiler for the invocation and it's all getting a bit too much for my current approach where I'm working out of a Jupyter notebook! I think we're probably at the point where this needs to be a JAX source modification to be practical, although I'm not sure.
I'd appreciate feedback from knowledgeable people about what the best direction to take this in next is! Is it feasible to write a small C++ utility to handle the kernel invocation? Does this machinery already exist in JAX somewhere? Do I need to get NVCC involved? Is this something that it makes sense to attempt externally as I'm doing, or would it be much smoother to patch this into the JAX source?
This may be as far as I can justify taking this distraction for now, so I thought I would write up what I've done to at least inspire somebody else to take it further. I might do more if it feels approachable, although I'm quite far off track from the model I'm supposed to be training!