Pyro: Batching and plating

Created on 9 Dec 2019  路  15Comments  路  Source: pyro-ppl/pyro

co-authored with @iffsid

Statistically, batching allows us to trade off the speed of taking gradient steps for lower variance gradient estimators. This can be done in a for loop. However, batching is most useful for parallelizing the computation of the log probs of the model and the guide by using low level numerical optimizations.

Plating is a useful construct for defining conditionally independent but identical random variables. While this is useful for defining hierarchical models, it is unsuited for batching since it changes the definition of the model. This leads to issues like IWAE-based algorithms failing silently, additional multiplicative factors in ELBO-based gradients, and being able to define nonsensical guides.

We propose defining a special keyword for plates that are supposed to indicate batching which circumvents the above problems.

More details

Consider a model p(z)p(x|z) where z and x can have multiple sites and each site can have different shapes, depending on the distribution batch shapes but also the plates internal to the model. Given a dataset of training xs, we often want to take the gradient of log p(x), averaged over a batch (subset of xs), as an estimator of E_{xs}[grad log p(x)]. In practice, it is suggested that we change the model by adding a batching plate so that the model becomes p(z_{1:B}, x_{1:B}) = prod_b p(z_b)p(x_b | z_b).

Problem 1: IWAE-based objectives fail silently

In IWAE-based objectives like RenyiELBO and RWS, we want to compute the model objective as logsumexp over the particle dimension and then average over the batch dimension. However, since there is no distinction between a batch dimension and a dimension of an actual plate, IWAE-based models are forced to sum over the batch dimension before taking the logsumexp over the particle dimension (here and here). This is wrong since the order of these operation matters.

Problem 2: ELBO-based gradients have wrong multiplicative factors

We鈥檙e interested in estimating E_{xs}[grad log p(x)] using a Monte Carlo estimator, which boils down to taking an average of log p(x) over xs. However, if we change the model from p(z, x) to p(z_{1:B}, x_{1:B}) using a plate in order to accomplish this, the resulting estimator will instead be a sum of log p(x) over xs which is off by a multiplicative factor of B. While this is fine if we use adaptive optimizers like Adam, it is wrong for non-adaptive ones like SGD.

Problem 3: Ability to define nonsensical guides

Consider the case of VAEs where latent vectors have dimension D_z and data has dimension D_x. If we change the model to p(z_{1:B}, x_{1:B}), we have the ability to write the guide to take in [B, D_x] observations and output a distribution over [B, D_z] which is not necessarily independent in the batch dimension, e.g. a multivariate normal distribution over B * D_z dimensions. This is clearly nonsensical for the VAE model where we just want the guide to map from data of shape [D_x] to distributions on vectors of shape [D_z]. While this is not as big as a problem as the previous ones, it is an indication that changing the model definition in order to do batching is unnatural.

Proposed solution: Use a plate with a special name for batching

Users will still deal with batching like before---by adding a plate around the body of model manually---but they are forced to name this plate using a string with a special keyword value. This keyword is then used internally by IWAE-based and ELBO-based objectives to compute the correct gradient estimators. This doesn鈥檛 solve the third problem.

discussion

Most helpful comment

Your return statement is exiting the loop early.

All 15 comments

cc: @martinjankowiak @fritzo

This is something we came across in writing RWS for pyro.

Re renyi_elbo: I think the issue here is you want to compute logp(x) as an average over data points. I am not sure if RenyiElbo still works if you do so. Like elbo, we use the objective logp(X) instead of E logp(x).

In the above, x and z mean different things for different models.

If we're interested in models of the form p(theta, x_{1:N}) = p(theta) \prod_n p(x_n | theta) (like in equation 2 of https://arxiv.org/pdf/1602.02311.pdf), then theta is z and x_{1:N} is x. Thus, maximizing the ELBO maximizes log p(x_{1:N}).

The batching issue only comes into play when you want to amortize inference in this model over multiple x_{1:N}s. I think we want to average over ELBOs instead of sum.

