Pytorch-lightning: RuntimeError: Error(s) in loading state_dict when adding/updating metrics to a trained model.

Created on 13 Nov 2020  ·  10Comments  ·  Source: PyTorchLightning/pytorch-lightning

❓ Questions and Help

For context, I trained a lot of models for many weeks, tracking the loss and accuracy for train, validation, and test steps.
Now, I wanted to evaluate more metrics for the test data set, more specifically, I added recall, confusion matrix, and precision metrics (from ptl.metrics module) to the test_step and test_epoch_end methods in the lighting module.

Also, I replaced my custom accuracy with the class-based Accuracy implemented on the ptl.metrics package.

When I try to test my model and get the metrics for the trained model on the test set, I get this error loading the checkpoint:

Traceback (most recent call last):
  File "model_manager.py", line 283, in <module>
    helper.test()
  File "model_manager.py", line 119, in test
    self.trainer.test(self.module)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 748, in test
    results = self.__test_given_model(model, test_dataloaders)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 813, in __test_given_model
    results = self.fit(model)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 459, in fit
    results = self.accelerator_backend.train()
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py", line 61, in train
    self.trainer.train_loop.setup_training(model)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 174, in setup_training
    self.trainer.checkpoint_connector.restore_weights(model)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 75, in restore_weights
    self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 107, in restore
    self.restore_model_state(model, checkpoint)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 128, in restore_model_state
    model.load_state_dict(checkpoint['state_dict'])
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1044, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for L_WavenetLSTMClassifier:
        Unexpected key(s) in state_dict: "train_acc.correct", "train_acc.total", "val_acc.correct", "val_acc.total", "test_acc.correct", "test_acc.total".

What is your question?

In my case, it's impossible to train again the models because it takes many weeks. So I just wonder if there is a way to load the already trained model anyway and obtain the updated test metrics by a test cycle.

Actually I just care about loading the parameters of the model to run the test cycle. I can't understand why it's so important to load other things up, those old metrics don't appear so vital to me.

What is you tired?

I read that this exception it's generated by torch model.load_state_dict method, and can be avoided with strict=falseparameter.

In my case I load the trained model with the resume_from_checkpoint parameter of the pytorch-lightning trainer class, so i have no clue to try to get closer to load this.

What's your environment?

  • OS: [e.g. iOS, Linux, Win]: Win
  • Version [e.g. 0.5.2.1]: Latest master branch (November 13rd, 2020)
Metrics question

All 10 comments

Reading further about state_dicts I noted that based on PyTorch docs, state_dicts should store model's parameters and hyper-parameters. I'm not sure why is there metrics-related data stored in this dict.

https://github.com/PyTorchLightning/pytorch-lightning/blob/baa8558cc0e6d2a3e24f2669e6a59ffdb8138737/pytorch_lightning/metrics/metric.py#L88-L90
this might be the reason since persistent is True by default, it will be saved to state_dict

Well thanks @rohitgr7, ill try changing this parameter and i'll report back.
I hope there is something i can do to the modules to load this checkpoints 😖.

maybe try after removing these keys from the checkpoint?

After removing each metric data from the state_dict(with delx['state_dict']['test_acc.total'], for example). I saved back the checkpoint with torch.save(x, 'epoch=999.ckpt') and now i get this weird exception:

Traceback (most recent call last):
  File "model_manager.py", line 283, in <module>
    helper.test()
  File "model_manager.py", line 119, in test
    self.trainer.test(self.module)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 748, in test
    results = self.__test_given_model(model, test_dataloaders)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 813, in __test_given_model
    results = self.fit(model)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 459, in fit
    results = self.accelerator_backend.train()
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py", line 64, in train
    results = self.train_or_test()
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 64, in train_or_test
    results = self.trainer.run_test()
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 655, in run_test
    eval_loop_results, _ = self.run_evaluation(test_mode=True)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 603, in run_evaluation
    output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 173, in evaluation_step
    output = self.trainer.accelerator_backend.test_step(args)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py", line 104, in test_step
    output = self.__test_step(args)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py", line 112, in __test_step
    output = self.trainer.model.test_step(*args)
  File "/data/voyanedel/code/aidio/lightning_modules.py", line 388, in test_step
    self.metrics['test_acc2'](y_pred, y_target)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/metrics/metric.py", line 155, in forward
    self.update(*args, **kwargs)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/metrics/metric.py", line 201, in wrapped_func
    return update(*args, **kwargs)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/metrics/classification/accuracy.py", line 98, in update
    self.correct += torch.sum(preds == target)
