Pyro: Parallelize ELBO computation over num_particles

Created on 20 Feb 2018  路  10Comments  路  Source: pyro-ppl/pyro

ELBO implementations currently iterate over num_particles. Let's also do this in parallel.

Why?

This is most useful for speeding up our gradient tests. If we had cheap parallelized estimates of the ELBO, then we could replace many of Pyro's expensive integration tests with cheaper gradient tests. This in turn would allow us to test over much larger grids of parameter settings rather than over a handful of carefully crafted models.

How?

This can follow logic similar to parallel discrete enumeration in #776 . We might set num_particles=1000, particles="parallel" as arguments to ELBO.

Note that a klugey way to do this that already works in Trace_ELBO is to add a fake Categorical sample site to the beginning of the guide and model, and then use that to broadcast all remaining samples.

def guide():
    pyro.sample("parallelize",
                dist.Categorical(Variable(torch.ones(num_particles))),
                infer={"enumerate": "parallel"})
    # ...rest of guide code here...
enhancement testing

All 10 comments

aside from internal reasons (eg testing speed), this is a great way to reduce gradient variance at low time-cost (at least when things fit on gpu).

I was trying to think about this for a few toy models that I was interested in.

I thought that one way to do this would be to generate a (num_particles,) sized additional batch for the topmost nodes (how to identify them ahead of time is another question) in the computation graph, e.g. in the case of the model below, that would be p. The remaining nodes in the graph that depend on the top nodes won't be affected. All enumerated nodes would also be left as is. These continuous latent sites would then also be queued like the discrete ones in iter_discrete_traces, except that they would be stitched together in sequence and not taken cross-product of when constructing the partial trace. The remaining machinery to construct the full trace could just follow from #776.

@fritzo - Does that seem reasonable, and is that the approach you had in mind? Would love to pick your brains on this, when you have some time.

def model(data):
    alpha = pyro.param("alpha", torch.tensor([1.1, 1.1]))
    beta = pyro.param("beta", torch.tensor([1.1, 1.1]))
    p = pyro.sample("p", dist.Beta(alpha, beta))
    with pyro.iarange("data", data.shape[0]):
        pyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
    return p_latent

Can鈥檛 we just wrap the whole elbo sample step in an iarange?

@neerajprad I believe something like EnumerateMessenger will be needed to .expand_by([num_particles]) at all sample sites that aren't already expanded. This is needed to address exactly the difficulty you point out: that the only general way to identify sample sites at the top of the compute graph is to look at all sample sites and see whether they have been broadcasted yet. I suggest forking EnumerateMessenger and modifying its logic. I'm happy to pair code.

@ngoodman It currently won't work to simply wrap then entire mode and guide in an iarange, since Pyro does not automatically broadcast everything inside an iarange. We could conceivably implement implicit batching. It's unclear whether implicit batching would be easier or more difficult to use than explicit .expand_by().

that makes sense fritz. feels like there's something a little janky in our conventions -- maybe all sites should be broadcast? and it would be interesting to understand how hard auto / implicit broadcasting would be.... anyhow, that's a discussion for a different issue.

maybe all sites should be broadcast?

The difficulty in this case is to avoid double-broadcasting, so I'm not sure there's a way around what @fritzo is suggesting.

@fritzo - Thanks for the pointers! I might take a stab at this later this week, and book some time with you to pair code, depending on how far I get. :)

I think once we have this, we can also reuse this to do parallel HMC chains with only a minor tweak to the integrator.

With regard to explicit broadcasting via .expand_by, I think the issue is just the mental book-keeping you need to do to figure out broadcasting by additional iarange, parallel enum and what dims to designate as .independent which will become tough with moderate sized models. Even for the simple model above which we can hand-parallelize as below, note how we need to expand by [100, 1] to account for the next "data" iarange. It seems like doing this intelligently via implicit broadcasting has the same issues that we discussed above, so I am not sure if that would help.

def model(data):
    with pyro.iarange("num_particles", 100):
        alpha = pyro.param("alpha", torch.tensor([1.1, 1.1]))
        beta = pyro.param("beta", torch.tensor([1.1, 1.1]))
        p = pyro.sample("p", dist.Beta(alpha, beta).expand_by([100, 1]).independent(2))
    with pyro.iarange("data", data.shape[0]):
        pyro.sample('obs', dist.Bernoulli(p_latent), obs=data).independent(2)
    return p_latent

It seems like doing this intelligently via implicit broadcasting has the same issues that we discussed above, so I am not sure if that would help.

This is what I was working towards in #950 and the factor graph design doc, although I didn't quite figure out the bookkeeping then and tried to put too much information into site["cond_indep_stack"]. Basically, you need to open an iarange every time you do a new broadcasting operation at a sample site, and you need to be able to automatically detect when you're outside an iarange by looking at the shape of a site. With those two operations implemented correctly, it would be fairly straightforward to implement this, parallel chains in HMC, #825, parallel enumeration, #811 etc. without having to think too hard about broadcasting as a user.

Thanks @eb8680. I will look at your PR too; are you planning to resurrect it? @fritzo and I were discussing the two ways to do this - one that is outlined above (that I think I have a good handle on), and the other being the approach that you mentioned that uses the cond_indep_stack. I still need to stare at the code a bit to fully grok this though. :)

the two ways to do this - one that is outlined above (that I think I have a good handle on), and the other being the approach that you mentioned that uses the cond_indep_stack.

I think the difficult part of both is the same - deciding whether a site is outside of an iarange based on its shape. I'd advocate figuring that out and using it to implement the first version, at which point resurrecting the more general/abstract version begun in #950 should be fairly straightforward.

Was this page helpful?
0 / 5 - 0 ratings