Pyro: [FR] Conjugate distributions

Created on 8 Jan 2019  路  5Comments  路  Source: pyro-ppl/pyro

It would be nice to have some basic conjugate distributions. These are especially useful when all local variables can be integrated out. Note that samplers can use the existing distributions, and we need only implement some new math for the .log_prob() methods.

It would also be nice to be able to sample the marginalized variable from its complete conditional, e.g.

def sample_hidden(self, value, sample_shape=()):
    # sample hidden variable conditioned on params and value

In the long-term we'd prefer to automatically detect conjugate relationships from programs without users needing to manually specify conjugate-pair distributions (as pointed out by @martinjankowiak). Thus the pairs in this issue can serve as short-term implementations for numerical testing, and we can work on better conjugacy detection that compiles down to these pairs.

Tasks

  • [x] #1708 BetaBinomial (needed by @fritzo, @neerajprad )
  • [x] GammaPoisson
  • [x] #1734 DirichletMultinomial
  • others?
enhancement

Most helpful comment

Could I work on NormalNormal?

All 5 comments

Could I work on NormalNormal?

Thank you @fritzo. I will have an implementation available in a few hours for the loc conjugated normal.

Closing this issue until we work out design issues in #1723 .

Here are the reasons I am closing this issue:

  1. We're still working out the conjugacy interface, and we welcome contributions to the design doc, which is still in progress.
  2. We're implementing a library to enable Birch-style delayed sampling in Pyro. That library will eventually wrap implementations of conjugate math in Pyro and PyTorch, but we'd like to implement more of that library so we know exactly what interfaces to create.
  3. Regarding specific distributions, the BetaBinomial, DirichletMultinomial, and GammaPoisson have the property that multiple observations can be trivially combined into a single observation (by adding the observations and adding total_counts). Other distributions have nontrivial sufficient statistics, and e.g. NormalNormal. It may make sense for e.g. NormalNormal to be a multivariate distribution with event dim.
  4. @neerajprad is still working out issues with plate nesting. The problem we're seeing is that the current pyro.distributions.conjugate implementations cannot handle plate nesting like
    py concentration = pyro.param(...) with pyro.plate("groups", ...): probs = pyro.sample("probs", Dirichlet(concentration)) with pyro.plate("sessions", ...): pyro.sample("obs", Multinomial(probs=probs), obs=data)
    The issue is that there are really two batch_shapes involved, and the current distribution interface does not recognize the distinction.

@fritzo Is there any update on a best practice for handling (point 4 from above) with using pyro.distributions.conjugate with plate nesting in order to integrate local variables exactly? I am trying to run SVI on exactly a Dirichlet-Multinomial conjugate pair.

@chanjed as mentioned in point 3 above, Dirichlet-Multinomial models permit fusion along a plate dimension, i.e. the following two models are equivalent, and the latter already works in Pyro:

# Model 1.
p = pyro.sample("p", Dirichlet(concentration))
with plate("plate", 100, dim=-1):
    pyro.sample("x", Multinomial(probs=probs),
                obs=x)

# Model 2.
p = pyro.sample("p", DirichletMultinomial(concentration),
                obs=x.sum(-2, keepdim=True))
Was this page helpful?
0 / 5 - 0 ratings

Related issues

fritzo picture fritzo  路  4Comments

neerajprad picture neerajprad  路  4Comments

fehiepsi picture fehiepsi  路  4Comments

neerajprad picture neerajprad  路  4Comments

fehiepsi picture fehiepsi  路  3Comments