I follow the docs SAVING AND LOADING WEIGHTS but meet some problems.
When I load model from ckpt, I meet the error __init__() got multiple values for argument 'net'
Here is the Colab version about the problem code.
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
import torch.nn as nn
class CNN(nn.Module):
def __init__(self, h_dim):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, h_dim)
self.l2 = torch.nn.Linear(h_dim, 10)
def forward(self, x):
x = torch.relu(self.l1(x.view(x.size(0), -1)))
x = torch.relu(self.l2(x.view(x.size(0), -1)))
return x
class MNISTModel(pl.LightningModule):
def __init__(self, net, h_dim):
super(MNISTModel, self).__init__()
# self.save_hyperparameters()
if net == 'cnn':
self.net = CNN(h_dim)
def forward(self, x):
return self.net(x)
def training_step(self, batch, batch_nb):
x, y = batch
loss = F.cross_entropy(self(x), y)
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
train_loader = DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
mnist_model = MNISTModel('cnn', h_dim=200)
trainer = pl.Trainer(gpus=1, progress_bar_refresh_rate=20, max_epochs=3)
trainer.fit(mnist_model, train_loader)
trainer.save_checkpoint('mnist.ckpt')
when I want to load model from mnist.ckpt, I meet an error:
model = MNISTModel.load_from_checkpoint('mnist.ckpt', net='cnn', h_dim=200)
/usr/local/lib/python3.6/dist-packages/pytorch_lightning/core/saving.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, tags_csv, *args, **kwargs)
167 checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
168
--> 169 model = cls._load_model_state(checkpoint, *args, **kwargs)
170 return model
171
/usr/local/lib/python3.6/dist-packages/pytorch_lightning/core/saving.py in _load_model_state(cls, checkpoint, *cls_args, **cls_kwargs)
203 if len(cls_spec.args) <= 1 and not cls_spec.kwonlyargs:
204 cls_args, cls_kwargs = [], {}
--> 205 model = cls(*cls_args, **cls_kwargs)
206 # load the state_dict on the model automatically
207 model.load_state_dict(checkpoint['state_dict'])
TypeError: __init__() got multiple values for argument 'net'
pytorch-lightning: 0.8.4
Colab environment with GPU
Hi! thanks for your contribution!, great first issue!
I have checked it on different versions:
Could this behaviour have been introduced in 0.8.1?
same issue here! any workaround? i tried install 0.8.0 but it pip says it can't find that version
The problem is here:
https://github.com/PyTorchLightning/pytorch-lightning/blob/a34609ef0ef3c5c6fac5dd30520ba45dbe0c9159/pytorch_lightning/core/saving.py#L196
Since args_name is None so model_args will remain a dict. Changing it to cls_kwargs.update(**model_args) works.
@rohitgr7 can we work around this somehow?
@rohitgr7 try this: Lightning.load_from_checkpoint(path, model=model, kwargs=dict(param_name=param_value))
@GazizovMarat this should work but still should work simply by just passing the args or kwargs as normal params.
https://github.com/PyTorchLightning/pytorch-lightning/blob/7b4db3045dcc9e6bb0b66e409b25bb2c7fa378f0/pytorch_lightning/core/saving.py#L65
https://github.com/PyTorchLightning/pytorch-lightning/blob/7b4db3045dcc9e6bb0b66e409b25bb2c7fa378f0/pytorch_lightning/core/saving.py#L103
@rohitgr7 yes, it's just a workaround at the moment
I caught the same bug introduced in this issue.
pytorch_lightning.core.saving.ModelIO.load_from_checkpoint() has args in the arguments to initialize pl.LightningModule.
If the module has multiple arguments, the method doesn't work correctly.
pl.LightningModuleimport pytorch_lightning as pl
class sample(pl.LightningModule):
def __init__(self, one, two, three):
super(sample, self).__init__()
def forward(self, input, **kwargs):
return 0
def training_step(self, batch, batch_idx, optimizer_idx=0):
return 0
def validation_step(self, batch, batch_idx, optimizer_idx=0):
return 0
if __name__ == "__main__":
model = sample.load_from_checkpoint(checkpoint_path=path) # please set `path` properly
print(model)
Then, I got a following error.
Traceback (most recent call last):
File ".\min.py", line 20, in <module>
model = sample.load_from_checkpoint(checkpoint_path="...")
File "...\lib\site-packages\pytorch_lightning\core\saving.py", line 169, in load_from_checkpoint
model = cls._load_model_state(checkpoint, *args, **kwargs)
File "...\lib\site-packages\pytorch_lightning\core\saving.py", line 205, in _load_model_state
model = cls(*cls_args, **cls_kwargs)
TypeError: __init__() missing 2 required positional arguments: 'two' and 'three'
pl.LightningModuleimport pytorch_lightning as pl
class sample(pl.LightningModule):
def __init__(self, one, two, three):
super(sample, self).__init__()
def forward(self, input, **kwargs):
return 0
def training_step(self, batch, batch_idx, optimizer_idx=0):
return 0
def validation_step(self, batch, batch_idx, optimizer_idx=0):
return 0
if __name__ == "__main__":
model = sample.load_from_checkpoint(checkpoint_path=path, one=1, two=2, three=3) # please set `path` properly
print(model)
Then, I got a following error.
Traceback (most recent call last):
File ".\min.py", line 20, in <module>
model = sample.load_from_checkpoint(checkpoint_path="...",
File "...\lib\site-packages\pytorch_lightning\core\saving.py", line 169, in load_from_checkpoint
model = cls._load_model_state(checkpoint, *args, **kwargs)
File "...\lib\site-packages\pytorch_lightning\core\saving.py", line 205, in _load_model_state
model = cls(*cls_args, **cls_kwargs)
TypeError: __init__() got multiple values for argument 'one'
Could you fix it?
Current workaround suggested here: https://github.com/PyTorchLightning/pytorch-lightning/issues/2550#issuecomment-659568926
Oh sorry, I missed the comments.
Thank you! I'll try it!
@Borda It seems there may be an issue here. In the example above, the MNISTModel is trained without self.save_hyperparameters() in __init__. When the user wants to load model weights from a checkpoint, the line MNISTModel.load_from_checkpoint('mnist.ckpt', net='cnn', h_dim=200) fails. As suggested by @rohitgr7, the issue seems to be in _load_model_state. Indeed, in this example, at line 196, cls_args equals ({'net': 'cnn', 'h_dim': 200},) whereas cls_kwargs also equals {'net': 'cnn', 'h_dim': 200}. Later in this function, when model = cls(*cls_args, **cls_kwargs) is called, the class receives multiple values for the net and h_dim parameters. Hence, the error!
The important point here is that one shall be able to load model weights from a checkpoint is the model was trained without self.save_hyperparameters(). As a fix for this issue, I propose to remove the following line in _load_model_state
cls_args = (model_args,) + cls_args
as it seems to only duplicates the class arguments. I implemented this fix locally and all tests are passing ✅ !
Shall I make a PR ? 🚀
cc @davidseroussi @invisprints
@jbschiratti please ping me in the PR if you make one :)
@GazizovMarat This workaround is unclear to me, what is "Lightning" referring to in Lightning.load_from_checkpoint(...) Can you please elaborate more on how to make this work? Thanks
I tried multiple suggestions found on several issues here, none worked for me.
My LightningModule is as follows:
class VariationalAutoEncoder(pl.LightningModule):
def __init__(self, encoder: nn.Module, decoder: nn.Module, get_batch_fn: Callable = lambda x: x) -> None:
...
Loading from checkpoint:
VariationalAutoEncoder.load_from_checkpoint(str(CHECKPOINT_PATH), kwargs=dict(encoder=Encoder(), decoder=Decoder()))
---> TypeError: __init__() missing 1 required positional argument: 'decoder'
VariationalAutoEncoder.load_from_checkpoint(str(CHECKPOINT_PATH), encoder=Encoder(), decoder=Decoder())
---> TypeError: __init__() got multiple values for argument 'encoder'
VariationalAutoEncoder.load_from_checkpoint(str(CHECKPOINT_PATH))
--> TypeError: __init__() missing 1 required positional argument: 'decoder'
@matthaeusheer I believe this was fixed, mind share what version of Pl are you using?
Sure, I use pytorch-lightning==0.8.5
My current workaround was to load the model via pytorch directly
CHECKPOINT_PATH = DATA_DIR_PATH / 'lightning_logs/train_vae/version_32/checkpoints/epoch=2.ckpt'
checkpoint = torch.load(CHECKPOINT_PATH)
model.load_state_dict(checkpoint['state_dict'])
which works fine for me, but of course I'd like to see the fix within lightning :-)
@matthaeusheer I see, mind try actual master or latest 0.9rc12
@Borda The issue is still there, the ways mentioned by @matthaeusheer don't work with the version from master producing the same TypeErrors
@invisprints just tested that #2776 fixes your issue...
Glad to hear that. But I see it haven’t merged into master
Most helpful comment
@Borda It seems there may be an issue here. In the example above, the
MNISTModelis trained withoutself.save_hyperparameters()in__init__. When the user wants to load model weights from a checkpoint, the lineMNISTModel.load_from_checkpoint('mnist.ckpt', net='cnn', h_dim=200)fails. As suggested by @rohitgr7, the issue seems to be in_load_model_state. Indeed, in this example, at line 196,cls_argsequals({'net': 'cnn', 'h_dim': 200},)whereascls_kwargsalso equals{'net': 'cnn', 'h_dim': 200}. Later in this function, whenmodel = cls(*cls_args, **cls_kwargs)is called, the class receives multiple values for thenetandh_dimparameters. Hence, the error!The important point here is that one shall be able to load model weights from a checkpoint is the model was trained without
self.save_hyperparameters(). As a fix for this issue, I propose to remove the following line in_load_model_stateas it seems to only duplicates the class arguments. I implemented this fix locally and all tests are passing ✅ !
Shall I make a PR ? 🚀
cc @davidseroussi @invisprints