LightningModules have the ability to augment the checkpoint dict with custom information relevant to the module. The motivation for this issue is to provide the same functionality for DataModules, as they too might have custom information they need to read from or write to the checkpoint dict. Importantly, there will only be one checkpoint saved during training. This checkpoint will comprise the default state generated by the trainer, the lightning module custom data (optional), and the data module custom data (optional)
Ckpt = Trainer State + Lightning module custom data (optional) + Data module custom data (optional)
cc @nateraw
Custom data we may want to include in the checkpoint:
These are all pieces of state that could live in advanced data modules. Importantly, this state isn't static: it can consist of arguments passed into the datamodule, or state generated during training
on_{save/load}_checkpoint functions for the DataModuleif self.trainer.datamodule is not None:
self.trainer.datamodule.on_save_checkpoint(checkpoint)
if self.trainer.datamodule is not None:
self.trainer.datamodule.on_load_checkpoint(checkpoint)
Sharing code:
We can refactor the existing LightningModule hooks here into a separate CheckpointHook mixin:
Then both the LightningModule and DataModule can inherit from this mixin (similar to what was done for the DataHook), meaning both the LightningModule and DataModule will share the function declarations.
By default, the data module will have an empty stub for these function definitions (as is the current implementation for the LightningModule)
Users are forced to load state from their checkpoint by including their datamodule inside their lightning module
The lightning module ends up this looks like this:
class MyLightningModule(pl.LightningModule):
def __init__(self, args):
self.datamodule = MyLightningDataModule(args)
...
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
checkpoint["stuff"] = get_state_from_datamodule(self.datamodule)
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.datamodule.stuff = checkpoint["stuff"]
But this forces the datamodule to be instantiated within the lightning module. the datamodule is not an independent component anymore. According to the docs, this is not a recommended usage of data modules: https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html#using-a-datamodule
https://forums.pytorchlightning.ai/t/do-datamodules-have-on-checkpoint-load-save-hooks/180
Assigned @ananthsub as he mentioned he wouldn't mind taking a stab at implementing this. Assigned myself to keep an eye on it.
@PyTorchLightning/core-contributors lets support Ananth as he submits contribution for this. Please leave your thoughts on the proposed solution here if you have any 😄 .
Most helpful comment
Assigned @ananthsub as he mentioned he wouldn't mind taking a stab at implementing this. Assigned myself to keep an eye on it.
@PyTorchLightning/core-contributors lets support Ananth as he submits contribution for this. Please leave your thoughts on the proposed solution here if you have any 😄 .