cc @OptimusLime
Hi @fritzo,
The guide will be a mixture of deltas, however the gradient is not ELBO-based and relies on the computation of Stein Kernels. These kernels, and the resulting gradient of the Kernelized Stein Discrepency which gives the GD direction are reasonably straightforward to compute provided one has access to the score function (grad log joint) over latents.
I could give this a go myself, but I'm not well versed in how the finer points of the back-end works. Have you considered making a method that creates a function for the log-joint for a model as well as its gradient? Just kind of a "direct" API to this.
Perhaps easier said than done, but it would make Pyro a lot more appealing for researchers who develop new inference techniques to prototype and test approaches.
@robsalomone thanks for your input. Here's one way to get a log-joint function that takes a dictionary vars mapping sample site names to sample values as inputs along with the usual inputs to the corresponding model and returns a torch scalar containing the log-joint probability that can be differentiated with torch.autograd.grad or backward:
def make_log_joint(model):
def _log_joint(vars, *args, **kwargs):
tr = poutine.trace(poutine.condition(model, data=vars)).get_trace(*args, **kwargs)
# make sure all variables in the model are accounted for
assert all(node["is_observed"] for node in tr.nodes.values() if node["type"] == "sample")
return tr.log_prob_sum()
return _log_joint
log_joint_fn = make_log_joint(model)
logp = log_joint_fn({"x": torch.tensor(0.123, requires_grad=True), ...}, ...)
We completely agree with the broader point about better support and documentation of internals for research users, and we're planning to prioritize that much more going forward.
Thanks for that @eb8680 ! I might try to do up a module for computing Stein kernel things down the line. I am interested in research in this direction so it would be great to do these things in Pyro.
What's the status on this?
Recently, I've been playing with the idea of using SVGD (or more generally, a discretized diffusion process, such as SGLD and variants) as part of the guide via a mixture of Deltas.
I could try to work on it if it's posible!
Hi @vicgalle unfortunately I've been very busy so haven't tried as of yet, so its all yours! I'm happy to help out if you have any issues.
Regarding your idea, not sure if you realise but there are similar ideas in using SVGD to optimise parameters of samplers like SGLD, its called Amortized SVGD (see this paper. Would be great if you could do up that method also!
The SVGD algorithm is pretty simple once you get the gradient of the log joint. I'd also recommend implementing SVGD with Complete Conditional Stein Discrepancy.
Functionality wise, I think the challenge will be fitting in SVGD mechanics to Pyro as the method doesn't do SVI in the classical sense (there is no explicit objective, you just have approximate gradient directions for optimising over the RKHS).
I believe @OptimusLime is also working on a Pyro implementation of SVGD. IIRC he plans to create a new non-ELBO loss function.
a basic version of SVGD (along with the complete conditional variant) has been implemented in #1991.
closing this general issue in favor of more targeted issues for any additional feature requests, bug reports, etc.
Most helpful comment
Thanks for that @eb8680 ! I might try to do up a module for computing Stein kernel things down the line. I am interested in research in this direction so it would be great to do these things in Pyro.