I am trying to implement CycleGAN (https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) with mxnet. I need to realize two generator networks for style translation between A and B, and the cycle loss will be
inputA -> GeneratorB -> fakeB -> GeneratorA -> A'
inputB -> GeneratorA -> fakeA -> GeneratorB -> B'
loss = L1loss(inputA, A') + L1loss(inputB, B')
I tried to use MakeLoss function, but I am lost with how to make the cycle symbol. Is there an easy way to implement this kind of cycle network in mxnet?
Really appreciate any suggestions and advice.
I made 4 symbols and trained them one by one.
@WillSuen you can refer to my implementation:
https://github.com/Ldpe2G/DeepLearningForFun/tree/master/Mxnet-Scala/CycleGAN
My implementation is the same with @Godricly , made 4 symbols(Generator A, Generator B, Dscriminator A, Descriminator B). and train them one by one, you can also refer to the torch implementation:
https://github.com/junyanz/CycleGAN, My implementation was based on it.
Thanks a lot! I'll have a try.
Most helpful comment
My implementation is the same with @Godricly , made 4 symbols(Generator A, Generator B, Dscriminator A, Descriminator B). and train them one by one, you can also refer to the torch implementation:
https://github.com/junyanz/CycleGAN, My implementation was based on it.