Pytorch-cyclegan-and-pix2pix: Regarding the backward process for G and D

Created on 26 Sep 2018  路  7Comments  路  Source: junyanz/pytorch-CycleGAN-and-pix2pix

Hi, thank you for your great work! @junyanz

I have a (dumb) question regarding how you do parameter optimization in the cycleGAN model:
Why do you write backward functions for D_A and D_B separately (backward_D_A() & backward_D_B()) , but write those of G_A and G_B in a single function, backward_G()? Is there any particular reason to do this?

Would it be problematic if I write separate backward functions for G_A and G_B as you do to D_A and D_B? I am just thinking about writing something like 'backward_G_basic' to make the code more compact.

Thank you!

Most helpful comment

It's just some coding trick. The difference is that (1) we can update D_A and D_B independently and we can write a backward_D_basic and reuse it to make the code more compact. (2) we have to update G_A and G_B at the same time due to the cycle consistency loss.

All 7 comments

I would like to add to the question above:
1) Can we thus also simply do more like this:

self.loss_D = self.loss_D_A + self.loss_D_B
self.loss_D.backward()

2) Why do we set requires_grad to False for Ds (line below) if we use a separate optimizers for Ds and G s anyway?
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/fdc7fcd1421ce386c57c6bdb9db09fcb0d22221a/models/cycle_gan_model.py#L140

Thanks in advance!

@maxdel

For your second question, I feel that setting 'requires_grad' to False for Ds is done for speed, not correctness. Here is a link of a similar question that may be helpful (https://github.com/pytorch/examples/issues/116).

And for the first question, my guess is YES, but I am just wondering is there any particular reason they do not write the code in this way.

It's just some coding trick. The difference is that (1) we can update D_A and D_B independently and we can write a backward_D_basic and reuse it to make the code more compact. (2) we have to update G_A and G_B at the same time due to the cycle consistency loss.

On (2), it doesn't seem like the code needs to update G_A and G_B in the same function?

 # GAN loss D_A(G_A(A))
 self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
 # GAN loss D_B(G_B(B))
 self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
 # Forward cycle loss
 self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
 # Backward cycle loss
 self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B

Where is the dependence in the cycle consistency loss?

The code updates them at the same time. See these lines.

self.fake_B = self.netG_A(self.real_A)
self.rec_A = self.netG_B(self.fake_B)

There is definitely a dependence in the forward pass which is what you linked to

But I'm not sure how that would affect the order in which the backward pass is done

I don't see how the current way would produce a different result from just doing it separately

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B

vs

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # Forward cycle loss
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A

func2
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Backward cycle loss
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B

The losses are just accumulated, so the order shouldn't matter right? Unless I'm just missing something about how pytorch works.

self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
self.loss_G.backward()

I was able to figure it out now. Indeed, it was just that I didn't understand how pytorch worked.
The answer is that if you do them separately, then once you call loss.backward() for one generator's loss, then both generators will be removed from memory unless you use the pytorch retain_graph=True argument to backward.
Thus, if we do them in the same function and call .backward on the summed loss for both, we won't have to deal with using retain_graph

Was this page helpful?
0 / 5 - 0 ratings

Related issues

filmo picture filmo  路  3Comments

John1231983 picture John1231983  路  3Comments

davidwessman picture davidwessman  路  3Comments

roypratyush picture roypratyush  路  3Comments

MaureenZOU picture MaureenZOU  路  4Comments