Pytorch-lightning: Saving of checkpoint after every epoch using ModelCheckpoint if no metric is monitored

Created on 6 Dec 2019  路  12Comments  路  Source: PyTorchLightning/pytorch-lightning

I may have missed something but it seems that ModelCheckpoint does not allow this based on the docs and code?

enhancement help wanted won't fix

Most helpful comment

Please reopen this issue.

All 12 comments

yup. i guess either return the epoch number as the thing to monitor or we modify to add this option

Hmm. I guess monitoring epoch number could work but I think some modifications should be made to handle the cases where there's no validation loop initialized. What do you think?

This would also be super important for me. I had a quite complicated experiment running on an older version relying on save_best_only = False saving every epoch without validation step. I lost quite a bit of training before I realized it was not saving checkpoints anymore. @williamFalcon is there a workaround? Like putting an empty validation step?

I also need such functionality. @simonjaq did you found workaround for this problem?

Hi. I made a custom checkpoint. Copied all the code but changed: def on_validation_end(self): to def on_epoch_end(self):
Then I'm calling the checkpoint in the Lightning loop:
def on_epoch_end(self): trainer.checkpoint_callback.on_epoch_end()

This works quite well. before starting the trainer I do this:
checkpoint_callback = ModelCheckpoint( filepath='./checkpoints/AD_15', save_top_k=10, monitor='g_loss', verbose=True, prefix='V0.13.8-RGB' )

Same here. I'm training one epoch in about 30minutes so am only validating every 10, say, to save time. So need to save every epoch without validating. @simonjaq can you point me in the right direction - which code did you copy? Was that callbacks.model_checkpoint or like here?

Hello
I took the whole code from /pytorch_lightning/callbacks/model_checkpoint.py. And changed line 189 to on_epoch_end

(continues below code block)

"""
Callbacks
=========
Callbacks supported by Lightning
"""

import os
import shutil
import logging as log
import warnings

import numpy as np


class Callback(object):
    """Abstract base class used to build new callbacks."""

    def __init__(self):
        self._trainer = None

    def set_trainer(self, trainer):
        """Make a link to the trainer, so different things like `trainer.current_epoch`,
        `trainer.batch_idx`, `trainer.global_step` can be used."""
        self._trainer = trainer

    def on_epoch_begin(self):
        """Called when the epoch begins."""
        pass

    def on_epoch_end(self):
        """Called when the epoch ends."""
        pass

    def on_batch_begin(self):
        """Called when the training batch begins."""
        pass

    def on_batch_end(self):
        """Called when the training batch ends."""
        pass

    def on_train_begin(self):
        """Called when the train begins."""
        pass

    def on_train_end(self):
        """Called when the train ends."""
        pass

    def on_validation_begin(self):
        """Called when the validation loop begins."""
        pass

    def on_validation_end(self):
        """Called when the validation loop ends."""
        pass

    def on_test_begin(self):
        """Called when the test begins."""
        pass

    def on_test_end(self):
        """Called when the test ends."""
        pass


_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization"


