I changed the original generators and discriminators inputs and outputs size.
Ga takes as input a 3 channel image and outputs a 1 channel image, and Gb takes as input a 1 channel image and outputs a 3 channel image. The training phase went really well but the when loading the model I got size mismatch error:
Missing key(s) in state_dict: "model.10.conv_block.6.weight", "model.10.conv_block.6.bias", "model.11.conv_block.6.weight", "model.11.conv_block.6.bias", "model.12.conv_block.6.weight", "model.12.conv_block.6.bias", "model.13.conv_block.6.weight", "model.13.conv_block.6.bias", "model.14.conv_block.6.weight", "model.14.conv_block.6.bias", "model.15.conv_block.6.weight", "model.15.conv_block.6.bias", "model.16.conv_block.6.weight", "model.16.conv_block.6.bias", "model.17.conv_block.6.weight", "model.17.conv_block.6.bias", "model.18.conv_block.6.weight", "model.18.conv_block.6.bias".
Unexpected key(s) in state_dict: "model.10.conv_block.5.weight", "model.10.conv_block.5.bias", "model.11.conv_block.5.weight", "model.11.conv_block.5.bias", "model.12.conv_block.5.weight", "model.12.conv_block.5.bias", "model.13.conv_block.5.weight", "model.13.conv_block.5.bias", "model.14.conv_block.5.weight", "model.14.conv_block.5.bias", "model.15.conv_block.5.weight", "model.15.conv_block.5.bias", "model.16.conv_block.5.weight", "model.16.conv_block.5.bias", "model.17.conv_block.5.weight", "model.17.conv_block.5.bias", "model.18.conv_block.5.weight", "model.18.conv_block.5.bias".
size mismatch for model.26.weight: copying a param with shape torch.Size([1, 64, 7, 7]) from checkpoint, the shape in current model is torch.Size([3, 64, 7, 7]).
size mismatch for model.26.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([3]).
Any idea why on the testing phase I get this error ?
Make sure that you use the same network and the same --norm flag for both training and test.
I am using --norm instance for training and testing phase. It should be the same network, after I train I rename the epoch_net_G_A.pth to latest_net_G.pth and run the test.py with --preprocess none. As you can see in the error posted above the weight file has a torch.Size([1, 64, 7, 7]) parameter, which is good because I trained it this way but he expects a torch.Size([3, 64, 7, 7]) so my guess is that the model that is built on the testing phase has 3 channels as output.
Here are the changes I made:
# GENERATOR A
generator_a_channels_input = 3
generator_a_channels_output = 1
# GENERATOR B
generator_b_channels_input = 1
generator_b_channels_output = 3
self.netG_A = networks.define_G(generator_a_channels_input, generator_a_channels_output, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
self.netG_B = networks.define_G(generator_b_channels_input, generator_b_channels_output, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
Does this code influence the reconstruction of the model ?
parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels: 3 for RGB and 1 for grayscale')
*Keep in mind that I commented the assert where the in_channels == out_channels
You need to specify the --input_nc and --output_nc during test time.
Changing --input_nc and --output_ncremoves the weight error.
parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels: 3 for RGB and 1 for grayscale')
But the missing keys error persists:
RuntimeError: Error(s) in loading state_dict for ResnetGenerator:
Missing key(s) in state_dict: "model.10.conv_block.6.weight", "model.10.conv_block.6.bias", "model.11.conv_block.6.weight", "model.11.conv_block.6.bias", "model.12.conv_block.6.weight", "model.12.conv_block.6.bias", "model.13.conv_block.6.weight", "model.13.conv_block.6.bias", "model.14.conv_block.6.weight", "model.14.conv_block.6.bias", "model.15.conv_block.6.weight", "model.15.conv_block.6.bias", "model.16.conv_block.6.weight", "model.16.conv_block.6.bias", "model.17.conv_block.6.weight", "model.17.conv_block.6.bias", "model.18.conv_block.6.weight", "model.18.conv_block.6.bias".
Unexpected key(s) in state_dict: "model.10.conv_block.5.weight", "model.10.conv_block.5.bias", "model.11.conv_block.5.weight", "model.11.conv_block.5.bias", "model.12.conv_block.5.weight", "model.12.conv_block.5.bias", "model.13.conv_block.5.weight", "model.13.conv_block.5.bias", "model.14.conv_block.5.weight", "model.14.conv_block.5.bias", "model.15.conv_block.5.weight", "model.15.conv_block.5.bias", "model.16.conv_block.5.weight", "model.16.conv_block.5.bias", "model.17.conv_block.5.weight", "model.17.conv_block.5.bias", "model.18.conv_block.5.weight", "model.18.conv_block.5.bias".
It was the --no_dropout missing from the command. I used the test.py file. I tried running the file from terminal not from PyCharm and it worked. Thank you !
Most helpful comment
It was the
--no_dropoutmissing from the command. I used thetest.pyfile. I tried running the file from terminal not from PyCharm and it worked. Thank you !