torch.save(trainer.model, "model.pth") throwing error in pytorch lightning version 0.10
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
class LitAutoEncoder(pl.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
def training_step(self, batch, batch_idx):
# training_step defined the train loop. It is independent of forward
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])
autoencoder = LitAutoEncoder()
trainer = pl.Trainer(max_epochs=1)
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))
torch.save(trainer.model, "model.pth")
`
Install pytorch lightning 0.9 and run the script mentioned above. Model will get saved successfully.
Upgrade pytorch lightning version to 0.10 using pip install pytorch_lightning==0.10 and run the same script, the error would be reproduced
Model should be saved as model.pth file.
Stack Trace
Traceback (most recent call last):
File "encoder.py", line 43, in <module>
torch.save(trainer.model, "model.pth")
File "/home/ubuntu/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/serialization.py", line 364, in save
_save(obj, opened_zipfile, pickle_module, pickle_protocol)
File "/home/ubuntu/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/serialization.py", line 466, in _save
pickler.dump(obj)
TypeError: 'NoneType' object is not callable
torch.save(trainer.model.state_dict(), "model.pth")
or
torch.save(trainer.get_model().state_dict(), "model.pth")
Yes. saving statedict works as expected in both version(0.9, 0.10). However, we are using library which dumps the entire model using torch.save. It was working in 0.9 and in the latest release we could find the error mentioned in the stack trace section.
However, we are using library which dumps the entire model using torch.save
sorry I didn't get this.
Also the code and stack trace doesn't match here.
torch.save(trainer.model, "model.pth")
Traceback (most recent call last):
File "/tmp/test.py", line 43, in <module>
torch.save(autoencoder, "autoencoder.pth") <---------- here
File "/home/ubuntu/anaconda3/lib/python3.8/site-packages/torch/serialization.py", line 364, in save
_save(obj, opened_zipfile, pickle_module, pickle_protocol)
File "/home/ubuntu/anaconda3/lib/python3.8/site-packages/torch/serialization.py", line 466, in _save
pickler.dump(obj)
TypeError: 'NoneType' object is not callable
The idea is to log the entire model into mlflow using mlflow.pytorch.log_model - https://www.mlflow.org/docs/latest/python_api/mlflow.pytorch.html . The library dumps the entire model into mlflow using torch.save and it is not working in 0.10 and 1.0 version.
Updated the stack trace.
I found two attributes that are causing pickle to throw an error:
trainer.model.module._results
trainer.model.module.trainer.accelerator_backend.interactive_ddp_procs
Deleting these keys allows the model to be saved. First one is required for single proc/GPU, both are required for multiGPU/proc (instead of interactive_ddp_procs, you'll need to delete mp_queue)
The first is an object called Result: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/step_result.py#L26
This doesn't override __get_state__ __set_state__ functions or inherits any for serialization, thus throws an error (as it falls back to __get_attr__.
The second one is an accelerator class storing processes that are not cleaned up. This also happens in ddp_cpu as well.
ddp: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/accelerators/ddp_accelerator.py#L134
ddp_cpu: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py#L56
cc @rohitgr7 @williamFalcon not entirely sure the direction we'd like to go. trainer.save_checkpoint via lightning handles saving, excluding these unserialisable objects.
@SeanNaren , let's make those two picklable?
@williamFalcon sure, will do this!
Most helpful comment
@williamFalcon sure, will do this!