Pytorch-lightning: Custom Checkpoint callback for multiple models

Created on 13 Aug 2020  ·  3Comments  ·  Source: PyTorchLightning/pytorch-lightning

❓ Questions and Help

Before asking:

  1. search the issues.
  2. search the docs.

What is your question?

I am looking to write my own callback for checkpointing for a list of models I initialize in __init__().

Code

I created 10 timeseries models and 1 image model lets say. Each model inherits Lightningmodule.
So LITFusionExp has 11 models.
When I save the checkpoint I can only see cnn_model's checkpoint and not ts_models.
However, I can see that trainer updates my ts_models.
The problem thus is when I reload the checkpoint all the ts_models are just randomly initialized. How to save ts_models too?
Thanks for the help

class LITFusionExp(LightningModule):
    def __init__(self,hparams):

        super().__init__()


        self.ts_models = [ Conv1dmultivariate(input_channels=10).cuda() for _ in range(10)]                

        self.cnn_model = LITConvAEexp(hparams)
trainer.fit(LITFusionExp())
trainer .save('mypath.ckpt')
###
my_ckpt= torch.load( 'mypath.ckpt')
#my_ckpt['state_dict'] has only keys with respect to CNN model

What's your environment?

  • OS: Linux
  • Packaging [e.g. conda]
  • Version [e.g. 0.8.5]
question

Most helpful comment

A list isn't a meaningful collection of nn.Modules in PyTorch. I think you're looking for torch.nn.ModuleList.

Wrap your list of modules with nn.ModuleList and I think your problem will be solved.

Going to close this because its a PyTorch problem, not a Lightning problem. Hope this helps at least 😄

All 3 comments

Hi! thanks for your contribution!, great first issue!

I have created a notebook here to show my concerns.
ckpt['state_dict'].keys() contains only cnn model's keys.
https://colab.research.google.com/drive/11rV-PY1CaDUyiYpgBddlU1megMIBYIQX?usp=sharing

A list isn't a meaningful collection of nn.Modules in PyTorch. I think you're looking for torch.nn.ModuleList.

Wrap your list of modules with nn.ModuleList and I think your problem will be solved.

Going to close this because its a PyTorch problem, not a Lightning problem. Hope this helps at least 😄

Was this page helpful?
0 / 5 - 0 ratings

Related issues

srush picture srush  ·  3Comments

mmsamiei picture mmsamiei  ·  3Comments

as754770178 picture as754770178  ·  3Comments

Vichoko picture Vichoko  ·  3Comments

edenlightning picture edenlightning  ·  3Comments