Pytorch-lightning: Relax hparams in model saving/loading

Created on 21 Feb 2020  路  8Comments  路  Source: PyTorchLightning/pytorch-lightning

I've managed to train a model using pl.fit(model) and have the .ckpt file. Now, I'm trying to load the .ckpt file so that I can do inference on a single image:

model = CoolSystem()
to_infer = torch.load('checkpoints/try_ckpt_epoch_1_v0.ckpt')
model.load_from_checkpoint(to_infer) # ------------- error is thrown at this line

However, upon loading the .ckpt file, the following error is thrown:

AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

Am I doing something wrong when using PyTorch Lightning for inference?

For reference, this is my system:

import pytorch_lightning as pl

import os
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

class CoolSystem(pl.LightningModule):

    def __init__(self):
        super(CoolSystem, self).__init__()

        # self.hparams = hparams
        self.data_dir = '/content/hymenoptera_data'

        self.model = torchvision.models.resnet18(pretrained=True) # final layer is of size [bs, 1000]
        num_ftrs = self.model.fc.in_features
        self.model.fc = torch.nn.Linear(num_ftrs, 2) # change final layer to be of size [bs, 2]

    def forward(self, x):
        x = self.model(x)
        return x

    def configure_optimizers(self):
        # Observe that all parameters are being optimized
        optimizer = torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)

        # Decay LR by a factor of 0.1 every 7 epochs
        exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

        return [optimizer], [exp_lr_scheduler]

    def training_step(self, batch, batch_idx):
        # REQUIRED
        x, y = batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

    @pl.data_loader
    def train_dataloader(self):
        # REQUIRED

        transform = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                ])

        train_set = torchvision.datasets.ImageFolder(os.path.join(self.data_dir, 'train'), transform)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4)

        return train_loader

    @pl.data_loader
    def val_dataloader(self):
      transform = transforms.Compose([
                                transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                ])

      val_set = torchvision.datasets.ImageFolder(os.path.join(self.data_dir, 'val'), transform)
      val_loader = torch.utils.data.DataLoader(val_set, batch_size=32, shuffle=True, num_workers=4)

      return val_loader

And I'm training it this way:

model = CoolSystem() 

import os

checkpoint_callback = pl.callbacks.ModelCheckpoint(
          filepath=os.path.join(os.getcwd(), 'checkpoints'),
          verbose=True,
          monitor='val_loss', 
          mode='min', 
          prefix='try',
          save_top_k=-1,
          period=1 # check val_loss every n periods, and saves the checkpoint if it is better than the val_loss at the previous period
      )

trainer = pl.Trainer(
      max_epochs=2,
      checkpoint_callback=checkpoint_callback)  

trainer.fit(model)
question

All 8 comments

Hey, thanks for your contribution! Great first issue!

Have not tested it, but I think it should be
model.load_from_checkpoint('checkpoints/try_ckpt_epoch_1_v0.ckpt')
(the method takes a string).
See docs:
https://pytorch-lightning.readthedocs.io/en/0.6.0/pytorch_lightning.core.html#pytorch_lightning.core.LightningModule.load_from_checkpoint

After trying model.load_from_checkpoint('checkpoints/try_ckpt_epoch_1_v0.ckpt'), the following error is now thrown:

OSError: Checkpoint does not contain hyperparameters. Are your model hyperparameters storedin self.hparams?

I built CoolSystem() without self.hparams, as per the example Colab notebook (https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=HOk9c4_35FKg)

Any advice on this?

I guess it should be added to that example. The GAN example has it.
Add the hparams to the __init__, train it, and then try to load again.
Looks like it is always needed, even if you don't pass any hparams in.

Got it! Will take note to always add hparams to __init__ then

@awaelchli find submitting a PR to fix?

i think the point was for hparams to be optional? or should we make it more flexible? @neggert

I can look at it.
To make it optional, I guess we could simply change the loading behaviour depending on whether the user has defined hparams or not.

I will hold back until #849 is finalized because it affects ModelCheckpoint callback.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

jcreinhold picture jcreinhold  路  3Comments

edenlightning picture edenlightning  路  3Comments

DavidRuhe picture DavidRuhe  路  3Comments

iakremnev picture iakremnev  路  3Comments

maxime-louis picture maxime-louis  路  3Comments