Pytorch-lightning: save_weights_only parameter in ModelCheckpoint class look like doesn't work

Created on 24 Oct 2019  路  8Comments  路  Source: PyTorchLightning/pytorch-lightning

Common bugs:

  1. Tensorboard not showing in Jupyter-notebook see issue 79.
  2. PyTorch 1.1.0 vs 1.2.0 support see FAQ

Describe the bug
save_weights_only parameter in ModelCheckpoint class look like doesn't work

document describe save_weight_only like that
save_weights_only: if True, then only the model's weights will be saved (model.save_weights(filepath)), else the full model is saved (model.save(filepath)).

but save_weight_only parameter doesn't save model differently each different options

To Reproduce
Steps to reproduce the behavior:

  1. I used sample script in official document
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

import pytorch_lightning as pl

class CoolSystem(pl.LightningModule):

    def __init__(self):
        super(CoolSystem, self).__init__()
        # not the best model...
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        # REQUIRED
        x, y = batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        return torch.optim.Adam(self.parameters(), lr=0.02)

    @pl.data_loader
    def train_dataloader(self):
        # REQUIRED
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

    @pl.data_loader
    def val_dataloader(self):
        # OPTIONAL
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

    @pl.data_loader
    def test_dataloader(self):
        # OPTIONAL
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

if __name__ == "__main__":
    from pytorch_lightning import Trainer
    from pytorch_lightning.callbacks import ModelCheckpoint

    weight_path = os.path.join(os.getcwd(), 'checkpoint')
    if not os.path.exists(weight_path):
        os.mkdir(weight_path)

    checkpoint_callback = ModelCheckpoint(
        filepath=weight_path,
        save_best_only=False,
        verbose=True,
        monitor='val_loss',
        mode='min',
        prefix='',
        save_weights_only=False
    )

    gpus = torch.cuda.device_count()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CoolSystem()
    model.to(device)
    trainer = Trainer(checkpoint_callback=checkpoint_callback,
                      max_nb_epochs=1, train_percent_check=0.1)
    trainer.fit(model)
  1. i just changed save_weights_only parameter & weight directory for saving. because if i don't model try restore weight

for example i save model in "checkpoint" directory.
after trained i move ckpt file other directory like called test

so, test directory have two model file save_weight_only_True, save_weight_only_False

  1. i was checking what they have some different using torch.load
    (PATH). but not different...

  2. two file has same parameter like this
    save_weights_only_False

{'epoch': 0, 'global_step': 187, 'checkpoint_callback_best': inf, 'optimizer_states': [{'state': {4830682784: {'step': 187, 'exp_avg': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]), 'exp_avg_sq': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])}, 4830682928: {'step': 187, 'exp_avg': tensor([ 2.8793e-12, -2.0038e-03,  1.6457e-03,  9.9219e-12, -1.4364e-02,
         7.6137e-03,  9.4487e-12,  1.0773e-11,  9.9046e-03, -3.7999e-03]), 'exp_avg_sq': tensor([7.2368e-08, 7.3899e-05, 1.2118e-04, 8.5934e-07, 1.5810e-04, 9.1419e-05,
        7.7932e-07, 1.0131e-06, 1.5449e-04, 2.0831e-04])}}, 'param_groups': [{'lr': 0.02, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'params': [4830682784, 4830682928]}]}], 'lr_schedulers': [], 'state_dict': OrderedDict([('l1.weight', tensor([[ 0.0198, -0.0171, -0.0024,  ..., -0.0154,  0.0243, -0.0241],
        [ 0.0349, -0.0040, -0.0051,  ...,  0.0159,  0.0295,  0.0129],
        [ 0.0302, -0.0218, -0.0171,  ...,  0.0104,  0.0179, -0.0167],
        ...,
        [-0.0284,  0.0056,  0.0178,  ...,  0.0211, -0.0075,  0.0163],
        [ 0.0145,  0.0204,  0.0121,  ..., -0.0013, -0.0103,  0.0043],
        [ 0.0059, -0.0060,  0.0017,  ..., -0.0214,  0.0273, -0.0288]])), ('l1.bias', tensor([-0.1228,  0.3341, -0.0163, -0.1302,  0.1738,  0.3422, -0.1155, -0.1122,
        -0.5156, -0.1902]))])}

save_weights_only_True

