The issue happens in case of .save_hyperparameters() is not used.
Steps to reproduce the behavior: follow code sample
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
class LitModel(LightningModule):
def __init__(self, some_stuff):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
def train_dataloader(self):
dataset = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
loader = DataLoader(dataset, batch_size=32, num_workers=4, shuffle=True)
return loader
model = LitModel(42)
trainer = pl.Trainer(fast_dev_run=True)
trainer.fit(model)
trainer.save_checkpoint('checkpoint.pl')
model = LitModel.load_from_checkpoint('checkpoint.pl', some_stuff=42)
# TypeError: __init__() got multiple values for argument 'some_stuff'
Model is able to load
conda, pip, source): pipI noticed that loading starts working if you call save_hyperparameters. It also works correctly if some_stuff is not inside save_hyperparameters, e.g. if you change __init__ to this one
def __init__(self, some_stuff, lr=42):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
self.lr = 42
self.save_hyperparameters('lr')
Hi! thanks for your contribution!, great first issue!
ok, it seems a bit tricky how to correctly interpret method arguments and still keep back-compatibility...
Maybe let's wait until the next major version to fix it, but document the behaviour.
It's not a bug, if it is documented, right? 馃槅
And in my (reasonably complex) case it was resolved when I added .save_hyperparameters() at the end of __init__. No other code changes were needed.
Similar #2550?
Yes, exactly the same issue
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
@Borda please check if issue persists on master
Most helpful comment
Maybe let's wait until the next major version to fix it, but document the behaviour.
It's not a bug, if it is documented, right? 馃槅
And in my (reasonably complex) case it was resolved when I added
.save_hyperparameters()at the end of__init__. No other code changes were needed.