IIUC then you want to train multiple models (have the same formulation) on different datasets (have the same size) using a single SVI run (with objective is the average of the objectives of single models). I am not sure if Pyro supports that... so would like to hear more opinions. In particular, RenyiELBO is not implemented to support this way of inference. If the number of datasets is small, probably you can just simply loop over those datasets. :D

Hmmm maybe we could meet in person at NeurIPS to discuss?

Yup, I'll be around there from Tuesday. Happy to chat about this!

@iffsid and I were trying to implement an example with sequential batching. We don't want to take gradients of ELBOs of single data points (because this results in high gradient variance) but at the same time don't want to vectorize the ELBO computations. Examples of when we might want to do this is for is models with dynamic number of latent variables like PCFGs.

The pyro.plate examples suggest using a for i in pyro.plate(...) construct together with a control flow (if z[i]) in order to prevent vectorization.

We tried this in an example below but we find that the i doesn't loop over all data points. The code below only prints zeros.

Is this the correct way to do sequential batching?

import torch
import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist

pyro.set_rng_seed(101)


def scale(guess_init, guess_scale, obs_scale, observations=None):
    obss = torch.tensor(observations['measurement'])
    guess_loc = pyro.param('guess', torch.tensor(guess_init))
    for i in pyro.plate('obss', len(obss)):
        if obss[i]:
            print(i)
            weight = pyro.sample('weight_{}'.format(i), dist.Normal(guess_loc, guess_scale))
            return pyro.sample('measurement_{}'.format(i), dist.Normal(weight, obs_scale),
                               obs=torch.tensor(obss[i]))


def scale_parametrized_guide(guess_init, guess_scale, obs_scale, observations=None):
    obss = torch.tensor(observations['measurement'])
    loc_mult = pyro.param('loc_mult', torch.tensor(1.))
    loc_add = pyro.param('loc_add', torch.tensor(0.))
    log_scale = pyro.param('log_scale', torch.tensor(0.))
    for i in pyro.plate('obss', len(obss)):
        if obss[i]:
            print(i)
            return pyro.sample('weight_{}'.format(i), dist.Normal(loc_mult * obss[i] + loc_add,
                                                                  torch.exp(log_scale)))


if __name__ == '__main__':
    guess_init = 8.5
    guess_scale = 1.0
    obs_scale = 0.75
    obss = [9.5, 9.1, 9.2]

    num_particles = 100
    vectorize = True

    pyro.clear_param_store()
    svi = pyro.infer.SVI(model=scale,
                         guide=scale_parametrized_guide,
                         optim=pyro.optim.Adam({'lr': 0.1}),
                         loss=pyro.infer.ReweightedWakeSleep(num_particles=num_particles,
                                                             vectorize_particles=vectorize,
                                                             insomnia=1.))

    num_steps = 10000
    for t in range(num_steps):
        theta_loss, phi_loss = svi.step(guess_init, guess_scale, obs_scale,
                                        observations={'measurement': obss})

Your return statement is exiting the loop early.

Your return statement is exiting the loop early.

You're right, thank you.

Is this still a problem? Is there any plan to fix it? I am quite keen to try out these RWS type objectives on a model which I have been working on which is currently largely implemented in pyro, so I would be willing to help out writing a pull request if this is something that people would consider adding to pyro.
As a temporary workaround, I'm also curious how bad the failure currently is if we just ignore the conceptual issues and do minibatching anyway - does exchanging the order of the logsumexp and sum have a big impact in practical terms? Did you try this experimentally?

Hi @lsgos if you need something right now, I started implementing this in https://github.com/tuananhle7/pyro/tree/batching but am currently busy with other stuff so won't be able to submit a PR soon.

cc @iffsid @fritzo @martinjankowiak

When we have a model with dynamic traces, we're forced to use a sequential plate construct for batching as well as vectorized=False for the particles in RenyiELBO/RWS.

In the current implementation of RenyiELBO/RWS, this results in the batch dimension being summed out before doing the logsumexp on the particle dimension because the per-particle trace includes all the elements in the batch. (Instead, we want to do logsumexp over the particle first and then sum/mean over the batch.)

