If I have a checkpoint that uses different keys to store the model weights, how do I lost it, sorry if converting it into lightning format?
Eg: checkpoint["my_funky_model_state_dict_key"]
Searched the docs, python API and github issues.
Do you mind share an example?
otherwise, I would say load it as you do, edit keys to match PL, and save it again...
Try this. See if it works.
new_model = new_lightning_model()
new_weights = new_model.state_dict()
old_weights = list(torch.load(old_checkpoint)['state_dict'].items())
i=0
for k, _ in new_weights.items():
new_weights[k] = old_weights[i][1]
i += 1
new_model.load_state_dict(new_weights)
# then save it to get the checkpoint in lightning format.
Try this. See if it works.
new_model = new_lightning_model() new_weights = new_model.state_dict() old_weights = list(torch.load(old_checkpoint)['state_dict'].items()) i=0 for k, _ in new_weights.items(): new_weights[k] = old_weights[i][1] i += 1 new_model.load_state_dict(new_weights) # then save it to get the checkpoint in lightning format.
This is a frequent happening problem when using pl_module to wrap around an existing module.
eg.
When load the pretrained weights, state_dict keys are always "bert.", when load our own pl trained checkpoint, keys are always "my_model.bert.".
And when we try to fine-tune downstream task, we might try to load both, and we have to write extra code for different weights.
Can we make the quoted code, into a state_dict cleaning function we can use it like following. if "bert." found but not "my_model.bert." found, replace "bert." to "my_model.bert."
cleaning = clean_state_keys(["my_model.bert.","bert.",],["replacement_key","to_be_replaced"],...)
model.load_state_dict(cleaning(state_dict),strict = False)
And the above code should work for both of the weights saved in bert. and my_model.bert., as in 1 situation it will replace nothing.
I can make a pull request out of it if this is considered helpful
@Borda Here's a sample of how I handle checkpoints in my current pytorch code:
Saving Checkpoint:
torch.save(
{
'model_state_dict': model_params,
'optimizer_state_dict': optimizer.state_dict(),
'lr_scheduler_state_dict': lr_scheduler.state_dict(),
'epoch': epoch,
'total_iter_num': total_iter_num,
'epoch_loss': epoch_loss,
'config': config_yaml
}, filename)
Loading Checkpoint
checkpoint = torch.load(config.train.checkpoint.path, map_location='cpu')
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
elif 'state_dict' in checkpoint:
# Checkpoint from other author (from github)
checkpoint['state_dict'].pop('decoder.last_conv.8.weight')
checkpoint['state_dict'].pop('decoder.last_conv.8.bias')
model.load_state_dict(checkpoint['state_dict'], strict=False)
else:
# Checkpoint contains only model state dict, it's not stored in a dict
model.load_state_dict(checkpoint)
if config.train.lrScheduler.name == 'StepLR':
lr_scheduler.last_epoch = checkpoint['epoch']
if 'lr_scheduler_state_dict' in checkpoint:
lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
if 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
Thanks, @rohitgr7 , but I was hoping for a method that did not involve converting the checkpoint format.
I assumed there would be a way to overwrite the default dict keys. As in, to tell lightning that model state dict can be found in the key "my_state_dict" for example.
@Shreeyak I think you can override these two functions in Trainer(by creating a SubTrainer) to get what you want:
https://github.com/PyTorchLightning/pytorch-lightning/blob/ed8a01afb0f9a605933563712e447a6a751f3af2/pytorch_lightning/trainer/training_io.py#L289-L297
https://github.com/PyTorchLightning/pytorch-lightning/blob/ed8a01afb0f9a605933563712e447a6a751f3af2/pytorch_lightning/trainer/training_io.py#L410-L416
Also, @Borda is this a bug here? on_load_checkpoint is called after state_dict is restored.
https://github.com/PyTorchLightning/pytorch-lightning/blob/ed8a01afb0f9a605933563712e447a6a751f3af2/pytorch_lightning/core/saving.py#L172-L183
Also, I see at some places the model is loaded directly using load_state_dict without using the load_from_checkpoint function in pl. I mean on_load_checkpoint will not be called there. For eg. here:
https://github.com/PyTorchLightning/pytorch-lightning/blob/ed8a01afb0f9a605933563712e447a6a751f3af2/pytorch_lightning/trainer/trainer.py#L1338-L1339
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!
Most helpful comment
Try this. See if it works.