Pytorch-lightning: How do you save a trained model in standard pytorch format?

Created on 11 Jun 2020  路  12Comments  路  Source: PyTorchLightning/pytorch-lightning

I've been googling how to save the model on it's own so anyone with torch can just load it and start making predictions but I've found it difficult to get documentation on this? My assumption was there would be some way to directly access the underlying pytorch model and just pickle it but I'm unsure how to do this.

question

Most helpful comment

@sambaths is right, it is better not to pickle your whole model class.
You should instead load and save like this:

torch.save(the_model.state_dict(), PATH)
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

from the official PyTorch docs:
https://pytorch.org/docs/stable/notes/serialization.html#recommend-saving-models

but you don't even need to worry about this.
the checkpoints saved by PL can be loaded like this and it has the state dict saved.
You will always be able to use the LightningModule as a nn.Module outside the PL training scripts.

All 12 comments

Since this works similar to a nn.Module , have you tried torch.save(your_model) ?
where your_model is

class your_model(LightningModule):
 ...

But when I tried to use this method most of the times, it would throw some kind of error.
I think it would be better to save the state_dict()
As PyTorch itself doesn't recommend directly saving the model.

Trainer automatically saves all the model weights when it checkpoints.

To know what all it checkpoints, see this.
https://pytorch-lightning.readthedocs.io/en/latest/weights_loading.html?highlight=checkpoint

Traceback (most recent call last):
File "main.py", line 35, in
torch.save(model,date_time.strftime("%Y-%m-%d %H:%M:%S"))
File "/home/michael/anaconda3/envs/pytorch3d/lib/python3.7/site-packages/torch/serialization.py", line 328, in save
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
File "/home/michael/anaconda3/envs/pytorch3d/lib/python3.7/site-packages/torch/serialization.py", line 401, in _legacy_save
pickler.dump(obj)
_pickle.PicklingError: Can't pickle : attribute lookup Lightning_Pipeline on abc failed

I got the above error when I tried to do this

I believe what you described is a common error that is encountered when saving models like this.
see this.
https://stackoverflow.com/questions/42703500/best-way-to-save-a-trained-model-in-pytorch

Please give more details about the code you used, your model definition etc.
@mm04926412

@sambaths is right, it is better not to pickle your whole model class.
You should instead load and save like this:

torch.save(the_model.state_dict(), PATH)
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

from the official PyTorch docs:
https://pytorch.org/docs/stable/notes/serialization.html#recommend-saving-models

but you don't even need to worry about this.
the checkpoints saved by PL can be loaded like this and it has the state dict saved.
You will always be able to use the LightningModule as a nn.Module outside the PL training scripts.

Ah I see so the state dictionary stores the models hyper-parameters and actication functions and everything else not just the weights?

model.state_dict() 

This contains the model weights, specifically Parameters that were learned while training the model(like weights, biases etc.)

If you want to continue training after you have stopped it, you need other information (which is by default saved while checkpointing in PL)
see this link -> https://pytorch-lightning.readthedocs.io/en/latest/weights_loading.html?highlight=checkpoint

if you have to reuse the state_dict(), you have to define the model first, then load the weights like @awaelchli said above.

What I want to do is save the model for active production use so it can be loaded in one line and I can typed something in like loaded_model.predict(data) without even needing to know what the networks architecture is.

If this can't be done easily how are models deployed in pytorch lightning?

PL can be deployed in the same way PyTorch models are deployed.
This issue is a drawback of pickle not that of PL or PyTorch.

From PyTorch Docs:

_The serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors._

You need to preserve the the conditions that exists while saving the model so that you can reload the model without any errors, which is a problem, because in most cases, while we are developing the models, these conditions will change.

This is the reason PyTorch itself, doesn't recommend this.

I'm not trying to develop the model though, the model is finished I am trying to deploy it so it can become a object with a blackbox "predict method" in a different project that has no access to the orginal code in any way.

You can define the architecture of model in a separate .py file and import it along with other necessities(if the model architecture is complex) or you can altogether define the model then and there.

But the recommended and most efficient way to export models is using model.state_dict()

Thanks for you help, in the end I exported it as a torchscript using the torch.jit.trace module and it seemed to work fine.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

edenlightning picture edenlightning  路  3Comments

justusschock picture justusschock  路  3Comments

maxime-louis picture maxime-louis  路  3Comments

monney picture monney  路  3Comments

DavidRuhe picture DavidRuhe  路  3Comments