Jax: Prevent custom calls with side effects to be optimized out

Created on 23 Jul 2020  Â·  22Comments  Â·  Source: google/jax

I am currently experimenting with implementing MPI send / recv as custom XLA calls.

It works fine in most cases, but a function like this leads to a deadlock:

@jax.jit
def send_recv(x):
    if rank == 0:
        x = Recv(x, comm=comm)
    else:
        Send(x, 0, comm=comm)
        # works if doing x = Send(x, 0, comm=comm)
    return x

I guess this is because the return value of Send is not used in the computational graph, so the whole call is optimized away, despite having side effects.

Is there a way to prevent this?

enhancement

Most helpful comment

Yes, that's right for now, but with #3370 neither tokens nor tie_in will be necessary.

All 22 comments

I think the bigger issue is that jit can't stage out MPI calls, not only for tracing reasons but more importantly because those calls don't exist in XLA.

The clearest path forward would be to jit smaller functions, and leave the MPI calls outside the jitted functions so they execute in Python. Is that an option?

@mattj Nope, that's not a problem.

I managed to write a small package defining xla custom ops in Cython for MPI operations.
At the moment the only supported one is Allreduce, which works very well.

We'd like to implement also other stuff like send/recv, as @dionhaefner pointed out above.

(The rationale behind this is that in some complicated distributed batched MatMul code, I see remarkable speed-ups with this approach compared to jitting only part of the functions. Also because if you want to AD or run those functions through a linear optimiser like jag's CG you have no choice but to jit the whole function)

It actually turns out the XLA already supports this but we haven't plumbed it through as part of our Python bindings. It should be an easy fix, but will require a jaxlib rebuild.

Is it something you'd consider doing?
Or if not, could you point us in the good direction? I cannot find any mention of this in xla's documentation.

I'm preparing a change to add this now.

Wow, awesome work on mpi4jax! Thanks for clarifying; I missed that point in @dionhaefner 's original message, though I see it now.

@hawkinsp in addition to the XLA update, we still have a tracing issue to fix, right? That is, we basically need #3370 to land, or else some other way to ensure the Send is actually staged out from JAX.

We don't necessarily need #3370; if nothing else we can use XLA tokens here.

Right, we need to fix the tracing issue, which could be #3370 or could be adding tokens to the code. But to support the code as written in the OP, i.e. without tokens, we need #3370.

I'm sorry but I'm not sure I follow.
What are tokens? What is the issue you are referring to?

Because for all reduce (so an operation that always has an output) our approach already works.
The issue here is that operations with no (local) side effect such as send can be optimised out of the IR.

What are tokens? What is the issue you are referring to?
The issue here is that operations with no (local) side effect such as send can be optimised out of the IR.

There are essentially two places that dead code elimination (DCE) may happen, leading to pruning of operations that (as far as the system currently understands) have no effect on the result of the computation. One is the Python tracing mechanism and has nothing to do with XLA or jaxlib:

In [1]: def f(x):
   ...:     y = x + x
   ...:     return x + 1
   ...:

In [2]: from jax import make_jaxpr

In [3]: make_jaxpr(f)(2)
Out[3]:
{ lambda  ; a.
  let b = add a 1
  in (b,) }

The other is XLA, which will prune operations unless (a) the value of the computation has a data dependence on the operation, or (b) the operation is marked as side-effectful.

I believe the fix @hawkinsp alluded to is about the latter, essentially adding a way to label XLA CustomCalls as side-effecting so that XLA doesn't prune them.

My point above is that we're still left with the former, i.e. the JAX tracing issue. The way to solve that on the current master branch is to add 'token' values which we thread into and out of side-effecting operations and on which the final result has a fake data dependence; there are more details to unpack you can see an example in the infeed tests. But on the #3370 branch we don't need tokens for the JAX side anymore, as you can see on that branch's version of infeed_test.py. This all has to do with the JAX Python side; in any case we needed the fix Peter added in XLA.

Does that make sense?

3834 adds a new has_side_effects argument to CustomCall, which handles the XLA side of this.

You can also work around both the XLA and JAX-side issues by adding a dummy output to your operator (e.g, make it return a a scalar that you return.)

By the way, tokens are also useful for sequencing things. That is, JAX jit tracing may reorder operations when there is no data dependence between them (until #3370, which preserves Python execution order), and similarly XLA may reorder operations with no data dependence (one must use tokens for that AIUI).

I think that’s all that is needed.
We already define a custom Jax primitive, and when you call ‘Send’ it is binded to the input. The only problem was that this primitive has no output (or we don’t use it) so cal was removing it...

Thanks a lot!
Any chance we can get a new release of Jaxlib somewhat soon?

--
Filippo Vicentini
CCQ Research Fellow
Flatiron Institute, New York
Google Scholar
Il 23 lug 2020, 18:49 +0200, Matthew Johnson notifications@github.com, ha scritto:

By the way, tokens are also useful for sequencing things. That is, JAX jit tracing may reorder operations when there is no data dependence between them (until #3370, which preserves Python execution order), and similarly XLA may reorder operations with no data dependence (one must use tokens for that AIUI).
—
You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub, or unsubscribe.

We just made one yesterday :-( I suggest building from source for now?

I think that’s all that is needed.
We already define a custom Jax primitive, and when you call ‘Send’ it is binded to the input.

Actually no, jit tracing will drop primitives that are bound if the output of the jitted computation doesn't have a data dependence on the result of the primitive. That's what my example above with make_jaxpr was meant to show: even though we call two adds, we only build a jaxpr (and then an XLA computation) containing the one that affected the output.

Concretely, the code in the OP will still drop the Send call. It'll get dropped on the JAX side, before XLA even has a chance to see it.

Here's a more self-contained example if you prefer:

from jax.core import Primitive

def Send(x):
  return send_p.bind(x)

send_p = Primitive('send')
send_p.def_abstract_eval(lambda x: None)


def f(x):
  Send(x)

from jax import make_jaxpr
print(make_jaxpr(f)(2))
{ lambda  ; a.
  let
  in () }

You can change the make_jaxpr to a jit and add a print(built.as_hlo_text()) after this line in xla.py if you want to convince yourself that XLA will never see the bound primitive.

Thanks a bunch for the explanation and the quick fix! I will try this ASAP and report back.

OK, so it seems like there is currently no way to preserve the Send call in nested JIT calls?

I.e.

def Send(x, dest, tag=0, comm=_MPI.COMM_WORLD):
    token = lax.create_token(x)
    token = _Send(token, x, dest, tag, comm)
    return lax.tie_in(token, x)

def Send_nested(x, dest, tag=0, comm=_MPI.COMM_WORLD):
    Send(x, dest, tag, comm)

print(jax.make_jaxpr(Send)(jnp.zeros(1), 0))
print(jax.make_jaxpr(Send_nested)(jnp.zeros(1), 0))

gives

{ lambda  ; a b.
  let c = create_token a
      d = send_mpi[ comm=<mpi4py.MPI.Intracomm object at 0x10928e5d0>
                    dest=Traced<ShapedArray(int32[], weak_type=True):JaxprTrace(level=0/0)>
                    tag=0 ] c a
      e = tie_in d a
  in (e,) }

{ lambda  ; a b.
  let
  in () }

So all user code that uses our Send implementation would have to include some boilerplate with create_token and tie_in.

Yes, that's right for now, but with #3370 neither tokens nor tie_in will be necessary.

Looking forward to that! Do you have any estimate when omnistaging is going to hit master?

This works beautifully with omnistaging, thanks!

I can also confirm that it does not work without doing has_side_effect=True, so both changes were necessary. Great job @mattjj and @hawkinsp!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

christopherhesse picture christopherhesse  Â·  32Comments

JuliusKunze picture JuliusKunze  Â·  23Comments

NeilGirdhar picture NeilGirdhar  Â·  23Comments

shoyer picture shoyer  Â·  35Comments

martiningram picture martiningram  Â·  21Comments