Pytorch-lightning: " __init__() got multiple values for argument" when load_from_checkpoint

Created on 8 Jul 2020  Â·  20Comments  Â·  Source: PyTorchLightning/pytorch-lightning

What is your question?

I follow the docs SAVING AND LOADING WEIGHTS but meet some problems.
When I load model from ckpt, I meet the error __init__() got multiple values for argument 'net'

Code

Here is the Colab version about the problem code.

code in colab

import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
import torch.nn as nn

class CNN(nn.Module):
    def __init__(self, h_dim):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, h_dim)
        self.l2 = torch.nn.Linear(h_dim, 10)

    def forward(self, x):
        x = torch.relu(self.l1(x.view(x.size(0), -1)))
        x = torch.relu(self.l2(x.view(x.size(0), -1)))
        return x

class MNISTModel(pl.LightningModule):

    def __init__(self, net, h_dim):
        super(MNISTModel, self).__init__()
        # self.save_hyperparameters()
        if net == 'cnn':
          self.net = CNN(h_dim)

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

train_loader = DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
mnist_model = MNISTModel('cnn', h_dim=200)
trainer = pl.Trainer(gpus=1, progress_bar_refresh_rate=20, max_epochs=3)    
trainer.fit(mnist_model, train_loader)  
trainer.save_checkpoint('mnist.ckpt')

when I want to load model from mnist.ckpt, I meet an error:

model = MNISTModel.load_from_checkpoint('mnist.ckpt', net='cnn', h_dim=200)
/usr/local/lib/python3.6/dist-packages/pytorch_lightning/core/saving.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, tags_csv, *args, **kwargs)
    167         checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
    168 
--> 169         model = cls._load_model_state(checkpoint, *args, **kwargs)
    170         return model
    171 

/usr/local/lib/python3.6/dist-packages/pytorch_lightning/core/saving.py in _load_model_state(cls, checkpoint, *cls_args, **cls_kwargs)
    203         if len(cls_spec.args) <= 1 and not cls_spec.kwonlyargs:
    204             cls_args, cls_kwargs = [], {}
--> 205         model = cls(*cls_args, **cls_kwargs)
    206         # load the state_dict on the model automatically
    207         model.load_state_dict(checkpoint['state_dict'])

TypeError: __init__() got multiple values for argument 'net'

What's your environment?

pytorch-lightning: 0.8.4
Colab environment with GPU

question

Most helpful comment

@Borda It seems there may be an issue here. In the example above, the MNISTModel is trained without self.save_hyperparameters() in __init__. When the user wants to load model weights from a checkpoint, the line MNISTModel.load_from_checkpoint('mnist.ckpt', net='cnn', h_dim=200) fails. As suggested by @rohitgr7, the issue seems to be in _load_model_state. Indeed, in this example, at line 196, cls_args equals ({'net': 'cnn', 'h_dim': 200},) whereas cls_kwargs also equals {'net': 'cnn', 'h_dim': 200}. Later in this function, when model = cls(*cls_args, **cls_kwargs) is called, the class receives multiple values for the net and h_dim parameters. Hence, the error!

The important point here is that one shall be able to load model weights from a checkpoint is the model was trained without self.save_hyperparameters(). As a fix for this issue, I propose to remove the following line in _load_model_state

cls_args = (model_args,) + cls_args

as it seems to only duplicates the class arguments. I implemented this fix locally and all tests are passing ✅ !

Shall I make a PR ? 🚀

cc @davidseroussi @invisprints

All 20 comments

Hi! thanks for your contribution!, great first issue!

I have checked it on different versions:

  • 0.7.6: Ok
  • 0.8.0: Ok
  • 0.8.1 (and higher): TypeError

Could this behaviour have been introduced in 0.8.1?

same issue here! any workaround? i tried install 0.8.0 but it pip says it can't find that version

The problem is here:
https://github.com/PyTorchLightning/pytorch-lightning/blob/a34609ef0ef3c5c6fac5dd30520ba45dbe0c9159/pytorch_lightning/core/saving.py#L196
Since args_name is None so model_args will remain a dict. Changing it to cls_kwargs.update(**model_args) works.

@rohitgr7 can we work around this somehow?

@rohitgr7 try this: Lightning.load_from_checkpoint(path, model=model, kwargs=dict(param_name=param_value))

@rohitgr7 yes, it's just a workaround at the moment

I caught the same bug introduced in this issue.

pytorch_lightning.core.saving.ModelIO.load_from_checkpoint() has args in the arguments to initialize pl.LightningModule.
If the module has multiple arguments, the method doesn't work correctly.

Snippets not to add arguments for pl.LightningModule

import pytorch_lightning as pl

class sample(pl.LightningModule):

    def __init__(self, one, two, three):
        super(sample, self).__init__()

    def forward(self, input, **kwargs):
        return 0

    def training_step(self, batch, batch_idx, optimizer_idx=0):
        return 0

    def validation_step(self, batch, batch_idx, optimizer_idx=0):
        return 0


if __name__ == "__main__":
    model = sample.load_from_checkpoint(checkpoint_path=path)  # please set `path` properly
    print(model)

Then, I got a following error.

