I've managed to train a model using pl.fit(model) and have the .ckpt file. Now, I'm trying to load the .ckpt file so that I can do inference on a single image:
model = CoolSystem()
to_infer = torch.load('checkpoints/try_ckpt_epoch_1_v0.ckpt')
model.load_from_checkpoint(to_infer) # ------------- error is thrown at this line
However, upon loading the .ckpt file, the following error is thrown:
AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.
Am I doing something wrong when using PyTorch Lightning for inference?
For reference, this is my system:
import pytorch_lightning as pl
import os
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
class CoolSystem(pl.LightningModule):
def __init__(self):
super(CoolSystem, self).__init__()
# self.hparams = hparams
self.data_dir = '/content/hymenoptera_data'
self.model = torchvision.models.resnet18(pretrained=True) # final layer is of size [bs, 1000]
num_ftrs = self.model.fc.in_features
self.model.fc = torch.nn.Linear(num_ftrs, 2) # change final layer to be of size [bs, 2]
def forward(self, x):
x = self.model(x)
return x
def configure_optimizers(self):
# Observe that all parameters are being optimized
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
return [optimizer], [exp_lr_scheduler]
def training_step(self, batch, batch_idx):
# REQUIRED
x, y = batch
y_hat = self.forward(x)
loss = F.cross_entropy(y_hat, y)
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
def validation_step(self, batch, batch_idx):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
return {'val_loss': F.cross_entropy(y_hat, y)}
def validation_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
tensorboard_logs = {'val_loss': avg_loss}
return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
@pl.data_loader
def train_dataloader(self):
# REQUIRED
transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
train_set = torchvision.datasets.ImageFolder(os.path.join(self.data_dir, 'train'), transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4)
return train_loader
@pl.data_loader
def val_dataloader(self):
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_set = torchvision.datasets.ImageFolder(os.path.join(self.data_dir, 'val'), transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=32, shuffle=True, num_workers=4)
return val_loader
And I'm training it this way:
model = CoolSystem()
import os
checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=os.path.join(os.getcwd(), 'checkpoints'),
verbose=True,
monitor='val_loss',
mode='min',
prefix='try',
save_top_k=-1,
period=1 # check val_loss every n periods, and saves the checkpoint if it is better than the val_loss at the previous period
)
trainer = pl.Trainer(
max_epochs=2,
checkpoint_callback=checkpoint_callback)
trainer.fit(model)
Hey, thanks for your contribution! Great first issue!
Have not tested it, but I think it should be
model.load_from_checkpoint('checkpoints/try_ckpt_epoch_1_v0.ckpt')
(the method takes a string).
See docs:
https://pytorch-lightning.readthedocs.io/en/0.6.0/pytorch_lightning.core.html#pytorch_lightning.core.LightningModule.load_from_checkpoint
After trying model.load_from_checkpoint('checkpoints/try_ckpt_epoch_1_v0.ckpt'), the following error is now thrown:
OSError: Checkpoint does not contain hyperparameters. Are your model hyperparameters storedin self.hparams?
I built CoolSystem() without self.hparams, as per the example Colab notebook (https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=HOk9c4_35FKg)
Any advice on this?
I guess it should be added to that example. The GAN example has it.
Add the hparams to the __init__, train it, and then try to load again.
Looks like it is always needed, even if you don't pass any hparams in.
Got it! Will take note to always add hparams to __init__ then
@awaelchli find submitting a PR to fix?
i think the point was for hparams to be optional? or should we make it more flexible? @neggert
I can look at it.
To make it optional, I guess we could simply change the loading behaviour depending on whether the user has defined hparams or not.
I will hold back until #849 is finalized because it affects ModelCheckpoint callback.