Pytorch-lightning: load_from_checkpoint: TypeError: __init__() missing 1 required positional argument

Created on 11 Aug 2020  ยท  12Comments  ยท  Source: PyTorchLightning/pytorch-lightning

โ“ Questions and Help

What is your question?

load_from_checkpoint: TypeError: __init__() missing 1 required positional argument

I have read the issues before, but the things different is my LightningModule is inherited from my self-defined LightningModule.

How to solve this problem or what is the best practice better suited to my needs?

Code

To reproduce the error:

import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning import Trainer

from argparse import Namespace

class _LitModel(pl.LightningModule):

    def __init__(self, hparams):
        super().__init__()
        if isinstance(hparams, dict):
            hparams = Namespace(**hparams)
        self.hparams = hparams
        self.l1 = torch.nn.Linear(28 * 28, hparams.classes)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return {'val_loss': loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        return {'val_loss': avg_loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

class LitModel(_LitModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--classes', type=int, default=10)
parser.add_argument('--checkpoint', type=str, default=None)
hparams = parser.parse_args()

mnist_train = MNIST(os.getcwd(), train=True, download=True,
                    transform=transforms.ToTensor())
mnist_train = DataLoader(mnist_train, num_workers=1)
mnist_val = MNIST(os.getcwd(), train=False, download=False,
                  transform=transforms.ToTensor())
mnist_val = DataLoader(mnist_val, num_workers=1)

# A bit weird here. I just want to show `load_from_checkpoint` will fail.
if hparams.checkpoint is None:
    model = LitModel(hparams)
else:
    model = LitModel.load_from_checkpoint(hparams.checkpoint)

trainer = Trainer(max_epochs=2, limit_train_batches=2,
                  limit_val_batches=2, progress_bar_refresh_rate=0)
trainer.fit(model, mnist_train, mnist_val)

Error msg

Traceback (most recent call last):
  File "main.py", line 64, in <module>
    model = LitModel.load_from_checkpoint(hparams.checkpoint)
  File "/home/siahuat0727/.local/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 138, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, *args, **kwargs)
  File "/home/siahuat0727/.local/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 174, in _load_model_state
    model = cls(*cls_args, **cls_kwargs)
  File "main.py", line 46, in __init__
    super().__init__(*args, **kwargs)
TypeError: __init__() missing 1 required positional argument: 'hparams'

How to run to get the error

$ python3 main.py 
$ python3 main.py --checkpoint lightning_logs/version_0/checkpoints/epoch\=1.ckpt

What's your environment?

  • OS: Linux
  • Packaging: pip
  • Version 0.9.0rc12
bug / fix question

Most helpful comment

The same issue appears with version 1.0.5 (0.9.0 is fine). Can you help with it?

(also track_grad_norm doesn't work in 0.9.0, so I have to switch to 1.0.5...)

Bump

All 12 comments

Did you try to call self.save_hyperparameters() in _LitModel?
Because it looks like hparams were not saved to checkpoint.

@awaelchli
Hihi, the result is the same.
It works if I directly use _LitModel instead of LitModel. So I think that's sth about inheritance.

https://pytorch-lightning.readthedocs.io/en/latest/hyperparameters.html

Anything assigned to self.hparams will also be saved automatically.

@siahuat0727 I can confirm this is a bug. I fixed it and reduced your example to a minimal test case, so it won't break in the future. Thanks for providing a easy to reproduce script!

Great job. Thanks!!

The same issue appears with version 1.0.5 (0.9.0 is fine). Can you help with it?

(also track_grad_norm doesn't work in 0.9.0, so I have to switch to 1.0.5...)

The same issue appears with version 1.0.5 (0.9.0 is fine). Can you help with it?

(also track_grad_norm doesn't work in 0.9.0, so I have to switch to 1.0.5...)

Bump

bump for version 1.0.6 as well

same problem here on 1.0.4

Apparently the problem is that checkpoint['hparams_name'] is empty. Maybe the problem is in the saving of the module when it is inherited?

What solved it for me is that instead of passing the hparams, you can pass them as kwargs. So in your class use:

class my_pl_module(LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.save_hyperparameters()

Still a bug though cause the hparams method is not yet deprecated.

@stathius yes, the "old" hparams method is not yet deprecated but it simply has conceptual flaws in terms of typing, that cannot be fixed as in a "bugfix". The solution we came up with here is to simply decouple two things:

  1. Saving hyperparameters into the checkpoint
  2. making hyperparameters accessible through a convenient self.hparams "namespace".

And the code you posted is exactly doing that, and this is the recommended way today.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

maxime-louis picture maxime-louis  ยท  3Comments

iakremnev picture iakremnev  ยท  3Comments

williamFalcon picture williamFalcon  ยท  3Comments

williamFalcon picture williamFalcon  ยท  3Comments

DavidRuhe picture DavidRuhe  ยท  3Comments