I want to load a checkpoint to finetune a model on other dataset, I have to change fc layer to fit num_classes. Now load_from_checkpoint will throws size mismatch error.
Can you add a option, to ignore size mismatch layer when load from checkpoint
Could you try setting string=False in load_from_checkpoint?
https://pytorch-lightning.readthedocs.io/en/latest/weights_loading.html?highlight=load_from_checkpoint
string=False will ignore params with different name. Not params with same name but different size.
Such as fc layer in checkpoint is 10 x 5, my new model change its size to 10 x 3. load_from_checkpoint with string=False will throw errors says size mismatch
then is it important that the last layer is called fc or could the new layer be called something else fc2?
or could you set the layer after loading it from checkpoint:
model = Model() # old architecture
model.load_from_checkpoint(...)
model.fc = nn.Linear(10,3) # update fc layer
I init my model in init method, size of fc will change according to dataset. But load_from_checkpoint is called from main.py. Now I have to implement my own load checkpoint function to load state dict. Can pytorch-lightning support this function in load_from_checkpoint by adding a option, such as skip_mismatch=True
you can update your checkpoint in on_load_checkpoint.
Also there seems to be a similar proposal in pytorch repo too.
https://github.com/pytorch/pytorch/pull/39144
https://github.com/pytorch/pytorch/issues/40859
To xiadingZ
I suffer the same error as you.
Can you share your own load check point function code?
this is my code
def load_finetune_checkpoint(self, path):
m = torch.load(path)['state_dict']
model_dict = self.state_dict()
for k in m.keys():
if 'fc_vidout' in k or 'fc_total' in k:
continue
if k in model_dict:
pname = k
pval = m[k]
model_dict[pname] = pval.clone().to(model_dict[pname].device)
self.load_state_dict(model_dict)
it can only restore state dict, but not hyperparameters, and I specify the param name, for a quick fix
Thanks xiadingZ
It works for me!
I hope lightning supports like _skip_mismatch=True_ as you mentioned above.
Thanks,
Thank you @rohitgr7 !
Overloading on_load_checkpoint() in my pl.Module works like a charm!
def on_load_checkpoint(self, checkpoint: dict) -> None:
state_dict = checkpoint["state_dict"]
model_state_dict = self.state_dict()
is_changed = False
for k in state_dict:
if k in model_state_dict:
if state_dict[k].shape != model_state_dict[k].shape:
logger.info(f"Skip loading parameter: {k}, "
f"required shape: {model_state_dict[k].shape}, "
f"loaded shape: {state_dict[k].shape}")
state_dict[k] = model_state_dict[k]
is_changed = True
else:
logger.info(f"Dropping parameter {k}")
is_changed = True
if is_changed:
checkpoint.pop("optimizer_states", None)
great!!
Most helpful comment
Thank you @rohitgr7 !
Overloading
on_load_checkpoint()in mypl.Moduleworks like a charm!