Pytorch-lightning: load_from_checkpoint raises TypeError when **kwargs is provided

Created on 15 Jul 2020  ·  7Comments  ·  Source: PyTorchLightning/pytorch-lightning

🐛 Bug

The issue happens in case of .save_hyperparameters() is not used.

To Reproduce

Steps to reproduce the behavior: follow code sample

Code sample

import os

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule

class LitModel(LightningModule):

    def __init__(self, some_stuff):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def train_dataloader(self):
        dataset = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        loader = DataLoader(dataset, batch_size=32, num_workers=4, shuffle=True)
        return loader


model = LitModel(42)
trainer = pl.Trainer(fast_dev_run=True)
trainer.fit(model)

trainer.save_checkpoint('checkpoint.pl')
model = LitModel.load_from_checkpoint('checkpoint.pl', some_stuff=42)

# TypeError: __init__() got multiple values for argument 'some_stuff'

Expected behavior

Model is able to load

Environment

  • PyTorch Version (e.g., 1.0): 1.5.0
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.7.6
  • CUDA/cuDNN version: N/A
  • GPU models and configuration: N/A
  • Any other relevant information:

Additional context

I noticed that loading starts working if you call save_hyperparameters. It also works correctly if some_stuff is not inside save_hyperparameters, e.g. if you change __init__ to this one

    def __init__(self, some_stuff, lr=42):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)
        self.lr = 42
        self.save_hyperparameters('lr')
bug / fix help wanted

Most helpful comment

Maybe let's wait until the next major version to fix it, but document the behaviour.
It's not a bug, if it is documented, right? 😆

And in my (reasonably complex) case it was resolved when I added .save_hyperparameters() at the end of __init__. No other code changes were needed.

All 7 comments

Hi! thanks for your contribution!, great first issue!

ok, it seems a bit tricky how to correctly interpret method arguments and still keep back-compatibility...

Maybe let's wait until the next major version to fix it, but document the behaviour.
It's not a bug, if it is documented, right? 😆

And in my (reasonably complex) case it was resolved when I added .save_hyperparameters() at the end of __init__. No other code changes were needed.

Similar #2550?

Yes, exactly the same issue

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@Borda please check if issue persists on master

Was this page helpful?
0 / 5 - 0 ratings

Related issues

williamFalcon picture williamFalcon  ·  3Comments

iakremnev picture iakremnev  ·  3Comments

justusschock picture justusschock  ·  3Comments

DavidRuhe picture DavidRuhe  ·  3Comments

as754770178 picture as754770178  ·  3Comments