I am looking to write my own callback for checkpointing for a list of models I initialize in __init__().
I created 10 timeseries models and 1 image model lets say. Each model inherits Lightningmodule.
So LITFusionExp has 11 models.
When I save the checkpoint I can only see cnn_model's checkpoint and not ts_models.
However, I can see that trainer updates my ts_models.
The problem thus is when I reload the checkpoint all the ts_models are just randomly initialized. How to save ts_models too?
Thanks for the help
class LITFusionExp(LightningModule):
def __init__(self,hparams):
super().__init__()
self.ts_models = [ Conv1dmultivariate(input_channels=10).cuda() for _ in range(10)]
self.cnn_model = LITConvAEexp(hparams)
trainer.fit(LITFusionExp())
trainer .save('mypath.ckpt')
###
my_ckpt= torch.load( 'mypath.ckpt')
#my_ckpt['state_dict'] has only keys with respect to CNN model
Hi! thanks for your contribution!, great first issue!
I have created a notebook here to show my concerns.
ckpt['state_dict'].keys() contains only cnn model's keys.
https://colab.research.google.com/drive/11rV-PY1CaDUyiYpgBddlU1megMIBYIQX?usp=sharing
A list isn't a meaningful collection of nn.Modules in PyTorch. I think you're looking for torch.nn.ModuleList.
Wrap your list of modules with nn.ModuleList and I think your problem will be solved.
Going to close this because its a PyTorch problem, not a Lightning problem. Hope this helps at least 😄
Most helpful comment
A list isn't a meaningful collection of
nn.Modulesin PyTorch. I think you're looking fortorch.nn.ModuleList.Wrap your list of modules with
nn.ModuleListand I think your problem will be solved.Going to close this because its a PyTorch problem, not a Lightning problem. Hope this helps at least 😄