{'epoch': 0, 'global_step': 187, 'checkpoint_callback_best': inf, 'optimizer_states': [{'state': {4877819552: {'step': 187, 'exp_avg': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]), 'exp_avg_sq': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])}, 4877819696: {'step': 187, 'exp_avg': tensor([ 1.4936e-11, -1.6439e-03, -6.1600e-04, -4.5436e-03, -8.8385e-03,
         4.9613e-12,  1.6526e-11,  4.4405e-03,  1.8937e-12, -6.5171e-04]), 'exp_avg_sq': tensor([1.9473e-06, 8.8430e-05, 1.4157e-04, 1.1038e-04, 1.3500e-04, 2.1486e-07,
        2.3841e-06, 1.1693e-04, 3.1302e-08, 2.3050e-04])}}, 'param_groups': [{'lr': 0.02, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'params': [4877819552, 4877819696]}]}], 'lr_schedulers': [], 'state_dict': OrderedDict([('l1.weight', tensor([[-0.0299,  0.0281, -0.0246,  ...,  0.0143,  0.0314,  0.0167],
        [ 0.0066, -0.0160,  0.0200,  ..., -0.0266,  0.0097,  0.0138],
        [ 0.0257,  0.0134, -0.0111,  ..., -0.0201,  0.0199, -0.0146],
        ...,
        [ 0.0260, -0.0082, -0.0049,  ...,  0.0277, -0.0070,  0.0275],
        [ 0.0089, -0.0003, -0.0051,  ..., -0.0086,  0.0285, -0.0252],
        [-0.0202,  0.0252,  0.0083,  ..., -0.0144, -0.0181,  0.0105]])), ('l1.bias', tensor([-0.1031,  0.2837,  0.0276, -0.1406,  0.2007, -0.1086, -0.1267,  0.2557,
        -0.1239, -0.0374]))])}

Expected behavior
if save_weights_only: if True
expected value like this

{'state_dict': OrderedDict([('l1.weight', tensor([[-0.0299,  0.0281, -0.0246,  ...,  0.0143,  0.0314,  0.0167],
        [ 0.0066, -0.0160,  0.0200,  ..., -0.0266,  0.0097,  0.0138],
        [ 0.0257,  0.0134, -0.0111,  ..., -0.0201,  0.0199, -0.0146],
        ...,
        [ 0.0260, -0.0082, -0.0049,  ...,  0.0277, -0.0070,  0.0275],
        [ 0.0089, -0.0003, -0.0051,  ..., -0.0086,  0.0285, -0.0252],
        [-0.0202,  0.0252,  0.0083,  ..., -0.0144, -0.0181,  0.0105]])), ('l1.bias', tensor([-0.1031,  0.2837,  0.0276, -0.1406,  0.2007, -0.1086, -0.1267,  0.2557,
        -0.1239, -0.0374]))])}

Screenshots
If applicable, add screenshots to help explain your problem.

Desktop (please complete the following information):

  • OS: MacBook Pro (15-inch, 2018); Mojave
  • Browser: chrome
  • Version: pytorch-lightning==0.5.2.1, torch==1.3.0.post2, torchvision==0.4.1.post2, test-tube==0.7.3

Additional context
I try to find some reason, why save_weights_only parameter doesn't work

i found TrainerIOMixin class inside PyTorch lightning. and I feel save_weights_only parameter not was implemented in PyTorch lightning

Priority P0 bug / fix

Most helpful comment

I also just run into this and went ahead and created a draft PR. Saving only the weights is working. However, I haven't changed any logic regarding loading.

All 8 comments

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@ssaru could you check it with the latest version on master?

Appreciate your reply.
I will check it

@ssaru I assume that it was fixed, if not pls feel free to reopen... :robot:

@Borda It's not fixed yet.
dump_checkpoint still returns everything regardless of the save_weights_only parameter.

I just ran into this bug myself. I can work on it, but probably not for another week or so. I took a quick look through the code and it doesn't seem _too_ difficult to fix, but as I'm not familiar with the entire code base there might be some distant issue I haven't seen, but I think dump_checkpoint would need an argument and then it could save nothing but the state_dict; plus there'd have to be corresponding checks for the extra checkpoint keys added to restore and restore_training_state.

@rightaditya mind draft a PR and we can help to finish it fast... :]

I also just run into this and went ahead and created a draft PR. Saving only the weights is working. However, I haven't changed any logic regarding loading.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

monney picture monney  路  3Comments

Vichoko picture Vichoko  路  3Comments

iakremnev picture iakremnev  路  3Comments

versatran01 picture versatran01  路  3Comments

williamFalcon picture williamFalcon  路  3Comments