class ModelCheckpoint(Callback):
    r"""
    Save the model after every epoch.
    Args:
        filepath (str): path to save the model file.
            Can contain named formatting options to be auto-filled.
            Example::
                # save epoch and val_loss in name
                ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')
                # saves file like: /path/epoch_2-val_loss_0.2.hdf5
        monitor (str): quantity to monitor.
        verbose (bool): verbosity mode, 0 or 1.
        save_top_k (int): if `save_top_k == k`,
            the best k models according to
            the quantity monitored will be saved.
            if `save_top_k == 0`, no models are saved.
            if `save_top_k == -1`, all models are saved.
            Please note that the monitors are checked every `period` epochs.
            if `save_top_k >= 2` and the callback is called multiple
            times inside an epoch, the name of the saved file will be
            appended with a version count starting with `v0`.
        mode (str): one of {auto, min, max}.
            If `save_top_k != 0`, the decision
            to overwrite the current save file is made
            based on either the maximization or the
            minimization of the monitored quantity. For `val_acc`,
            this should be `max`, for `val_loss` this should
            be `min`, etc. In `auto` mode, the direction is
            automatically inferred from the name of the monitored quantity.
        save_weights_only (bool): if True, then only the model's weights will be
            saved (`model.save_weights(filepath)`), else the full model
            is saved (`model.save(filepath)`).
        period (int): Interval (number of epochs) between checkpoints.
    Example::
        from pytorch_lightning import Trainer
        from pytorch_lightning.callbacks import ModelCheckpoint
        checkpoint_callback = ModelCheckpoint(filepath='my_path')
        Trainer(checkpoint_callback=checkpoint_callback)
        # saves checkpoints to my_path whenever 'val_loss' has a new min
    """

    def __init__(self, filepath, monitor='val_loss', verbose=0,
                 save_top_k=1, save_weights_only=False,
                 mode='auto', period=1, prefix=''):
        super(ModelCheckpoint, self).__init__()
        if (
            save_top_k and
            os.path.isdir(filepath) and
            len(os.listdir(filepath)) > 0
        ):
            warnings.warn(
                f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
                "All files in this directory will be deleted when a checkpoint is saved!"
            )

        self.monitor = monitor
        self.verbose = verbose
        self.filepath = filepath
        os.makedirs(filepath, exist_ok=True)
        self.save_top_k = save_top_k
        self.save_weights_only = save_weights_only
        self.period = period
        self.epochs_since_last_check = 0
        self.prefix = prefix
        self.best_k_models = {}
        # {filename: monitor}
        self.kth_best_model = ''
        self.best = 0

        if mode not in ['auto', 'min', 'max']:
            warnings.warn(
                f'ModelCheckpoint mode {mode} is unknown, '
                'fallback to auto mode.', RuntimeWarning)
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = np.less
            self.kth_value = np.Inf
            self.mode = 'min'
        elif mode == 'max':
            self.monitor_op = np.greater
            self.kth_value = -np.Inf
            self.mode = 'max'
        else:
            if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
                self.monitor_op = np.greater
                self.kth_value = -np.Inf
                self.mode = 'max'
            else:
                self.monitor_op = np.less
                self.kth_value = np.Inf
                self.mode = 'min'

    def _del_model(self, filepath):
        dirpath = os.path.dirname(filepath)

        # make paths
        os.makedirs(dirpath, exist_ok=True)

        try:
            shutil.rmtree(filepath)
        except OSError:
            os.remove(filepath)

    def _save_model(self, filepath):
        dirpath = os.path.dirname(filepath)

        # make paths
        os.makedirs(dirpath, exist_ok=True)

        # delegate the saving to the model
        self.save_function(filepath)

    def check_monitor_top_k(self, current):
        less_than_k_models = len(self.best_k_models.keys()) < self.save_top_k
        if less_than_k_models:
            return True
        return self.monitor_op(current, self.best_k_models[self.kth_best_model])

    def on_epoch_end(self):
        assert self._trainer is not None, _NO_TRAINER_ERROR_MSG

        logs = self._trainer.callback_metrics
        epoch = self._trainer.current_epoch
        self.epochs_since_last_check += 1

        if self.save_top_k == 0:
            # no models are saved
            return
        if self.epochs_since_last_check >= self.period:
            self.epochs_since_last_check = 0
            filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt'
            version_cnt = 0
            while os.path.isfile(filepath):
                # this epoch called before
                filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt'
                version_cnt += 1

            if self.save_top_k != -1:
                current = logs.get(self.monitor)

                if current is None:
                    warnings.warn(
                        f'Can save best model only with {self.monitor} available,'
                        ' skipping.', RuntimeWarning)
                else:
                    if self.check_monitor_top_k(current):

                        # remove kth
                        if len(self.best_k_models.keys()) == self.save_top_k:
                            delpath = self.kth_best_model
                            self.best_k_models.pop(self.kth_best_model)
                            self._del_model(delpath)

                        self.best_k_models[filepath] = current
                        if len(self.best_k_models.keys()) == self.save_top_k:
                            # monitor dict has reached k elements
                            if self.mode == 'min':
                                self.kth_best_model = max(self.best_k_models, key=self.best_k_models.get)
                            else:
                                self.kth_best_model = min(self.best_k_models, key=self.best_k_models.get)
                            self.kth_value = self.best_k_models[self.kth_best_model]

                        if self.mode == 'min':
                            self.best = min(self.best_k_models.values())
                        else:
                            self.best = max(self.best_k_models.values())
                        if self.verbose > 0:
                            log.info(
                                f'\nEpoch {epoch:05d}: {self.monitor} reached'
                                f' {current:0.5f} (best {self.best:0.5f}), saving model to'
                                f' {filepath} as top {self.save_top_k}')
                        self._save_model(filepath)

                    else:
                        if self.verbose > 0:
                            log.info(
                                f'\nEpoch {epoch:05d}: {self.monitor}'
                                f' was not in top {self.save_top_k}')

            else:
                if self.verbose > 0:
                    log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
                self._save_model(filepath)`

In my train loop I call the checkpoint:

class DiscriminatorGenerator(pl.LightningModule):

... ## all my training code ....

    def on_epoch_end(self):

        trainer.checkpoint_callback.on_epoch_end() 

My training block looks like this:

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger


checkpoint_callback = ModelCheckpoint(
    filepath='./checkpoints/AD_15',  
    save_top_k=10,  
    monitor='g_loss',
    #save_best_only=False,
    verbose=True,
    prefix='V0.13.8-RGB'
)



#default logger used by trainer
logger = TensorBoardLogger(
    save_dir='./logs',
    version=252,
    name='lightning_logs'
)

trainer = Trainer(
    logger=logger,
    min_epochs=2000,
    max_epochs=5000,
    checkpoint_callback=checkpoint_callback,
    amp_level='O1', use_amp=True,
    gpus=1,
    weights_summary='full'

)

This works for me. Note that I work in a Jupyter notebook and just insert the modified callback somewhere at the beginning of the notebook. This should also work by importing your modified callback.

Hopefully this is an improvement @williamFalcon (but still doesn't allow saving all models independent of validation).

Just for anyone else, I couldn't get the above to work. pl versions are different. Seemed to get messy putting trainer into model. I'm now saving every epoch, while still validating n > 1 epochs using this custom callback. Doesn't require adjusting of callbacks.model_checkpoint.py. fairly hacky and redoes filenames, but works.

`
class Non_val_epoch_saves(pl.Callback):
def __init__(self, iteration, filepath):
self.iteration = iteration
self.filepath = filepath
self.ver = int(self.iteration[-1])
if any(self.iteration in x for x in os.listdir(self.filepath)):
self.ver += 1

