Jax: xla in pmap fails (i.e. jit-of-pmap or lax.scan with collectives)

Created on 3 Jun 2019  路  5Comments  路  Source: google/jax

The parallel xla interpreter currently doesn't properly support nested jit compilation.
A practical example of this issue is when trying to use psum from within scan:

pmap( partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.), axis_name="i" )(np.ones((8, 4)))

Scan need to compiles the body of the loop using xla which fails because psum is only defined in the context of the pxla interpreter.

bug

All 5 comments

Thanks for raising this.

I tweaked the issue title for a canonical "jit-of-pmap" name. This is actually a case of jit-of-pmap because lax.scan does something like jit in its implementation, which is why it doesn't know about parallel collectives.

jit can be thought of as a special case of pmap (mapping over a non-existant singleton axis), and if we implemented it that way then this would all work automatically. But at the moment because the jit implementation predates pmap, it needs to learn about how to handle parallel primitives like psum (and including pmap itself).

For completeness, there is a related issue pmap-of-jit(fori_loop) and jit-of-jit(fori_loop) :)

def f(x):
    return jit(lax.fori_loop, static_argnums=(2,))(0, 10, lambda *args: args[1], x)

jit(f)(1.)

which throws

TypeError: No constant handler for type: <class 'jaxlib.xla_extension.XlaOp'>

Actually, that looks like a separate issue (and an easy one!); I'll open a new issue for it.

Fixed the issue @fehiepsi raised in #832.

Thanks @mattjj ! I didn't expect that it is fixed so fast. ^^

Was this page helpful?
0 / 5 - 0 ratings