Pytorch-lightning: Mismatch between docstring and code regarding when `on_load_checkpoint` hook is called

Created on 8 Oct 2020  路  3Comments  路  Source: PyTorchLightning/pytorch-lightning

馃悰 Bug

The docstring of on_load_checkpoint聽hook says that it is called before trying to load_state_dict:
https://github.com/PyTorchLightning/pytorch-lightning/blob/cea5f1f53876399dfaa0d37accdc527af7ca39af/pytorch_lightning/core/saving.py#L203-L206

However, in LightningModule.load_from_checkpoint, it is called after load_state_dict:
https://github.com/PyTorchLightning/pytorch-lightning/blob/cea5f1f53876399dfaa0d37accdc527af7ca39af/pytorch_lightning/core/saving.py#L195-L199

Additional context

Related discussion on Slack: https://pytorch-lightning.slack.com/archives/CQXV8BRH9/p1602168345184000

I think the docstring is correct and the call to on_load_checkpoint聽should be moved right before load_state_dict聽to give the model a chance to call setup.

bug / fix documentation help wanted

All 3 comments

@hbredin mind sending a PR to fix it... 馃惏

I can do that. Should I fix the docstring or the code?
I'd go with the code.

I need this code change as well! I'm doing transfer learning and I want to support both loading the original model with the original weights, and modify it for a new task.
on_load_checkpoint would allow me to redo the modifications I've done for transfer learning, so the state_dict of the transferred model can be properly restored.

At present I need to add non-network code to the model to handle this logic, which is ugly and prone to bugs.

This would allow to have the same model to redo the modifications I've made for transfer learning,

Was this page helpful?
0 / 5 - 0 ratings