It would be nice to have some basic conjugate distributions. These are especially useful when all local variables can be integrated out. Note that samplers can use the existing distributions, and we need only implement some new math for the .log_prob() methods.
It would also be nice to be able to sample the marginalized variable from its complete conditional, e.g.
def sample_hidden(self, value, sample_shape=()):
# sample hidden variable conditioned on params and value
In the long-term we'd prefer to automatically detect conjugate relationships from programs without users needing to manually specify conjugate-pair distributions (as pointed out by @martinjankowiak). Thus the pairs in this issue can serve as short-term implementations for numerical testing, and we can work on better conjugacy detection that compiles down to these pairs.
BetaBinomial (needed by @fritzo, @neerajprad )GammaPoissonDirichletMultinomialCould I work on NormalNormal?
Thank you @fritzo. I will have an implementation available in a few hours for the loc conjugated normal.
Closing this issue until we work out design issues in #1723 .
Here are the reasons I am closing this issue:
BetaBinomial, DirichletMultinomial, and GammaPoisson have the property that multiple observations can be trivially combined into a single observation (by adding the observations and adding total_counts). Other distributions have nontrivial sufficient statistics, and e.g. NormalNormal. It may make sense for e.g. NormalNormal to be a multivariate distribution with event dim.pyro.distributions.conjugate implementations cannot handle plate nesting likepy
concentration = pyro.param(...)
with pyro.plate("groups", ...):
probs = pyro.sample("probs", Dirichlet(concentration))
with pyro.plate("sessions", ...):
pyro.sample("obs", Multinomial(probs=probs),
obs=data)
batch_shapes involved, and the current distribution interface does not recognize the distinction.@fritzo Is there any update on a best practice for handling (point 4 from above) with using pyro.distributions.conjugate with plate nesting in order to integrate local variables exactly? I am trying to run SVI on exactly a Dirichlet-Multinomial conjugate pair.
@chanjed as mentioned in point 3 above, Dirichlet-Multinomial models permit fusion along a plate dimension, i.e. the following two models are equivalent, and the latter already works in Pyro:
# Model 1.
p = pyro.sample("p", Dirichlet(concentration))
with plate("plate", 100, dim=-1):
pyro.sample("x", Multinomial(probs=probs),
obs=x)
# Model 2.
p = pyro.sample("p", DirichletMultinomial(concentration),
obs=x.sum(-2, keepdim=True))
Most helpful comment
Could I work on NormalNormal?