Mmdetection: How to finetune from pretrained models trained on coco data with different number of classes?

Created on 1 Jun 2019  路  11Comments  路  Source: open-mmlab/mmdetection

Is there a config option to load pretrained coco models for finetuning? The last layers where the number of classes may be different, so those weights should not be loaded.
If I just use faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth, Then I got

While copying the parameter named bbox_head.fc_cls.weight, whose dimensions in the model are torch.Size([21, 1024]) and whose dimensions in the checkpoint are torch.Size([81, 1024]).

Or, how should I modify the pretrained weight for my model?

Thank you!

Most helpful comment

Thank you so much @spytensor , I've solved the problem by this

import torch
pretrained_weights  = torch.load('faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth')

num_class = 21
pretrained_weights['state_dict']['bbox_head.fc_cls.weight'].resize_(num_class, 1024)
pretrained_weights['state_dict']['bbox_head.fc_cls.bias'].resize_(num_class)
pretrained_weights['state_dict']['bbox_head.fc_reg.weight'].resize_(num_class*4, 1024)
pretrained_weights['state_dict']['bbox_head.fc_reg.bias'].resize_(num_class*4)

torch.save(pretrained_weights, "faster_rcnn_r50_fpn_1x_%d.pth"%num_class)

All 11 comments

Thank you so much @spytensor , I've solved the problem by this

import torch
pretrained_weights  = torch.load('faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth')

num_class = 21
pretrained_weights['state_dict']['bbox_head.fc_cls.weight'].resize_(num_class, 1024)
pretrained_weights['state_dict']['bbox_head.fc_cls.bias'].resize_(num_class)
pretrained_weights['state_dict']['bbox_head.fc_reg.weight'].resize_(num_class*4, 1024)
pretrained_weights['state_dict']['bbox_head.fc_reg.bias'].resize_(num_class*4)

torch.save(pretrained_weights, "faster_rcnn_r50_fpn_1x_%d.pth"%num_class)

@cowry5, @hellock Hi, I also want to finetune the pretrained coco models using pascal voc 2007, I follow your method to resize the bbox_head of pretrained weights, and I convert the pascal voc dataset into CustomDataset as described in 'Use my own datasets', but when I train the model, the training process exits without any traceback, could you give me some advice to finetune the pretrained model, thanks!

my configfile is faster_rcnn_x101_64x4d_fpn_1x.py, the image and annotation path have changed to converted pascal voc data, and I use Ubuntu16.04.5, single TITAN XP GPU.

The problem is show as follows:

(py36-torch) Jayson:~/workspace/mmdetection$ CUDA_VISIBLE_DEVICES=6 python tools/train.py configs/faster_rcnn_x101_64x4d_fpn_1x.py --resume_from checkpoints/faster_rcnn_x101_64x4d_fpn_1x_21.pth
2019-06-05 21:35:46,285 - INFO - Distributed training: False
2019-06-05 21:35:48,199 - INFO - load model from: open-mmlab://resnext101_64x4d
2019-06-05 21:35:48,538 - WARNING - missing keys in source state_dict: layer3.4.bn2.num_batches_tracked, layer3.15.bn2.num_batches_tracked, layer3.20.bn2.num_batches_tracked, layer3.2.bn3.num_batches_tracked, layer3.3.bn2.num_batches_tracked, layer3.20.bn3.num_batches_tracked, layer4.2.bn1.num_batches_tracked, layer3.5.bn3.num_batches_tracked, layer3.9.bn2.num_batches_tracked, layer3.7.bn3.num_batches_tracked, layer3.17.bn3.num_batches_tracked, layer3.19.bn3.num_batches_tracked, layer3.4.bn3.num_batches_tracked, layer3.16.bn2.num_batches_tracked, layer4.0.bn1.num_batches_tracked, layer2.3.bn2.num_batches_tracked, layer3.9.bn3.num_batches_tracked, layer2.1.bn3.num_batches_tracked, layer1.2.bn3.num_batches_tracked, layer2.3.bn3.num_batches_tracked, layer3.5.bn1.num_batches_tracked, layer3.22.bn2.num_batches_tracked, layer3.12.bn3.num_batches_tracked, layer3.7.bn1.num_batches_tracked, layer3.0.bn1.num_batches_tracked, layer3.11.bn2.num_batches_tracked, layer3.22.bn3.num_batches_tracked, layer1.2.bn1.num_batches_tracked, layer3.20.bn1.num_batches_tracked, layer3.19.bn1.num_batches_tracked, layer3.10.bn2.num_batches_tracked, layer2.3.bn1.num_batches_tracked, layer3.2.bn1.num_batches_tracked, layer3.11.bn3.num_batches_tracked, layer3.18.bn1.num_batches_tracked, layer3.3.bn1.num_batches_tracked, layer3.10.bn1.num_batches_tracked, layer1.0.bn2.num_batches_tracked, layer3.15.bn1.num_batches_tracked, layer2.1.bn1.num_batches_tracked, layer3.3.bn3.num_batches_tracked, layer3.17.bn2.num_batches_tracked, layer2.2.bn2.num_batches_tracked, layer2.2.bn3.num_batches_tracked, layer4.0.downsample.1.num_batches_tracked, layer3.6.bn1.num_batches_tracked, layer2.0.bn2.num_batches_tracked, layer3.12.bn1.num_batches_tracked, layer3.11.bn1.num_batches_tracked, layer4.2.bn3.num_batches_tracked, layer3.16.bn1.num_batches_tracked, layer3.13.bn1.num_batches_tracked, layer3.6.bn2.num_batches_tracked, layer3.1.bn2.num_batches_tracked, layer3.5.bn2.num_batches_tracked, layer3.10.bn3.num_batches_tracked, layer3.9.bn1.num_batches_tracked, layer4.1.bn3.num_batches_tracked, layer2.2.bn1.num_batches_tracked, layer3.7.bn2.num_batches_tracked, layer4.0.bn2.num_batches_tracked, layer3.2.bn2.num_batches_tracked, layer3.21.bn3.num_batches_tracked, layer3.4.bn1.num_batches_tracked, layer3.22.bn1.num_batches_tracked, layer4.2.bn2.num_batches_tracked, layer3.21.bn2.num_batches_tracked, layer3.0.bn2.num_batches_tracked, layer1.1.bn3.num_batches_tracked, layer2.1.bn2.num_batches_tracked, layer1.0.bn1.num_batches_tracked, layer2.0.bn3.num_batches_tracked, layer3.12.bn2.num_batches_tracked, layer1.0.bn3.num_batches_tracked, layer3.17.bn1.num_batches_tracked, layer3.8.bn2.num_batches_tracked, layer3.6.bn3.num_batches_tracked, layer3.8.bn1.num_batches_tracked, layer1.2.bn2.num_batches_tracked, layer3.16.bn3.num_batches_tracked, layer3.18.bn2.num_batches_tracked, layer3.14.bn1.num_batches_tracked, layer3.8.bn3.num_batches_tracked, layer4.1.bn1.num_batches_tracked, layer2.0.downsample.1.num_batches_tracked, layer3.1.bn1.num_batches_tracked, layer1.1.bn1.num_batches_tracked, layer3.18.bn3.num_batches_tracked, layer3.0.bn3.num_batches_tracked, layer3.21.bn1.num_batches_tracked, layer3.0.downsample.1.num_batches_tracked, layer2.0.bn1.num_batches_tracked, layer4.0.bn3.num_batches_tracked, layer1.0.downsample.1.num_batches_tracked, layer3.13.bn2.num_batches_tracked, layer3.14.bn3.num_batches_tracked, layer3.13.bn3.num_batches_tracked, layer1.1.bn2.num_batches_tracked, layer3.14.bn2.num_batches_tracked, bn1.num_batches_tracked, layer3.19.bn2.num_batches_tracked, layer3.1.bn3.num_batches_tracked, layer3.15.bn3.num_batches_tracked, layer4.1.bn2.num_batches_tracked

