Pytorch-lightning: Adding a Metric to LightningModule prevents loading of a checkpoint / weights

Created on 26 Oct 2020  ·  5Comments  ·  Source: PyTorchLightning/pytorch-lightning

🐛 Bug


Adding a Metric like Accuracy prevents the loading of a .ckpt due to missing keys:

RuntimeError: Error(s) in loading state_dict for BoringModel2:
    Missing key(s) in state_dict: "pl_accuracy.correct", "pl_accuracy.total". 

Please reproduce using the BoringModel and post here


https://colab.research.google.com/drive/1km0SE2TVRuif6R4uF8vihqUl-FshXu_u?usp=sharing

To Reproduce

class BoringModel2(BoringModel):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # NEW adding metric
        self.pl_accuracy = Accuracy()

model2 = BoringModel2.load_from_checkpoint(checkpoint_path="example.ckpt")

Expected behavior


Being able to add new metrics to a model without changing the layers (e.g. in transfer learning settings), but still be able to load the weights of a model without those metrics.

Actual behavior

RuntimeError                              Traceback (most recent call last)

<ipython-input-15-50d6f204e428> in <module>()
----> 1 model2 = BoringModel2.load_from_checkpoint(checkpoint_path="example.ckpt")  # , strict=False

2 frames

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1043         if len(error_msgs) > 0:
   1044             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1045                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1046         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1047 

RuntimeError: Error(s) in loading state_dict for BoringModel2:
    Missing key(s) in state_dict: "pl_accuracy.correct", "pl_accuracy.total". 

Environment

Note: Bugs with code are solved faster ! Colab Notebook should be made public !

You can get the script and run it with:

wget https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/tests/collect_env_details.py
# For security purposes, please check the contents of collect_env_details.py before running it.
python collect_env_details.py

Colab:

  • CUDA:

    • GPU:



      • Tesla T4



    • available: True

    • version: 10.1

  • Packages:

    • numpy: 1.18.5

    • pyTorch_debug: False

    • pyTorch_version: 1.6.0+cu101

    • pytorch-lightning: 1.0.3

    • tqdm: 4.41.1

  • System:

    • OS: Linux

    • architecture:



      • 64bit


      • -


    • processor: x86_64

    • python: 3.6.9

    • version: #1 SMP Thu Jul 23 08:00:38 PDT 2020

Additional context

Checkpoint Metrics bug / fix help wanted

Most helpful comment

Metrics are simple nn.Module and therefore also have there own state_dict. In the case of Accuracy metric, the correct and total tensors refer to the number of correctly labeled data points seen and the total number of data points seen. When initialized both of these numbers are 0. Having them as state_dict, for example allows the metric to be computed on some fraction of the dataset, saved, loaded and compute on the remaining part, and still get the right result.
Therefore, Metrics should be treated as any other nn.Module, reducing the options to either loading with strict=True or adding the missing state_dict keys to the checkpoint (the default values for any metrics can be access through metric._defaults).

All 5 comments

Adding strict=False to:

model2 = BoringModel2.load_from_checkpoint(checkpoint_path="example.ckpt", strict=False)

does allow to load the weights (thank you @ananthsub ), but this will increase the risk of improper variables in your .ckpt's that would have been noticed if strict=True.
This is not something you want in bigger projects where you want to make sure that keys match.


Since the metric Accuracy uses default values, I have no idea what "pl_accuracy.correct" or "pl_accuracy.total" should be initialized with, and even if I knew, needing to provide these values to load_from_checkpoint() is probably not a good approach.

Using BoringModel as a submodule is also not desirable, as you cannot do simply:

class BoringModelTransfer(LightningModule):

    def __init__(self):
        super().__init__()
        self.boring = BoringModel()

    def forward(self, x):
        return self.boring(x)

You would need to copy all the other functions, such as training_step() as well. Lightning should take away engineering; not increasing it (if you would take this approach).

Also, hparams will be in self.boring.hparams instead of self.hparams, giving new issues.

Best approach might be to modify the old checkpoints to make them compatible to a common format.
Assume the keys epoch and total_iter_num don't exist, here's a sample code to add those params to the checkpoint (with a default value of 0):

checkpoint = torch.load(path_checkpoint, map_location='cpu')
checkpoint['epoch'] = 0
checkpoint['total_iter_num'] = 0

torch.save(checkpoint, path_new_checkpoint)

Metrics are simple nn.Module and therefore also have there own state_dict. In the case of Accuracy metric, the correct and total tensors refer to the number of correctly labeled data points seen and the total number of data points seen. When initialized both of these numbers are 0. Having them as state_dict, for example allows the metric to be computed on some fraction of the dataset, saved, loaded and compute on the remaining part, and still get the right result.
Therefore, Metrics should be treated as any other nn.Module, reducing the options to either loading with strict=True or adding the missing state_dict keys to the checkpoint (the default values for any metrics can be access through metric._defaults).

As @SkafteNicki said, this comes down to more of a limitation of torch.load if you add new nn.Modules as attributes to a model, you will not be able to load an older version without specifying strict=False.

If you do not want this behavior when implementing your own metric, you can simply pass persistent=False when you call add_state. (only works in PyTorch 1.7+).

Do people think every metric should have a persistent flag that would apply this behavior to any state variable? This could be possible and solve your issue.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

iakremnev picture iakremnev  ·  3Comments

maxime-louis picture maxime-louis  ·  3Comments

williamFalcon picture williamFalcon  ·  3Comments

as754770178 picture as754770178  ·  3Comments

versatran01 picture versatran01  ·  3Comments