def on_epoch_end(self, trainer, pl_module):

    metrics = trainer.callback_metrics
    if 'avg_val_loss' in metrics:
        avl = metrics['avg_val_loss']
        avl = f'{avl:.3f}'
    else:
        avl = 'NA'
    tl = metrics.get(trainer.checkpoint_callback.monitor)
    current_tl = f'{tl:0.3f}'
    self.name = self.iteration[:-1] + str(self.ver) + '_epo='+ \
        str(trainer.current_epoch) + \
            '_tloss' + '=' + \
        current_tl + '_' + \
            'avloss=' + avl + \
            '.ckpt'
    trainer.checkpoint_callback._save_model(filepath=os.path.join(
        self.filepath, self.name)
        )`

which is called like:
`iteration = '18Mar_v0'

callback_dir = os.path.join(DATADIR,'dev_test_models/ckpts_' + iteration +'/')

callback = ModelCheckpoint(
        filepath=callback_dir,
        monitor='loss',
        verbose=1,
        save_top_k=0,
        save_weights_only=False,
        mode='min',
        period=1,
        prefix='''
        )

trainer = Trainer(
    accumulate_grad_batches=6, 
    callbacks=[Non_val_epoch_saves(
            iteration=iteration,
            filepath=callback_dir
            )],
    checkpoint_callback=callback,
    check_val_every_n_epoch=2,
    distributed_backend='ddp')

`

What I did was

class ModelCheckpointAtEpochEnd(pl.Callback):
    def on_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        metrics['epoch'] = trainer.current_epoch
        if trainer.disable_validation:
            trainer.checkpoint_callback.on_validation_end(trainer, pl_module)

And add this callback to the trainer too, and set the checkpoint_callback to do its thing. So it would be nice if this is added inside Lightning.

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Please reopen this issue.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

Vichoko picture Vichoko  路  3Comments

monney picture monney  路  3Comments

edenlightning picture edenlightning  路  3Comments

mmsamiei picture mmsamiei  路  3Comments

anthonytec2 picture anthonytec2  路  3Comments