Traceback (most recent call last):
  File ".\min.py", line 20, in <module>
    model = sample.load_from_checkpoint(checkpoint_path="...")
  File "...\lib\site-packages\pytorch_lightning\core\saving.py", line 169, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, *args, **kwargs)
  File "...\lib\site-packages\pytorch_lightning\core\saving.py", line 205, in _load_model_state
    model = cls(*cls_args, **cls_kwargs)
TypeError: __init__() missing 2 required positional arguments: 'two' and 'three'

Snippets to add arguments for pl.LightningModule

import pytorch_lightning as pl

class sample(pl.LightningModule):

    def __init__(self, one, two, three):
        super(sample, self).__init__()

    def forward(self, input, **kwargs):
        return 0

    def training_step(self, batch, batch_idx, optimizer_idx=0):
        return 0

    def validation_step(self, batch, batch_idx, optimizer_idx=0):
        return 0


if __name__ == "__main__":
    model = sample.load_from_checkpoint(checkpoint_path=path, one=1, two=2, three=3) # please set `path` properly
    print(model)

Then, I got a following error.

Traceback (most recent call last):
  File ".\min.py", line 20, in <module>
    model = sample.load_from_checkpoint(checkpoint_path="...",
  File "...\lib\site-packages\pytorch_lightning\core\saving.py", line 169, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, *args, **kwargs)
  File "...\lib\site-packages\pytorch_lightning\core\saving.py", line 205, in _load_model_state
    model = cls(*cls_args, **cls_kwargs)
TypeError: __init__() got multiple values for argument 'one'

Environment

  • PyTorchLightning Version: 0.8.5
  • PyTorch Version: 1.5.1
  • OS: Windows
  • How you installed PyTorch: pip
  • Python version: 3.8.3
  • CUDA/cuDNN version: 10.2
  • GPU models and configuration: RTX 2080 Ti

Could you fix it?

Oh sorry, I missed the comments.
Thank you! I'll try it!

@Borda It seems there may be an issue here. In the example above, the MNISTModel is trained without self.save_hyperparameters() in __init__. When the user wants to load model weights from a checkpoint, the line MNISTModel.load_from_checkpoint('mnist.ckpt', net='cnn', h_dim=200) fails. As suggested by @rohitgr7, the issue seems to be in _load_model_state. Indeed, in this example, at line 196, cls_args equals ({'net': 'cnn', 'h_dim': 200},) whereas cls_kwargs also equals {'net': 'cnn', 'h_dim': 200}. Later in this function, when model = cls(*cls_args, **cls_kwargs) is called, the class receives multiple values for the net and h_dim parameters. Hence, the error!

The important point here is that one shall be able to load model weights from a checkpoint is the model was trained without self.save_hyperparameters(). As a fix for this issue, I propose to remove the following line in _load_model_state

cls_args = (model_args,) + cls_args

as it seems to only duplicates the class arguments. I implemented this fix locally and all tests are passing ✅ !

Shall I make a PR ? 🚀

cc @davidseroussi @invisprints

@jbschiratti please ping me in the PR if you make one :)

@GazizovMarat This workaround is unclear to me, what is "Lightning" referring to in Lightning.load_from_checkpoint(...) Can you please elaborate more on how to make this work? Thanks

I tried multiple suggestions found on several issues here, none worked for me.

My LightningModule is as follows:

class VariationalAutoEncoder(pl.LightningModule):
    def __init__(self, encoder: nn.Module, decoder: nn.Module, get_batch_fn: Callable = lambda x: x) -> None:
        ...

Loading from checkpoint:

  • version 1
VariationalAutoEncoder.load_from_checkpoint(str(CHECKPOINT_PATH), kwargs=dict(encoder=Encoder(), decoder=Decoder()))

---> TypeError: __init__() missing 1 required positional argument: 'decoder'
  • version 2
VariationalAutoEncoder.load_from_checkpoint(str(CHECKPOINT_PATH), encoder=Encoder(), decoder=Decoder())

---> TypeError: __init__() got multiple values for argument 'encoder'
  • version 3
VariationalAutoEncoder.load_from_checkpoint(str(CHECKPOINT_PATH))

--> TypeError: __init__() missing 1 required positional argument: 'decoder'

@matthaeusheer I believe this was fixed, mind share what version of Pl are you using?

Sure, I use pytorch-lightning==0.8.5

My current workaround was to load the model via pytorch directly

CHECKPOINT_PATH = DATA_DIR_PATH / 'lightning_logs/train_vae/version_32/checkpoints/epoch=2.ckpt'
checkpoint = torch.load(CHECKPOINT_PATH)
model.load_state_dict(checkpoint['state_dict'])

which works fine for me, but of course I'd like to see the fix within lightning :-)

@matthaeusheer I see, mind try actual master or latest 0.9rc12

@Borda The issue is still there, the ways mentioned by @matthaeusheer don't work with the version from master producing the same TypeErrors

@invisprints just tested that #2776 fixes your issue...

Glad to hear that. But I see it haven’t merged into master

Was this page helpful?
0 / 5 - 0 ratings

Related issues

maxime-louis picture maxime-louis  Â·  3Comments

remisphere picture remisphere  Â·  3Comments

Vichoko picture Vichoko  Â·  3Comments

awaelchli picture awaelchli  Â·  3Comments

polars05 picture polars05  Â·  3Comments