Pytorch-lightning: Support checkpointing for Sub-Epoch period

Created on 24 Sep 2020  路  10Comments  路  Source: PyTorchLightning/pytorch-lightning

Question

When setting period to a fractional value, checkpointing doesn鈥檛 trigger correctly. Additionally I think period should default to val_check_interval, if it doesn鈥檛 already.

To Reproduce

Steps to reproduce the behavior:

Run any model and set checkpoint to run at a fractional value. Only the first checkpoint will be saved.

Expected behavior

A checkpoint should be saved every specified period

Environment

  • Lighting Version: 0.9.0

    • PyTorch Version (e.g., 1.0): 1.6

    • OS (e.g., Linux): Ubuntu 16.04

    • How you installed PyTorch (conda, pip, source): pip

    • Build command you used (if compiling from source):

    • Python version: 3.7

    • CUDA/cuDNN version: 10.1

enhancement help wanted

All 10 comments

I think you need to set period=0, then it could work, just looking at the code:

        if (
            self.epoch_last_check is not None
            and (epoch - self.epoch_last_check) < self.period
        ):
            # skipping in this term
            return

Try it :)
period can only be an integer. Setting it to the val_check_interval does not make sense.

@awaelchli Thanks! This works, but I find it a bit unintuitive. I didn鈥檛 get any warnings or anything trying to set it to a fractional value (the same as Val interval) so I assumed that鈥檚 what I had to do. Perhaps we could auto set this if Val interval is below 1? Or trigger a warning stating the correctly value to set in this condition?

Yes I agree, period=0 only works because of an implementation detail and it is not meant to be used like that. It's a hack. The sub-epoch checkpointing is not supported currenlty. We're looking into that. If you're feeling lucky, give it a try and send a draft PR? :rocket: It is a tricky one though.

Let's classify this as a feature requrest instead of bug?

Ill give it a look over and see if Im able to do it. Feature request sounds good, since it's working as intended

@awaelchli what do you think about these options for checkpointing? I think this suite could be really helpful:
- Currently supported: Checkpoint every N epochs (after validation)
+ Checkpoint every N training batches
+ Checkpoint after N time period (timedelta specified by users)
+ Support for checkpointing on training epoch end if validation steps aren't supported

I think yes, these are all fine use cases.
Given that the current ModelCheckpoint callback is quite complex, it may be hard or become impossible to maintain all these options in a single class. We could consider splitting the functionality into several callbacks. A combination of these features would mean passing several callbacks to the Trainer. But then there are new challenges, like clashing filenames etc.

Support for checkpointing on training epoch end if validation steps aren't supported

is that not already supported?

@awaelchli i believe this is fixed on master (At least for the case of checkpointing with sub epoch validations), since it now checks to make sure we haven鈥檛 saved on the same global step, instead of the same epoch. Can you confirm?

import torch
import pytorch_lightning as pl
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from torchvision.datasets.mnist import MNIST
from torchvision import transforms


class LitClassifier(pl.LightningModule):
    def __init__(self, hidden_dim=128, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()

        self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
        self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('valid_loss', loss)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('test_loss', loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--hidden_dim', type=int, default=128)
        parser.add_argument('--learning_rate', type=float, default=0.0001)
        return parser


def cli_main():
    pl.seed_everything(1234)

    # ------------
    # args
    # ------------
    parser = ArgumentParser()
    parser.add_argument('--batch_size', default=32, type=int)
    parser = pl.Trainer.add_argparse_args(parser)
    parser = LitClassifier.add_model_specific_args(parser)
    args = parser.parse_args()

    # ------------
    # data
    # ------------
    dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
    mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
    mnist_train, mnist_val = random_split(dataset, [55000, 5000])

    train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
    val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
    test_loader = DataLoader(mnist_test, batch_size=args.batch_size)

    # ------------
    # model
    # ------------
    model = LitClassifier(args.hidden_dim, args.learning_rate)

    # ------------
    # training
    # ------------
    trainer = pl.Trainer.from_argparse_args(
        args,
        max_epochs=3,
        val_check_interval=0.25,
        gpus=1,
        checkpoint_callback=ModelCheckpoint(
            filepath="lightning_logs/test/{epoch:d}-{valid_loss:.2f}",
            save_top_k=-1
        )
    )
    trainer.fit(model, train_loader, val_loader)

    # ------------
    # testing
    # ------------
    trainer.test(test_dataloaders=test_loader)


if __name__ == '__main__':
    cli_main()

Yes! I just checked it. Above is the code that I tested with
val_check_interval = .25
It saves 4 checkpoints per epoch

Awesome, closing for now, the other features can likely be added separately

Was this page helpful?
0 / 5 - 0 ratings

Related issues

monney picture monney  路  3Comments

williamFalcon picture williamFalcon  路  3Comments

maxime-louis picture maxime-louis  路  3Comments

baeseongsu picture baeseongsu  路  3Comments

jcreinhold picture jcreinhold  路  3Comments