Pytorch-lightning: How to load checkpoint from external source with different keys in dict

Created on 2 Aug 2020  ·  7Comments  ·  Source: PyTorchLightning/pytorch-lightning

❓ Questions and Help

Before asking:

  1. search the issues.
  2. search the docs.

What is your question?

If I have a checkpoint that uses different keys to store the model weights, how do I lost it, sorry if converting it into lightning format?

Eg: checkpoint["my_funky_model_state_dict_key"]

Code

What have you tried?

Searched the docs, python API and github issues.

What's your environment?

  • OS: Linux
  • Packaging: conda
  • Version [e.g. 0.5.2.1]
question won't fix

Most helpful comment

Try this. See if it works.

new_model = new_lightning_model()
new_weights = new_model.state_dict()
old_weights = list(torch.load(old_checkpoint)['state_dict'].items())

i=0
for k, _ in new_weights.items():
    new_weights[k] = old_weights[i][1]
    i += 1

new_model.load_state_dict(new_weights)
# then save it to get the checkpoint in lightning format.

All 7 comments

Do you mind share an example?
otherwise, I would say load it as you do, edit keys to match PL, and save it again...

Try this. See if it works.

new_model = new_lightning_model()
new_weights = new_model.state_dict()
old_weights = list(torch.load(old_checkpoint)['state_dict'].items())

i=0
for k, _ in new_weights.items():
    new_weights[k] = old_weights[i][1]
    i += 1

new_model.load_state_dict(new_weights)
# then save it to get the checkpoint in lightning format.

Try this. See if it works.

new_model = new_lightning_model()
new_weights = new_model.state_dict()
old_weights = list(torch.load(old_checkpoint)['state_dict'].items())

i=0
for k, _ in new_weights.items():
    new_weights[k] = old_weights[i][1]
    i += 1

new_model.load_state_dict(new_weights)
# then save it to get the checkpoint in lightning format.

This is a frequent happening problem when using pl_module to wrap around an existing module.

eg.

When load the pretrained weights, state_dict keys are always "bert.", when load our own pl trained checkpoint, keys are always "my_model.bert.".

And when we try to fine-tune downstream task, we might try to load both, and we have to write extra code for different weights.

Can we make the quoted code, into a state_dict cleaning function we can use it like following. if "bert." found but not "my_model.bert." found, replace "bert." to "my_model.bert."

cleaning = clean_state_keys(["my_model.bert.","bert.",],["replacement_key","to_be_replaced"],...)
model.load_state_dict(cleaning(state_dict),strict = False)

And the above code should work for both of the weights saved in bert. and my_model.bert., as in 1 situation it will replace nothing.

I can make a pull request out of it if this is considered helpful

@Borda Here's a sample of how I handle checkpoints in my current pytorch code:

Saving Checkpoint:

            torch.save(
                {
                    'model_state_dict': model_params,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'lr_scheduler_state_dict': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'total_iter_num': total_iter_num,
                    'epoch_loss': epoch_loss,
                    'config': config_yaml
                }, filename)

Loading Checkpoint

        checkpoint = torch.load(config.train.checkpoint.path, map_location='cpu')
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        elif 'state_dict' in checkpoint:
            # Checkpoint from other author (from github)
            checkpoint['state_dict'].pop('decoder.last_conv.8.weight')
            checkpoint['state_dict'].pop('decoder.last_conv.8.bias')
            model.load_state_dict(checkpoint['state_dict'], strict=False)
        else:
            # Checkpoint contains only model state dict, it's not stored in a dict
            model.load_state_dict(checkpoint)

        if config.train.lrScheduler.name == 'StepLR':
            lr_scheduler.last_epoch = checkpoint['epoch']
        if 'lr_scheduler_state_dict' in checkpoint:
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
        if 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

Thanks, @rohitgr7 , but I was hoping for a method that did not involve converting the checkpoint format.

I assumed there would be a way to overwrite the default dict keys. As in, to tell lightning that model state dict can be found in the key "my_state_dict" for example. 

Also, @Borda is this a bug here? on_load_checkpoint is called after state_dict is restored.
https://github.com/PyTorchLightning/pytorch-lightning/blob/ed8a01afb0f9a605933563712e447a6a751f3af2/pytorch_lightning/core/saving.py#L172-L183

Also, I see at some places the model is loaded directly using load_state_dict without using the load_from_checkpoint function in pl. I mean on_load_checkpoint will not be called there. For eg. here:
https://github.com/PyTorchLightning/pytorch-lightning/blob/ed8a01afb0f9a605933563712e447a6a751f3af2/pytorch_lightning/trainer/trainer.py#L1338-L1339

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

as754770178 picture as754770178  ·  3Comments

maxime-louis picture maxime-louis  ·  3Comments

remisphere picture remisphere  ·  3Comments

anthonytec2 picture anthonytec2  ·  3Comments

monney picture monney  ·  3Comments