load_from_checkpoint: TypeError: __init__() missing 1 required positional argument
I have read the issues before, but the things different is my LightningModule is inherited from my self-defined LightningModule.
How to solve this problem or what is the best practice better suited to my needs?
To reproduce the error:
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from argparse import Namespace
class _LitModel(pl.LightningModule):
def __init__(self, hparams):
super().__init__()
if isinstance(hparams, dict):
hparams = Namespace(**hparams)
self.hparams = hparams
self.l1 = torch.nn.Linear(28 * 28, hparams.classes)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return {'val_loss': loss}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
return {'val_loss': avg_loss}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
class LitModel(_LitModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--classes', type=int, default=10)
parser.add_argument('--checkpoint', type=str, default=None)
hparams = parser.parse_args()
mnist_train = MNIST(os.getcwd(), train=True, download=True,
transform=transforms.ToTensor())
mnist_train = DataLoader(mnist_train, num_workers=1)
mnist_val = MNIST(os.getcwd(), train=False, download=False,
transform=transforms.ToTensor())
mnist_val = DataLoader(mnist_val, num_workers=1)
# A bit weird here. I just want to show `load_from_checkpoint` will fail.
if hparams.checkpoint is None:
model = LitModel(hparams)
else:
model = LitModel.load_from_checkpoint(hparams.checkpoint)
trainer = Trainer(max_epochs=2, limit_train_batches=2,
limit_val_batches=2, progress_bar_refresh_rate=0)
trainer.fit(model, mnist_train, mnist_val)
Traceback (most recent call last):
File "main.py", line 64, in <module>
model = LitModel.load_from_checkpoint(hparams.checkpoint)
File "/home/siahuat0727/.local/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 138, in load_from_checkpoint
model = cls._load_model_state(checkpoint, *args, **kwargs)
File "/home/siahuat0727/.local/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 174, in _load_model_state
model = cls(*cls_args, **cls_kwargs)
File "main.py", line 46, in __init__
super().__init__(*args, **kwargs)
TypeError: __init__() missing 1 required positional argument: 'hparams'
$ python3 main.py
$ python3 main.py --checkpoint lightning_logs/version_0/checkpoints/epoch\=1.ckpt
Did you try to call self.save_hyperparameters() in _LitModel?
Because it looks like hparams were not saved to checkpoint.
@awaelchli
Hihi, the result is the same.
It works if I directly use _LitModel instead of LitModel. So I think that's sth about inheritance.
https://pytorch-lightning.readthedocs.io/en/latest/hyperparameters.html
Anything assigned to self.hparams will also be saved automatically.
@siahuat0727 I can confirm this is a bug. I fixed it and reduced your example to a minimal test case, so it won't break in the future. Thanks for providing a easy to reproduce script!
Great job. Thanks!!
The same issue appears with version 1.0.5 (0.9.0 is fine). Can you help with it?
(also track_grad_norm doesn't work in 0.9.0, so I have to switch to 1.0.5...)
The same issue appears with version 1.0.5 (0.9.0 is fine). Can you help with it?
(also track_grad_norm doesn't work in 0.9.0, so I have to switch to 1.0.5...)
Bump
bump for version 1.0.6 as well
same problem here on 1.0.4
Apparently the problem is that checkpoint['hparams_name'] is empty. Maybe the problem is in the saving of the module when it is inherited?
What solved it for me is that instead of passing the hparams, you can pass them as kwargs. So in your class use:
class my_pl_module(LightningModule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.save_hyperparameters()
Still a bug though cause the hparams method is not yet deprecated.
@stathius yes, the "old" hparams method is not yet deprecated but it simply has conceptual flaws in terms of typing, that cannot be fixed as in a "bugfix". The solution we came up with here is to simply decouple two things:
And the code you posted is exactly doing that, and this is the recommended way today.
Most helpful comment
Bump