How to save checkpoint and validate every n steps.
I saw there is a val_check_interval, but it seems it's not for that purpose.
Set check_val_every_n_epoch=n
in Trainer
may do the trick.
Won’t that do validation every n epochs?
Oh, you are right, I misunderstand your question.
As per my understanding, you want to validate the model after every n steps in the same epoch. If I am correct then val_check_interval
does the same:
https://github.com/PyTorchLightning/pytorch-lightning/blob/7ef73f242ad4e5f14e6c967c8639c3a65285a048/pytorch_lightning/trainer/evaluation_loop.py#L52-L68
Not really.
First val_check_interval can't be bigger than the number of training batch.
Second, it is the number of steps in one epoch. that is, do when batch_idx % val_check_interval == 0 instead of global_step % val_check_interval == 0.
My current workaround is:
set check_val_every_n_epoch to be float('inf').
And add such callback:
class ValEveryNSteps(pl.Callback):
def __init__(self, every_n_step):
self.every_n_step = every_n_step
def on_batch_end(self, trainer, pl_module):
if trainer.global_step % self.every_n_step == 0 and trainer.global_step != 0:
trainer.run_evaluation(test_mode=False)
Ok, your n is global_steps
.
Let's make this issue into a feature request, allowing val_check_interval > len(train_dataloader)
?
This is an important feature. Especially for large datasets where an epoch may take a whole day, we might want to save a checkpoint in between epochs in case something goes wrong. We need a way to checkpoint based on steps, or in between epochs.
Saving a checkpoint every N steps should really _not_ be tied to validation.
For some models it doesn't make sense to monitor a decreasing validation loss. For example: vanilla GANs expect a constantly shifting loss value between generator and discriminator. We need independent N-steps checkpointing.
This is an important feature. Especially for large datasets where an epoch may take a whole day, we might want to save a checkpoint in between epochs in case something goes wrong.
Also for the opposite; if you have very short epochs, you don't wanna spend time/disk-space saving so many checkpoints.
Hi all, I believe I figured out how to save every N steps, independent of validation metrics. All you need to do is create a Callback that overrides on_batch_end
in v0.8.5 (or on_train_batch_end
in v0.9+). The below code will save to the same directory as other checkpoints.
import os
import pytorch_lightning as pl
class CheckpointEveryNSteps(pl.Callback):
"""
Save a checkpoint every N steps, instead of Lightning's default that checkpoints
based on validation loss.
"""
def __init__(
self,
save_step_frequency,
prefix="N-Step-Checkpoint",
use_modelcheckpoint_filename=False,
):
"""
Args:
save_step_frequency: how often to save in steps
prefix: add a prefix to the name, only used if
use_modelcheckpoint_filename=False
use_modelcheckpoint_filename: just use the ModelCheckpoint callback's
default filename, don't use ours.
"""
self.save_step_frequency = save_step_frequency
self.prefix = prefix
self.use_modelcheckpoint_filename = use_modelcheckpoint_filename
def on_batch_end(self, trainer: pl.Trainer, _):
""" Check if we should save a checkpoint after every train batch """
epoch = trainer.current_epoch
global_step = trainer.global_step
if global_step % self.save_step_frequency == 0:
if self.use_modelcheckpoint_filename:
filename = trainer.checkpoint_callback.filename
else:
filename = f"{self.prefix}_{epoch=}_{global_step=}.ckpt"
ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filename)
trainer.save_checkpoint(ckpt_path)
Trainer(callbacks=[CheckpointEveryNSteps()])
I realize this answers a slightly different question than the original Issue asked for (this doesn't validate), but I'll leave it here because N-Step checkpointing is a common usecase.
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!
Most helpful comment
Hi all, I believe I figured out how to save every N steps, independent of validation metrics. All you need to do is create a Callback that overrides
on_batch_end
in v0.8.5 (oron_train_batch_end
in v0.9+). The below code will save to the same directory as other checkpoints.I realize this answers a slightly different question than the original Issue asked for (this doesn't validate), but I'll leave it here because N-Step checkpointing is a common usecase.