Pytorch-lightning: Allow load_from_checkpoint with ignore size mismatch

Created on 16 Nov 2020  Â·  11Comments  Â·  Source: PyTorchLightning/pytorch-lightning

I want to load a checkpoint to finetune a model on other dataset, I have to change fc layer to fit num_classes. Now load_from_checkpoint will throws size mismatch error.

Can you add a option, to ignore size mismatch layer when load from checkpoint

enhancement help wanted

Most helpful comment

Thank you @rohitgr7 !
Overloading on_load_checkpoint() in my pl.Module works like a charm!

    def on_load_checkpoint(self, checkpoint: dict) -> None:
        state_dict = checkpoint["state_dict"]
        model_state_dict = self.state_dict()
        is_changed = False
        for k in state_dict:
            if k in model_state_dict:
                if state_dict[k].shape != model_state_dict[k].shape:
                    logger.info(f"Skip loading parameter: {k}, "
                                f"required shape: {model_state_dict[k].shape}, "
                                f"loaded shape: {state_dict[k].shape}")
                    state_dict[k] = model_state_dict[k]
                    is_changed = True
            else:
                logger.info(f"Dropping parameter {k}")
                is_changed = True

        if is_changed:
            checkpoint.pop("optimizer_states", None)

All 11 comments

string=False will ignore params with different name. Not params with same name but different size.
Such as fc layer in checkpoint is 10 x 5, my new model change its size to 10 x 3. load_from_checkpoint with string=False will throw errors says size mismatch

then is it important that the last layer is called fc or could the new layer be called something else fc2?

or could you set the layer after loading it from checkpoint:

model = Model() # old architecture 
model.load_from_checkpoint(...)
model.fc = nn.Linear(10,3) # update fc layer

I init my model in init method, size of fc will change according to dataset. But load_from_checkpoint is called from main.py. Now I have to implement my own load checkpoint function to load state dict. Can pytorch-lightning support this function in load_from_checkpoint by adding a option, such as skip_mismatch=True

you can update your checkpoint in on_load_checkpoint.

Also there seems to be a similar proposal in pytorch repo too.
https://github.com/pytorch/pytorch/pull/39144
https://github.com/pytorch/pytorch/issues/40859

To xiadingZ

I suffer the same error as you.
Can you share your own load check point function code?

this is my code

   def load_finetune_checkpoint(self, path):
        m = torch.load(path)['state_dict']
        model_dict = self.state_dict()
        for k in m.keys():
            if 'fc_vidout' in k or 'fc_total' in k:
                continue

            if k in model_dict:
                pname = k
                pval = m[k]
                model_dict[pname] = pval.clone().to(model_dict[pname].device)

        self.load_state_dict(model_dict)

it can only restore state dict, but not hyperparameters, and I specify the param name, for a quick fix

Thanks xiadingZ

It works for me!
I hope lightning supports like _skip_mismatch=True_ as you mentioned above.

Thanks,

Thank you @rohitgr7 !
Overloading on_load_checkpoint() in my pl.Module works like a charm!

    def on_load_checkpoint(self, checkpoint: dict) -> None:
        state_dict = checkpoint["state_dict"]
        model_state_dict = self.state_dict()
        is_changed = False
        for k in state_dict:
            if k in model_state_dict:
                if state_dict[k].shape != model_state_dict[k].shape:
                    logger.info(f"Skip loading parameter: {k}, "
                                f"required shape: {model_state_dict[k].shape}, "
                                f"loaded shape: {state_dict[k].shape}")
                    state_dict[k] = model_state_dict[k]
                    is_changed = True
            else:
                logger.info(f"Dropping parameter {k}")
                is_changed = True

        if is_changed:
            checkpoint.pop("optimizer_states", None)

great!!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

as754770178 picture as754770178  Â·  3Comments

edenlightning picture edenlightning  Â·  3Comments

DavidRuhe picture DavidRuhe  Â·  3Comments

justusschock picture justusschock  Â·  3Comments

awaelchli picture awaelchli  Â·  3Comments