Current solutions to adding a custom log prob term include (1) subclassing ELBO #1059 #958, and (2) using a bogus statement
pyro.sample(name, dist.Bernoulli(logits=custom_term), obs=torch.tensor(1.))
By contrast STAN makes it easy to add custom loss terms.
Should we add a helper that adds a custom loss term, effectively adding a non-normalized log-prob via a hidden observe statement?
Other PPLs have a built-in factor helper. See https://github.com/uber/pyro/blob/dev/examples/rsa/search_inference.py#L33
We should also add a tutorial on writing custom losses as suggested in #958
@eb8680 great! What do you think about moving factor into pyro.primitives, adding docs, and using it in a couple examples (e.g. SS-VAE)?
What do you think about moving factor into pyro.primitives, adding docs, and using it in a couple examples (e.g. SS-VAE)?
We should write a custom FactorDistribution with the correct shape semantics and use that to implement factor. I actually don't think it's a good idiom and should be discouraged unless really necessary, so I'd rather not use it in existing examples, but maybe we could add a separate factor graph example?
Agreed, I'm not convinced it is a good idiom.
Hi everybody,
I recently started using Pyro. Thanks for the great SS-VAE tutorial. It really helped me getting up to speed.
For an extension of the SS-VAE, I want to add an additional (weighted) MMD regulariser to my elbo, since I need samples from my variational posterior I thought about adding a bogus statement to my model function, e.g.
pyro.sample("MMD", dist.Bernoulli(logits=self.mmd_multiplier*self.calculate_mmd(z, y)), obs=torch.tensor(1.).cuda())
Or is there a better option?
Most helpful comment
Agreed, I'm not convinced it is a good idiom.