When we save models with a Checkpoint Callback, we can only load it up by having the original LightningModule that we used to create the model and the checkpoint. Is there some extension of save_checkpoint so that I can save out everything that I would need to reload the module and load the checkpoints as part of a save function?
What is best practice around this?
I have created a modular LightningModule that can change depending on hyperparams which makes it more difficult to just reload and use the module. I would need to have the hyperparam file as well to make sure that my module is loaded by the same way.
can you give an example of your model or useful reloading?
I would need to have the hyperparam file as well to make sure that my module is loaded by the same way.
Why would you need that file? PL saves all (few exceptions) hyperparameters to the checkpoint. As long as the source code matches, you can load the model with Model.load_from_checkpoint(...)
If you need to store extra data, just override the hooks on_save_checkpoint and on_load_checkpoint. docs
This is the init part of my model. It reads in the hparams then depending on the config I have set chooses a model either from Torchvision lib or a model I have defined locally. hparams can also change the activation function used. so I need to init with the same params before I do the load_from_checkpoint. It would be handy if there was a save function that saved everything I needed so that I can just load rather than declaring the class with the right hparams first
class LightningModel(LightningModule):
def __init__(self, hparams):
super().__init__()
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
local_model_names = sorted(name for name in local_models.__dict__
if name.islower() and not name.startswith("__")
and callable(local_models.__dict__[name]))
valid_models = model_names + local_model_names
self.hparams = hparams
if self.hparams.act_func == 'swish':
self.act_funct = Swish()
elif self.hparams.act_func == 'mish':
self.act_funct = Mish()
elif self.hparams.act_func == 'relu':
self.act_funct = nn.ReLU(inplace=True)
# initiate model
print("=> creating new model '{}'".format(self.hparams.model))
if self.hparams.model in model_names:
cv_model = models.__dict__[self.hparams.model](pretrained=False,
num_classes=self.hparams.num_classes)
elif self.hparams.model in local_model_names:
cv_model = local_models.__dict__[self.hparams.model](pretrained=False,
activation=self.act_funct,
num_classes=self.hparams.num_classes)
if self.hparams.model == 'inception_v3':
cv_model.aux_logits=False
self.model = cv_model
It would be handy if there was a save function that saved everything I needed so that I can just load rather than declaring the class with the right hparams first
That's exactly what Lightning is doing for you.
See https://pytorch-lightning.readthedocs.io/en/latest/hyperparameters.html#lightningmodule-hyperparameters
If you save your hyperparameters this way, then you can load Model.load_from_checkpoint(...)
You don't have to init it first.
Of course the model definition cannot change between saving and loading.
ah okay so then when I reload checkpoint it will populate the hparams as well?
yes, that's what Lightning is trying to simplify for you. It will simply call your init with the saved hparams from the checkpoint. You should use the latest PL version for this since this feature is relatively new and had some bug fixes recently. Try it and see :)
what is the correct way to initialise then?
as the module definition is:
class LightningModel(LightningModule):
def __init__(self, hparams):
super().__init__()
I need to have an hparams to initialise the model:
model = LightningModel()
won't work but I need to have the model instance before I can call load_from_checkpoint?
no, carefully read the docs. load_from_checkpoint is a class method.
model = LightningModel.load_from_checkpoint(...) # see the class here, not the instance
print(model.hparams) # there it is
ah gotcha cool seems to work in my quick test all good