I saw the following code in the cycle_gan_model backward_G method
# GAN loss
# D_A(G_A(A))
self.fake_B = self.netG_A.forward(self.real_A)
pred_fake = self.netD_A.forward(self.fake_B)
self.loss_G_A = self.criterionGAN(pred_fake, True)
# D_B(G_B(B))
self.fake_A = self.netG_B.forward(self.real_B)
pred_fake = self.netD_B.forward(self.fake_A)
self.loss_G_B = self.criterionGAN(pred_fake, True)
# Forward cycle loss
self.rec_A = self.netG_B.forward(self.fake_B)
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
# Backward cycle loss
self.rec_B = self.netG_A.forward(self.fake_A)
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
# combined loss
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
self.loss_G.backward()
The way I see it, G_A and G_B each has three forward passes, twice accepting the real data and twice the fake data.
In tensorflow (I think) the backward pass is always computed w.r.t the last input data. In this case, the backpropagation of loss_G would be wrong and one should instead do backward pass thrice, each immediately following their involving forward pass.
I assume this is somehow taken care of in pytorch. But how does the model know w.r.t which input data it should compute the gradients?
That is the beauty of Pytorch. Pytorch will automatically handle it.
Is this really true that Pytorch autograd handles this?
According to this post, it says otherwise: https://discuss.pytorch.org/t/how-to-use-the-backward-functions-for-multiple-losses/1826/5
I think this post helps clarify the issue.
But do the forward passes over the same network not overwrite the intermediate activation values? I don't know if autograd keeps references to the historic values of a variable or if it has copies. If the former, then it may very well be overwritten with each forward pass, making the gradients wrong.
@AAnoosheh fwd pass over the same network doesn't overwrite the saved variable values. Calling the same module twice will be two links in the dynamic graph. The saved variables for bwd will be stored differently without overwriting each other.
See my reply at https://discuss.pytorch.org/t/how-to-use-the-backward-functions-for-multiple-losses/1826/7?u=simonw
Ah okay, so the reason the OP in that discussion was failing when performing his change must have been because the original version improved the discriminator each time on the real input before forwarding the fake input.
Thanks for the rapid support, guys!
Most helpful comment
@AAnoosheh fwd pass over the same network doesn't overwrite the saved variable values. Calling the same module twice will be two links in the dynamic graph. The saved variables for bwd will be stored differently without overwriting each other.
See my reply at https://discuss.pytorch.org/t/how-to-use-the-backward-functions-for-multiple-losses/1826/7?u=simonw