2019-06-05 21:35:55,316 - INFO - load checkpoint from checkpoints/faster_rcnn_x101_64x4d_fpn_1x_21.pth
2019-06-05 21:35:55,584 - INFO - resumed epoch 12, iter 87960
2019-06-05 21:35:55,585 - INFO - Start running, host: Jayson, work_dir: /home/Jayson/workspace/mmdetection/work_dirs/faster_rcnn_x101_64x4d_fpn_1x
2019-06-05 21:35:55,585 - INFO - workflow: [('train', 1)], max: 12 epochs
(py36-torch) Jayson:~/workspace/mmdetection$

@cowry5, @hellock I have tried to directly train the model without pretrained checkpoints, it can work well, but when I start to train the model from the pretrained checkpoints, the above issue occurs! can you help me to solve this problem?

@cowry5, @hellock I have tried to directly train the model without pretrained checkpoints, it can work well, but when I start to train the model from the pretrained checkpoints, the above issue occurs! can you help me to solve this problem?

@Alawaka Did u modify the total_epochs in configs/faster_rcnn_x101_64x4d_fpn_1x.py ?
If not, the param --resume_from will train the detector in epoch 12 and will end directly.

load_from and resume_from have different behaviors. If you just want to load some checkpoints for fine-tuning, you just need load_from.

@hellock It seems that this project lacks a document specifying the parameters.

I will add the instructions to GETTING_STARTED.

Thank you so much @spytensor , I've solved the problem by this

import torch
pretrained_weights  = torch.load('faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth')

num_class = 21
pretrained_weights['state_dict']['bbox_head.fc_cls.weight'].resize_(num_class, 1024)
pretrained_weights['state_dict']['bbox_head.fc_cls.bias'].resize_(num_class)
pretrained_weights['state_dict']['bbox_head.fc_reg.weight'].resize_(num_class*4, 1024)
pretrained_weights['state_dict']['bbox_head.fc_reg.bias'].resize_(num_class*4)

torch.save(pretrained_weights, "faster_rcnn_r50_fpn_1x_%d.pth"%num_class)

Can you please share your experience with fine tuning like this? What lr schedule did you use? How many epochs?

@cowry5 Hey this resizing may work in the case of fasterrcnn, will this work also work for other networks? What i'm asking is there a generic solution to this?

thanks

There needs to be a clear guide on fine-tuning the models in your library. I also believe that the config file should be enough to take care of this classes mismatch problem while fine-tuning.

I kept searching in the documentations, then checked the issues list to see if someone reported it.

It will be kind of you to create a readme for other users as well?

I will add the instructions to GETTING_STARTED.

Any update on fine tuning models? For example pretrained coco models on coco dataset. Can just resizing the mismatch keys if we load from the model work ? @hellock

Was this page helpful?
0 / 5 - 0 ratings

Related issues

BeBeauty picture BeBeauty  路  3Comments

fengxiuyaun picture fengxiuyaun  路  3Comments

songyuc picture songyuc  路  3Comments

michaelisc picture michaelisc  路  3Comments

happog picture happog  路  3Comments