Hi,
thanks for releasing Pyro.
Any plans to have GAN like inference in pyro?
Thanks
Hi @sameerkhurana10 we've discussed adding a discriminator example following Mohamed, Lakshminarayanan (2016) and Tran, Ranganath, Blei (2017), but we have no immediate plans. If you'd like to add something yourself, @eb8680 or @martinjankowiak could suggest where to start.
thanks @fritzo
i could have a go at it. Some pointers will be great.
@sameerkhurana10 most such algorithms should pretty easy to implement with Pyro's existing tools (though if you find that's not the case, feel free to open another issue!). In my admittedly limited experience, the algorithms themselves (independent of Pyro) are extremely brittle and sensitive to hyperparameters, so there's not a compelling reason for us to add whole algorithms to Pyro instead of just providing tools for implementing them concisely. That said, however, we'd definitely welcome any contributions of examples or tutorials, or even a contributed library like pyro.contrib.gp.
Here's an almost-complete idiomatic Pyro implementation of a VAE with an implicit variational distribution and an Adversarial Variational Bayes loss optimized with simultaneous gradient descent, which seems like the simplest GAN inference variant:
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
import pyro.optim
import pyro.poutine as poutine
# only these and the data are missing
decoder = nn.Sequential(...)
encoder = nn.Sequential(...) # should have some internal randomness, e.g. call to torch.randn()
discriminator = nn.Sequential(...)
def model():
z = pyro.sample("z", dist.Normal(torch.zeros(10), torch.ones(10)))
loc_h, scale_h = pyro.module("decoder", decoder)(z)
return pyro.sample("x", dist.Normal(loc_h, scale_h))
constrained_model = lambda x: pyro.condition(model, data={"x": x})
def guide(x):
# encoder should have some internal randomness not exposed to pyro
return pyro.sample("z", pyro.module("encoder", encoder), x)
def loss(model, guide, *args, **kwargs):
pyro.module("discriminator", discriminator)
guide_tr = poutine.trace(guide).get_trace(*args, **kwargs)
model_tr = poutine.trace(poutine.replay(model, trace=guide_tr)).get_trace(*args, **kwargs)
prior_tr = poutine.trace(model).get_trace(*args, **kwargs)
# main loss
elbo = model_tr.nodes["x"]["fn"].log_prob().sum()
elbo -= discriminator(guide_tr.nodes["z"]["value"],
*guide_tr.nodes["z"]["args"]).sum()
# discriminator loss
aux_loss = torch.log(torch.sigmoid(discriminator(guide_tr.nodes["z"]["value"],
*guide_tr.nodes["z"]["args"]))).sum()
aux_loss -= torch.log(1. - torch.sigmoid(discriminator(prior_tr.nodes["z"]["value"],
prior_tr.nodes["x"]["value"]))).sum()
return main_loss, aux_loss
main_optim = pyro.optim.Adam({"lr": 0.001})
aux_optim = pyro.optim.Adam({"lr": 0.001})
... # load data
for batch in data:
with poutine.trace(param_only=True) as param_capture:
main_loss, aux_loss = loss(constrained_model, guide, batch)
# since discriminator is nn.Module, could also use:
# aux_params = discriminator.named_parameters()
aux_params = {name: node["value"].unconstrained()
for name, node in param_capture.nodes.items()
if "discriminator" in name}
# since encoder/decoder are nn.Modules, could also use:
# main_params = encoder.named_parameters()
# main_params.update(decoder.named_parameters()) # assuming names are different
main_params = {name: node["value"].unconstrained()
for name, node in param_capture.nodes.items()
if "discriminator" not in name}
for main_param in main_params.values():
if main_param.grad is not None:
main_param.grad.fill_(0)
main_loss.backward()
main_optim.step(main_params.values())
for aux_param in aux_params.values():
aux_param.grad.fill_(0)
aux_loss.backward()
aux_optim.step(aux_params.values())
Great, thanks @eb8680
this should be very helpful.
Hi sameerkhurana,
have you been able to get this to work?
sorry, did not have time to work on it. Won't be able to get to it anytime
soon.
On Thu, Jun 7, 2018 at 2:37 PM, ibulu notifications@github.com wrote:
Hi sameerkhurana,
have you been able to get this to work?—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/uber/pyro/issues/1164#issuecomment-395522271, or mute
the thread
https://github.com/notifications/unsubscribe-auth/AHV3feloJndUMV_cvkdO_WKcxaRdDdFVks5t6XLlgaJpZM4UPVde
.
--
conversation enriches understanding, but solitude is the school of genius.
@eb8680 @sameerkhurana10 Can I take this issue? If yes, I am right that what I need to implement is just to introduce a new class inherited from ELBO (say, AdversarialELBO), that would introduce a classifier for every latent variable?
Sure, go for it! To make it easier to get started, I would recommend first
implementing a self-contained example taken directly from a single paper
rather than a general piece of machinery - maybe the first toy example in
the Adversarial Variational Bayes paper?
On Sun, Feb 24, 2019, 10:00 AM varenick notifications@github.com wrote:
@eb8680 https://github.com/eb8680 @sameerkhurana10
https://github.com/sameerkhurana10 Can I take this issue? If yes, I am
right that what I need to implement is just to introduce a new class
inherited from ELBO (say, AdversarialELBO), that would introduce a
classifier for every latent variable?—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/pyro-ppl/pyro/issues/1164#issuecomment-466800005, or mute
the thread
https://github.com/notifications/unsubscribe-auth/AB8CwB9SHrSqaeTieTMaRAQA7cFrPMSnks5vQtM8gaJpZM4UPVde
.
@eb8680 I have read the discussion here once again; am I right, that there are actually no need for a separate functionality, like, say, AdversarialELBO class, but rather an example of using adversarial methods in pyro?
Looking at your example, it seems like computing a KL-divergence with a classifier requires a user to "dig into guts" of pyro. By saying this, I mean a user should at least use a pyro.poutine module which seems quite low-level. May be, it would be convenient to provide a class that would implement methods expected_log_likelihood, entropy, cross_entropy separately, together with methods get_prior_samples and get_posterior_samples. This class would also be convenient for KL annealing, which is a popular technique
@varenick it would be easy enough to turn my example code above into a simple generic AdversarialELBO class. I suppose I'm biased, but it seems pretty readable to me, especially if paired with our custom SVI objectives tutorial :)
I'm not opposed to having an AdversarialELBO class, but starting with an example makes sense to me for two reasons: first, because it will help familiarize you with the APIs you'd use to write a generic version, and second, you'll get a better sense for how difficult these algorithms are to tune, especially if you haven't worked with them before. You'll find that if you dig into the gritty details of even the simplest nontrivial examples, like the MNIST experiments in the Adversarial Variational Bayes paper, there are always several layers of hacks ("Adaptive Contrast" in this case) required to get the optimization algorithm to converge that may not generalize usefully beyond those examples in practice. Starting with an example or two would help guide the design of a generic implementation that works more reliably.
Re: adding various ELBO term methods, that's an interesting idea but seems distinct from the discussion here. Feel free to open a separate issue to discuss further.
@eb8680: I've worked a fair bit on implementing this paper, using Pytorch, on a reasonably sized dataset (100,000 dims, 5000-50000 samples). It involves a GAN-like objective and ties in nicely with a causal inference in genomics problem. I've been meaning to port my code to Pyro but may need some help since I'm a beginner with PPLs. Mind if I have a go?
Also, this could be a nice project for GSoc 2020.
Mind if I have a go?
@deepaks4077 sure! I'd encourage you to first make sure your existing code is correct and produces the results you expect on your data. Once you've done that, it should be easy to try porting it to Pyro following the discussion in this issue and in our custom objectives tutorial. If you get stuck, please don't hesitate to ask questions or open a PR with incomplete or incorrect code so that we can help you get it finished.
@eb8680 : Great, I'll have a go at this soon. Does pyro have anything akin to ImplicitKLqp in Edward 1?
Edit: Nevermind, I believe that is what we are trying to implement here.
Most helpful comment
@sameerkhurana10 most such algorithms should pretty easy to implement with Pyro's existing tools (though if you find that's not the case, feel free to open another issue!). In my admittedly limited experience, the algorithms themselves (independent of Pyro) are extremely brittle and sensitive to hyperparameters, so there's not a compelling reason for us to add whole algorithms to Pyro instead of just providing tools for implementing them concisely. That said, however, we'd definitely welcome any contributions of examples or tutorials, or even a contributed library like
pyro.contrib.gp.Here's an almost-complete idiomatic Pyro implementation of a VAE with an implicit variational distribution and an Adversarial Variational Bayes loss optimized with simultaneous gradient descent, which seems like the simplest GAN inference variant: