Jax: report: jit-compilation on list of matrices

Created on 6 Jan 2019  路  16Comments  路  Source: google/jax

@mattjj this is a follow-up from our video chat last week. Just wanted to report back on this.

You were indeed right with JIT-compilation of multiple dot products between zipped lists of matrices.

I wrote a forward computation, which looks like this:

def identity(x):
    return x

def mp(params, As, Fs, nonlin=identity):
    """my hack for doing message passing on multiple graphs in one logical forward computation step."""
    outputs = []
    for a, f in zip(As, Fs):
        output = np.dot(a, f)
        output = np.dot(params['w'], output) + params['b']
        output = nonlin(output)
        outputs.append(output)
    return outputs

With 5000+ graphs (i.e. 5000+ pairs of As and Fs), initial compilation time took 3 hours. Subsequent forward passes took less than a few hundred milliseconds.

Nothing more here w.r.t. this issue, just wanted to let you know about it. Please feel free to close whenever :smile:.

application

Most helpful comment

PR https://github.com/google/jax/pull/229 improves the compilation time of batched matrix multiplication on the CPU backend; compilation for large batch sizes should be quick now. (Credit goes to @sanjoy )

All 16 comments

Thanks for this report! It's a very interesting use case. The compilation is slow (I presume) because we're unrolling a big computation graph and handing it to XLA, which does a lot of whole-program analysis for heavy optimization but, as a drawback, ends up spending a long time in that analysis for very big unrolled graphs.

One way to improve that on the JAX side would be to use a loop construct instead of unrolling everything before handing it to XLA. Our loop constructs are still unpolished, but it might be worth trying them out here.

Any chance we could make some synthetic As and Fs so that I can make model of this computation, and try out a loop construct to reduce the compile times? What are the shapes like here?

Also, out of curiosity, how does a few hundred milliseconds compare to what you might expect for evaluation time? Is that fast enough to be useful, once we bring down that crazy compile time?

For scale reference, a few hundred milliseconds is insanely awesome! :smile: My usual forward compute times are on the order of seconds, so this is at least 10X faster.

I'm happy to try out the loop constructs with the same datasets. UPDATE: Couldn't find docs in the README, where could I find them?

A simple way to generate graphs is by using NetworkX. (I am writing this from memory, pardon any syntax errors - upcoming on my calendar is watching a dance performance in a few minutes!)

import networkx as nx
from random import randint
import numpy as np

As = []
Fs = []
for i in range(2000):  # make 2000 graphs
    n = randint(10, 20)   # make graph with anywhere between 10 and 20 nodes
    G = nx.erdos_renyi_graph(n, p=0.3)
    F = np.random.normal(size=(n, 4))  # make `n` node features, 4 columns.
    A = nx.to_numpy_array(G)
    As.append(A)
    Fs.append(F)

@mattjj just to confirm, are the loop constructs in lax.py?

Yes, but I wouldn't recommend trying them yet. There's lax.fori_loop and lax._while_loop, but those versions are meant for internal use only. We're working on new versions.

Do the a, f values all have the same shapes? If you are on GPU, you will be far happier if you can use a batched matrix multiplication (e.g., via np.matmul) rather than an explicit Python for loop, which will build a giant XLA computation.

On CPU, batch matmul won't help yet because we end up just lowering batched matrix multiplication to one matmul per batch element, although there's work in progress to fix that (perhaps landing sometime this week or next.)

They don't all have the same size, at least not in the example code; the size is determined by that randint call.

I tried turning it into a batch matmul yesterday but on the cpu backend the compilation time still took a while, which I found really surprising:

def mp(params, As, Fs):
  A = pad_stack(As, (20, 20))
  F = pad_stack(Fs, (20, 20))
  stacked_outputs = stacked_mp(params, A, F)
  outputs = [out[:a.shape[0], :f.shape[1]]
             for out, a, f in zip(stacked_outputs, As, Fs)]
  return outputs

@jit
def stacked_mp(params, A, F):
  output = np.matmul(A, F)
  # output = np.matmul(params['w'], output) + params['b']  # TODO params shape?
  output = np.tanh(output)
  return output

def pad_stack(lst, shape=None):
  if shape is None:
    shape = onp.max(list(map(np.shape, lst)), 0)
  return np.stack([pad_to(elt, shape) for elt in lst])

def pad_to(x, shape):
  end_pads = tuple(onp.subtract(shape, x.shape))
  pads = [(0, end_pad, 0) for end_pad in end_pads]
  return lax.pad(x, onp.array(0, x.dtype), pads)

On the CPU backend it will be expanded into the unrolled form (but not on the GPU backend). Hence you would expect the same behavior.

The relevant code is:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/dot_decomposer.h

Try on GPU; I believe @sanjoy is actively working on improving the CPU case.

@mattjj thanks for trying this out! I noticed you padded the arrays - is this to ensure that all of the A matrices have the same shape? Is this the trick that enables batched matrix multiplication? (Sorry if my question seems dumb here, I have a sense I'm still in learning mode w.r.t. these computational backend things.)

Yes, exactly! In this case I did it to reduce the number of nodes in the compiled graph (now in the @jit function there's only one input array for the padded-and-stacked A's and padded-and-stacked F's). But @hawkinsp seems to be saying that XLA:CPU ends up just expanding it out into the big graph anyway, just because batch matmul hasn't been implemented yet. (We could roll our own batch matmul using a loop construct, though it's probably better just to wait for improvements if they're on the way.)

Although the "map-using-loop" widget would have other uses, too. For example, currently the LAPACK routines mostly don't support batch dimensions. If we had a helper that did this, we could use it there too.

In principle the HLO Map operator does what we want, but in practice its implementation is too limited.

c.f. #212, which became a feature request for a "map-using-loop" widget.

PR https://github.com/google/jax/pull/229 improves the compilation time of batched matrix multiplication on the CPU backend; compilation for large batch sizes should be quick now. (Credit goes to @sanjoy )

@hawkinsp woohoo!!!! I can't wait to try it out! Thank you for the work, @sanjoy!

You'll need to rebuild jaxlib (or wait for us to do it and update the wheels) for this change to take effect, I think.

FYI I just updated the jaxlib linux wheels (to 0.1.4), both on pypi (for the non-cuda versions) and in our cloud bucket (for the cuda versions). I haven't updated the mac ones yet.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

NeilGirdhar picture NeilGirdhar  路  23Comments

froystig picture froystig  路  34Comments

dionhaefner picture dionhaefner  路  22Comments

martiningram picture martiningram  路  21Comments

samuela picture samuela  路  27Comments