In particular, these two loops sum out over all the batch elements in a per-particle trace because nodes in model_trace.nodes are flat.

Below is a minimal example showcasing this. If you run this with a breakpoint in this line, and look at elbo_particles, it's a list of num_particles 0-dim tensors whereas it should be of shape num_particles x batch_size.

What would be a good way to fix this issue?

import torch
import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist


def geometric(batch_id, geometric_id, logits):
    if pyro.sample(f"{batch_id}_{geometric_id}", dist.Bernoulli(logits=logits)) == 1:
        return 0
    else:
        return geometric(batch_id, geometric_id + 1, logits) + 1


def model(obs):
    logits = pyro.param("model_logits", torch.tensor(0.0))
    for batch_id in pyro.plate("batch", len(obs)):
        geometric_value = geometric(batch_id, 0, logits)
        pyro.sample(
            f"obs_{batch_id}",
            dist.Bernoulli(probs=torch.tensor(0.99)),
            obs=(geometric_value == obs[batch_id]).float(),
        )


def guide(obs):
    logits = pyro.param("guide_logits", torch.tensor(0.0))
    for batch_id in pyro.plate("batch", len(obs)):
        geometric(batch_id, 0, logits)


if __name__ == "__main__":
    obs = torch.tensor([8, 9, 10])  # batch_size = 3 in this case
    num_particles = 4
    vectorize = False

    pyro.clear_param_store()
    svi = pyro.infer.SVI(
        model=model,
        guide=guide,
        optim=pyro.optim.Adam({"lr": 0.1}),
        loss=pyro.infer.RenyiELBO(num_particles=num_particles, vectorize_particles=vectorize,),
    )

    num_steps = 10000
    for t in range(num_steps):
        svi.step(obs)

@tuananhle7 Sequential plates generate metadata at sample sites indicating which "slice" the site is in (the .counter attribute of CondIndepStackFrame).

Here's a helper function that uses this metadata to take a trace of a single particle and the name of the batch plate (which you could either hardcode or read off from the trace using the fact that it must appear at every site) and group its log_prob tensors into separate lists for each batch index:

def group_batch_log_probs(particle_trace, plate_name):
    batch_log_probs = collections.defaultdict(list)
    for node in trace.nodes.values():
        if node["type"] != "sample":
            continue
        batch_frame = next(f for f in node["cond_indep_stack"] if f.name == plate_name)
        batch_log_probs[batch_frame].append(node["log_prob"])

    # return list of lists of factors per batch index
    return list(batch_log_probs.items())

You should be able to use this function in a fork of RenyiELBO to reorder operations appropriately, although I imagine doing everything sequentially will be quite slow.

Thanks @eb8680 :)

We're not that worried about speed right now; more that there is a class of models we'd like to target that isn't easy to learn because using RenyiELBO/RWS necessitates batch_size = 1 right now; i.e. DreamCoder-style programmatic models.

Do you guys think it might be useful in general to look into 'padded' traces as a special case---where conditional independence holds, and vectorisation isn't possible only because of stochastic-length traces---to help with the speed issue? Assuming of course that you've not already looked into it and decided it's more pain than its worth?

it might be useful in general to look into 'padded' traces

While I haven't thought about this in detail, I'm inclined to implement this as some sort of transformation or rewriting, either of Trace objects or of collections of log_probs extracted from those trace objects. One possible implementation would be to transform a sequential Trace object to a vectorized Trace object by concatenating some of the sites. Another possibility would be to extract the log_prob terms as a data structure of Funsors and do the rewriting with funsor machinery. Probably the simplest possibility is to formalize the assumptions of @eb8680's helper above.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

tobyclh picture tobyclh  路  3Comments

fritzo picture fritzo  路  4Comments

fehiepsi picture fehiepsi  路  4Comments

tristandeleu picture tristandeleu  路  3Comments

fehiepsi picture fehiepsi  路  4Comments