Pytorch-lightning: Save checkpoint and validate every n steps

Created on 7 Jul 2020  ·  13Comments  ·  Source: PyTorchLightning/pytorch-lightning

❓ Questions and Help

How to save checkpoint and validate every n steps.

I saw there is a val_check_interval, but it seems it's not for that purpose.

enhancement help wanted won't fix

Most helpful comment

Hi all, I believe I figured out how to save every N steps, independent of validation metrics. All you need to do is create a Callback that overrides on_batch_end in v0.8.5 (or on_train_batch_end in v0.9+). The below code will save to the same directory as other checkpoints.

import os
import pytorch_lightning as pl

class CheckpointEveryNSteps(pl.Callback):
    """
    Save a checkpoint every N steps, instead of Lightning's default that checkpoints
    based on validation loss.
    """

    def __init__(
        self,
        save_step_frequency,
        prefix="N-Step-Checkpoint",
        use_modelcheckpoint_filename=False,
    ):
        """
        Args:
            save_step_frequency: how often to save in steps
            prefix: add a prefix to the name, only used if
                use_modelcheckpoint_filename=False
            use_modelcheckpoint_filename: just use the ModelCheckpoint callback's
                default filename, don't use ours.
        """
        self.save_step_frequency = save_step_frequency
        self.prefix = prefix
        self.use_modelcheckpoint_filename = use_modelcheckpoint_filename

    def on_batch_end(self, trainer: pl.Trainer, _):
        """ Check if we should save a checkpoint after every train batch """
        epoch = trainer.current_epoch
        global_step = trainer.global_step
        if global_step % self.save_step_frequency == 0:
            if self.use_modelcheckpoint_filename:
                filename = trainer.checkpoint_callback.filename
            else:
                filename = f"{self.prefix}_{epoch=}_{global_step=}.ckpt"
            ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filename)
            trainer.save_checkpoint(ckpt_path)

Trainer(callbacks=[CheckpointEveryNSteps()])

I realize this answers a slightly different question than the original Issue asked for (this doesn't validate), but I'll leave it here because N-Step checkpointing is a common usecase.

All 13 comments

Set check_val_every_n_epoch=n in Trainer may do the trick.

Won’t that do validation every n epochs?

Oh, you are right, I misunderstand your question.

As per my understanding, you want to validate the model after every n steps in the same epoch. If I am correct then val_check_interval does the same:
https://github.com/PyTorchLightning/pytorch-lightning/blob/7ef73f242ad4e5f14e6c967c8639c3a65285a048/pytorch_lightning/trainer/evaluation_loop.py#L52-L68

Not really.
First val_check_interval can't be bigger than the number of training batch.
Second, it is the number of steps in one epoch. that is, do when batch_idx % val_check_interval == 0 instead of global_step % val_check_interval == 0.

My current workaround is:
set check_val_every_n_epoch to be float('inf').

And add such callback:


class ValEveryNSteps(pl.Callback):
    def __init__(self, every_n_step):
        self.every_n_step = every_n_step

    def on_batch_end(self, trainer, pl_module):
        if trainer.global_step % self.every_n_step == 0 and trainer.global_step != 0:
            trainer.run_evaluation(test_mode=False)

Ok, your n is global_steps.

Let's make this issue into a feature request, allowing val_check_interval > len(train_dataloader)?

This is an important feature. Especially for large datasets where an epoch may take a whole day, we might want to save a checkpoint in between epochs in case something goes wrong. We need a way to checkpoint based on steps, or in between epochs.

Saving a checkpoint every N steps should really _not_ be tied to validation.

For some models it doesn't make sense to monitor a decreasing validation loss. For example: vanilla GANs expect a constantly shifting loss value between generator and discriminator. We need independent N-steps checkpointing.

This is an important feature. Especially for large datasets where an epoch may take a whole day, we might want to save a checkpoint in between epochs in case something goes wrong.

Also for the opposite; if you have very short epochs, you don't wanna spend time/disk-space saving so many checkpoints.

Hi all, I believe I figured out how to save every N steps, independent of validation metrics. All you need to do is create a Callback that overrides on_batch_end in v0.8.5 (or on_train_batch_end in v0.9+). The below code will save to the same directory as other checkpoints.

import os
import pytorch_lightning as pl

class CheckpointEveryNSteps(pl.Callback):
    """
    Save a checkpoint every N steps, instead of Lightning's default that checkpoints
    based on validation loss.
    """

    def __init__(
        self,
        save_step_frequency,
        prefix="N-Step-Checkpoint",
        use_modelcheckpoint_filename=False,
    ):
        """
        Args:
            save_step_frequency: how often to save in steps
            prefix: add a prefix to the name, only used if
                use_modelcheckpoint_filename=False
            use_modelcheckpoint_filename: just use the ModelCheckpoint callback's
                default filename, don't use ours.
        """
        self.save_step_frequency = save_step_frequency
        self.prefix = prefix
        self.use_modelcheckpoint_filename = use_modelcheckpoint_filename

    def on_batch_end(self, trainer: pl.Trainer, _):
        """ Check if we should save a checkpoint after every train batch """
        epoch = trainer.current_epoch
        global_step = trainer.global_step
        if global_step % self.save_step_frequency == 0:
            if self.use_modelcheckpoint_filename:
                filename = trainer.checkpoint_callback.filename
            else:
                filename = f"{self.prefix}_{epoch=}_{global_step=}.ckpt"
            ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filename)
            trainer.save_checkpoint(ckpt_path)

Trainer(callbacks=[CheckpointEveryNSteps()])

I realize this answers a slightly different question than the original Issue asked for (this doesn't validate), but I'll leave it here because N-Step checkpointing is a common usecase.

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

Vichoko picture Vichoko  ·  3Comments

anthonytec2 picture anthonytec2  ·  3Comments

maxime-louis picture maxime-louis  ·  3Comments

justusschock picture justusschock  ·  3Comments

edenlightning picture edenlightning  ·  3Comments