Hi all,
I have been trying to write a custom loss function (essentially ELBO with an extra regularizer) to use with SVI. Which class should it inherit from and what methods should it include?
So far I've tried inheriting nn.Module and including an init() and forward() class but I cannot get it to work.What should be the arguments to the loss function? Model, guide and the data?
Thank you in advance for any help!
@yhalk can you please specify more exactly what you're after. what kind of regularizer?
Thanks @martinjankowiak
I want to train a bayesian network for classification and would like to try a weighted sum of two kinds of targets for each training sample: a soft (continuous) and a hard (binary) label. The soft will be used with the ELBO, and I'd like to add a loss (possibly cross entropy) for the binary label. Is it possible?
I'm also trying to add custom losses in addition to ELBO. There are some workarounds right now but it's a bit confusing.
Another problem is that right now for svi.step(), I can't (or I don't know how to) update specific parameters. For example in GANs, I need to update parameters in the discriminator and generator separately. I think for now the easiest and most intuitive way for me is to still do the standard PyTorch .backward() and optimizer.step() for the additional losses. Several questions about this:
svi.step() already does a forward pass. Do I need to do it again, or is there a way I can get some variables in model() and guide() such that I can calculate the loss directly from it and do a gradient step?pyro.optim.Adam for SVI, and torch.optim.Adam for the other losses (to specify the parameters that will be updated). I feel like it makes more sense to use the same optimizer. Is there a way around this?@jthsieh I will try to answer based on my limited knowledge. Basically, you can modify SVI to achieve your target.
loss_and_grad on ELBO. Then call your custom loss .backward() before any updating step.SVI.optim to update parameters of a model/guide trace. You can use SVI.optim to update whatever parameters you want. You can get these param instances (to feed into Pyro optim) from pyro.param(param_name).unconstrained().net(x) and don't have access to middle computation, unless you separate your net into smaller modules. In Pyro, you can have a quick access to primitive nodes from a trace, so you can simplify your computation by using them.@jthsieh the update currently computed by SVI.step is very simple: it just applies a single step of a PyTorch optimizer to all of the parameters that appear in the loss.
pyro.optim is itself just a set of programmatically generated wrappers for torch.optim optimizers that makes sure optimizer state/hyperparameters and parameters are matched correctly across forward passes, since parameters can potentially appear out of order or not at all depending on the values of random samples in the model.
If you want to do something less vanilla it's probably easier to write your own loss and/or optimization loop using pyro.poutine, Pyro's lower-level tools for manipulating models. If you can give examples with code we can probably help with that; in the meantime here's a sketch of an example of a simple custom loss function and optimization loop that satisfies all three of your requirements above:
def mc_elbo_with_l2(model_trace, guide_trace, lam=0.1, discriminator, ...):
logp = model_trace.log_prob_sum()
logq = guide_trace.log_prob_sum()
penalty = 0.
for node in model_trace.nodes.values():
if node["type"] == "param":
penalty = penalty + lam * torch.sum(torch.pow(node["value"], 2))
for node in guide_trace.nodes.values():
if node["type"] == "param":
penalty = penalty + lam * torch.sum(torch.pow(node["value"], 2))
discriminator_term = discriminator(model_trace, guide_trace, ...)
return logq - logp + penalty + discriminator_term
def aux_loss(model_trace, guide_trace, discriminator):
... # something with discriminators...
return loss
optimizer = pyro.optim.Adam(...)
for minibatch in dataset:
guide_trace = poutine.trace(guide).get_trace(minibatch)
model_trace = poutine.trace(poutine.replay(model, trace=guide_trace)).get_trace(minibatch)
main_loss_val = mc_elbo_with_l2(model_trace, guide_trace, lam=0.1, discriminator, ...)
main_loss_val.backward()
main_params = set(node["value"].unconstrained()
for node in tr.values()
for tr in (model_trace, guide_trace)
if node["type"] == "param")
optimizer(main_params)
... # zero gradients, reset state, plot stuff etc.
aux_loss_val = aux_loss(model_trace, guide_trace, pyro.module("disc", discriminator), ...)
aux_loss_val.backward()
aux_params = discriminator.parameters()
optimizer(aux_params)
... # zero gradients, reset state, plot stuff etc.
@jthsieh Please follow @eb8680 's comment. My answer might miss some important aspects.
Thank you so much for your help, @fehiepsi and @eb8680!! I'm not familiar with pyro.poutine, but I'll definitely spend some time learning it.
What I'm doing is similar to a VAE, but I want to try adding a discriminator on top of the output. (By the way, I realized loss='ELBO' is deprecated, but I haven't updated Pyro). I feel like my code is pretty simple, so I probably don't need the advanced pyro.poutine for now. I'll try out @fehiepsi 's answer, using loss_and_grad. The following is a simplified version of the code I have right now:
class Model():
def __init__(self):
self.netG = Generator(...)
self.netD = Discriminator(...)
# Pyro optimizer and SVI
self.optimizer = optim.Adam({'lr': 1e-3})
self.svi = SVI(self.model, self.guide, self.optimizer, loss='ELBO')
# Additional losses
self.optimizer_g = torch.optim.Adam(self.netG.parameters(), lr=1e-3, betas=(0.5, 0.999))
self.optimizer_d = torch.optim.Adam(self.netD.parameters(), lr=1e-3, betas=(0.5, 0.999))
def model(self, input):
pyro.module('generator', netG)
...
def guide(self, input):
...
def train(self, minibatch):
# Update parameters for one minibatch
...
self.svi.step(minibatch)
# Additional losses
z = ...
output = self.netG(z)
# Update discriminator
self.optimizer_d.zero_grad()
out_d = self.netD(output.detach())
loss_d = self.criterion(out_d, ...)
loss_d.backward()
self.optimizer_d.step()
# Update generator
self.optimizer_g.zero_grad()
out_g = self.netD(output)
loss_g = self.criterion(out_g, ...)
loss_g.backward()
self.optimizer_g.step()
@martinjankowiak @jpchen can we close this now that #1592 has merged?
yeah i think it's no longer a release blocker, additional examples/fixes would be minor.
Most helpful comment
@jthsieh the update currently computed by
SVI.stepis very simple: it just applies a single step of a PyTorch optimizer to all of the parameters that appear in the loss.pyro.optimis itself just a set of programmatically generated wrappers fortorch.optimoptimizers that makes sure optimizer state/hyperparameters and parameters are matched correctly across forward passes, since parameters can potentially appear out of order or not at all depending on the values of random samples in the model.If you want to do something less vanilla it's probably easier to write your own loss and/or optimization loop using
pyro.poutine, Pyro's lower-level tools for manipulating models. If you can give examples with code we can probably help with that; in the meantime here's a sketch of an example of a simple custom loss function and optimization loop that satisfies all three of your requirements above: