Maskrcnn-benchmark: instructions for transfer learning

Created on 27 Nov 2018  ยท  8Comments  ยท  Source: facebookresearch/maskrcnn-benchmark

โ“ Questions and Help

Hello!

I'd would be great to have set of the instructions to get going with transfer learning:
-- set the base model for transfer learning
-- specify how many layers to freeze
-- details how to onboard your own classes.

Thank you!

awaiting response needs discussion

Most helpful comment

For transfer learning, your dataset might have a different number of classes and all the changes are around this:

Step 1: You can change the _C.MODEL.ROI_BOX_HEAD.NUM_CLASSES by calling the cfg.merge_from_list method according to how many classes your dataset has.

Step 2: Create a customized dataset class that generates a tuple (img, bbox, label) each time. img is generated by Image.open(...).convert('RGB'); bbox is a 2d numpy array and label is 1d array. Then you can create another class on top of the class. For this inheritance step, this code is helpful: https://github.com/facebookresearch/maskrcnn-benchmark/blob/f25c6cff92d32d92abe8965d68401004e90c8bee/maskrcnn_benchmark/data/datasets/coco.py#L35

Step 3: When loading the pretrained model, you can exclude the roi head layers as follows (the gist is the weights for each layer is stored in a dict).
_pretrained_model = torch.load(checkpoint_path)['state']
model_dict = self.model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_model.items() if k in model_dict and "roi_heads" not in k}
model_dict.update(pretrained_dict)_

BTW, when encoding the labels of your dataset, the label should be indexed starting from zero. Otherwise when computing the cros entropy loss on a cpu, pytorch complains saying "Assertion `cur_target >= 0 && cur_target < n_classes". When computing on a gpu, the error message is not informative. I got stuck on this for a hour.

At last, I must say this package is very good in modular design! Thanks for your excellent work and kind sharing!

All 8 comments

Hi,

Thanks for the opening the issue.

While I believe this could be a good thing to have, best results are actually dataset dependent and requires different modifications, see for example the comment in https://github.com/facebookresearch/maskrcnn-benchmark/issues/15#issuecomment-433168443

What do you think?

Thank you for your response Francisco!
Agree, that generally speaking diff dataset will require different modifications.
It would be nice outline in README "basic" (retraining only classification layers) or "more advanced" options requiring model' open heart surgery.

I agree, a basic section in the README would be a great contribution!

@fmassa I think it would be better if you can provide a simple script to trim the pre-trained model weights (fox example, trimming the class-specific layers). We can modify this script for more advanced need.

@wangg12 Such a script would be dataset-dependent. There are a few scripts already available in the Detectron github repo, one could start from there.

For transfer learning, your dataset might have a different number of classes and all the changes are around this:

Step 1: You can change the _C.MODEL.ROI_BOX_HEAD.NUM_CLASSES by calling the cfg.merge_from_list method according to how many classes your dataset has.

Step 2: Create a customized dataset class that generates a tuple (img, bbox, label) each time. img is generated by Image.open(...).convert('RGB'); bbox is a 2d numpy array and label is 1d array. Then you can create another class on top of the class. For this inheritance step, this code is helpful: https://github.com/facebookresearch/maskrcnn-benchmark/blob/f25c6cff92d32d92abe8965d68401004e90c8bee/maskrcnn_benchmark/data/datasets/coco.py#L35

Step 3: When loading the pretrained model, you can exclude the roi head layers as follows (the gist is the weights for each layer is stored in a dict).
_pretrained_model = torch.load(checkpoint_path)['state']
model_dict = self.model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_model.items() if k in model_dict and "roi_heads" not in k}
model_dict.update(pretrained_dict)_

BTW, when encoding the labels of your dataset, the label should be indexed starting from zero. Otherwise when computing the cros entropy loss on a cpu, pytorch complains saying "Assertion `cur_target >= 0 && cur_target < n_classes". When computing on a gpu, the error message is not informative. I got stuck on this for a hour.

At last, I must say this package is very good in modular design! Thanks for your excellent work and kind sharing!

@wmmxk thanks for the step-by-step! Do you think you could send a PR improving the finetuning section in the README?

@fmassa I have sent a PR.

Was this page helpful?
0 / 5 - 0 ratings