RuntimeError: Trying to pass too many CPU scalars to CUDA kernel!
Exception ignored in: <function tqdm.__del__ at 0x7f142d3174c0>
Traceback (most recent call last):
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/tqdm/std.py", line 1087, in __del__
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/tqdm/std.py", line 1294, in close
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/tqdm/std.py", line 1472, in display
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/tqdm/std.py", line 1090, in __repr__
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/tqdm/std.py", line 1434, in format_dict
TypeError: cannot unpack non-iterable NoneType object

At least the main issue seems fixed, now I'm dealing with this one.
Will keep informed.

Surprisingly i had to change the line 99 of the File /data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/metrics/classification/accuracy.py, in update this:
self.correct += torch.sum(preds == target)
to:
self.correct = self.correct + torch.sum(preds == target)
as exposed here: https://discuss.pytorch.org/t/trying-to-pass-too-many-cpu-scalars-to-cuda-kernel/87757/5

And now the test loop it's running at least without the new metrics.

Thank you for your help.

PD: I have other issues that prevent me to evaluate Recall, Precision, F1Score, and Confusion Table.
All metric evaluations in the test_step method, throw the same exception:

Traceback (most recent call last):
  File "model_manager.py", line 283, in <module>
    helper.test()
  File "model_manager.py", line 119, in test
    self.trainer.test(self.module)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 718, in test
    results = self.__test_given_model(model, test_dataloaders)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 783, in __test_given_model
    results = self.fit(model)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 444, in fit
    results = self.accelerator_backend.train()
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 138, in train
    results = self.ddp_train(process_idx=self.task_idx, model=model)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 272, in ddp_train
    results = self.train_or_test()
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 66, in train_or_test
    results = self.trainer.run_test()
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 627, in run_test
    eval_loop_results, _ = self.run_evaluation(test_mode=True)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 578, in run_evaluation
    output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 169, in evaluation_step
    output = self.trainer.accelerator_backend.test_step(args)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 156, in test_step
    output = self.training_step(args)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 148, in training_step
    output = self.trainer.model(*args)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/overrides/data_parallel.py", line 179, in forward
    output = self.module.test_step(*inputs[0], **kwargs[0])
  File "/data/voyanedel/code/aidio/lightning_modules.py", line 390, in test_step
    self.metrics['test_precision'](y_pred, y_target)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/metrics/metric.py", line 155, in forward
    self.update(*args, **kwargs)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/metrics/metric.py", line 201, in wrapped_func
    return update(*args, **kwargs)
  File "/data/anaconda3/envs/aidio2/lib/python3.8/site-packages/pytorch_lightning/metrics/classification/precision_recall.py", line 132, in update
    self.true_positives += torch.sum(preds * target, dim=1)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Exception ignored in: <function tqdm.__del__ at 0x7f9432e204c0>

I'll keep debunging and inform any news. At the worst case, i can log the y_pred and y_target tensors completely, and get the metrics over them with other tools.

This is basically a duplicate (at least the initial problem) of https://github.com/PyTorchLightning/pytorch-lightning/issues/4361.
It was solved in v1.0.6. The compromise was the addition of a method .persistent(mode) that lets you switch on and off if a metrics state should be added to the state_dict after initializing the metric:

metric = pl.metrics.Accuracy()
print(metric.state_dict()) # prints OrderedDict([('correct', tensor(0)), ('total', tensor(0))])
metric.persistent(False)
print(metric.state_dict()) # prints OrderedDict()

Thanks for the help!
Also thank you for listening to us C:

Anyway here is the script I used to fix the state_dict. Maybe it'd be useful for someone:

"""
Remove metric records on state_dict.

"""

import argparse
from pathlib import Path

import torch

state_dict_keys_to_remove = ['test_acc.total', 'test_acc.correct',
                             'val_acc.total', 'val_acc.correct',
                             'train_acc.total', 'train_acc.correct', ]


def main():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--src_path', help='', )
    parser.add_argument('--dest_path', help='', )
    args = parser.parse_args()
    src_path = Path(args.src_path)
    dest_path = Path(args.dest_path)
    ckpt = torch.load(src_path)
    print('info: state_dict keys: {}'.format(ckpt['state_dict'].keys()))
    for k in state_dict_keys_to_remove:
        del ckpt['state_dict'][k]
    torch.save(ckpt, dest_path)


if __name__ == '__main__':
    main()

@Vichoko no problem. Personally I was in favor of having metric states part of the state dict, but when enough people raises the concern I am of course willing to change my opinion :]

@SkafteNicki That's very laudable of you ^^
I really appreciate it.

I still wonder in which case would be useful to store the metric states on the state_dict.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

iakremnev picture iakremnev  ·  3Comments

as754770178 picture as754770178  ·  3Comments

versatran01 picture versatran01  ·  3Comments

edenlightning picture edenlightning  ·  3Comments

monney picture monney  ·  3Comments