I may have missed something but it seems that ModelCheckpoint does not allow this based on the docs and code?
yup. i guess either return the epoch number as the thing to monitor or we modify to add this option
Hmm. I guess monitoring epoch number could work but I think some modifications should be made to handle the cases where there's no validation loop initialized. What do you think?
This would also be super important for me. I had a quite complicated experiment running on an older version relying on save_best_only = False saving every epoch without validation step. I lost quite a bit of training before I realized it was not saving checkpoints anymore. @williamFalcon is there a workaround? Like putting an empty validation step?
I also need such functionality. @simonjaq did you found workaround for this problem?
Hi. I made a custom checkpoint. Copied all the code but changed: def on_validation_end(self): to def on_epoch_end(self):
Then I'm calling the checkpoint in the Lightning loop:
def on_epoch_end(self):
trainer.checkpoint_callback.on_epoch_end()
This works quite well. before starting the trainer I do this:
checkpoint_callback = ModelCheckpoint(
filepath='./checkpoints/AD_15',
save_top_k=10,
monitor='g_loss',
verbose=True,
prefix='V0.13.8-RGB'
)
Same here. I'm training one epoch in about 30minutes so am only validating every 10, say, to save time. So need to save every epoch without validating. @simonjaq can you point me in the right direction - which code did you copy? Was that callbacks.model_checkpoint or like here?
Hello
I took the whole code from /pytorch_lightning/callbacks/model_checkpoint.py. And changed line 189 to on_epoch_end
(continues below code block)
"""
Callbacks
=========
Callbacks supported by Lightning
"""
import os
import shutil
import logging as log
import warnings
import numpy as np
class Callback(object):
"""Abstract base class used to build new callbacks."""
def __init__(self):
self._trainer = None
def set_trainer(self, trainer):
"""Make a link to the trainer, so different things like `trainer.current_epoch`,
`trainer.batch_idx`, `trainer.global_step` can be used."""
self._trainer = trainer
def on_epoch_begin(self):
"""Called when the epoch begins."""
pass
def on_epoch_end(self):
"""Called when the epoch ends."""
pass
def on_batch_begin(self):
"""Called when the training batch begins."""
pass
def on_batch_end(self):
"""Called when the training batch ends."""
pass
def on_train_begin(self):
"""Called when the train begins."""
pass
def on_train_end(self):
"""Called when the train ends."""
pass
def on_validation_begin(self):
"""Called when the validation loop begins."""
pass
def on_validation_end(self):
"""Called when the validation loop ends."""
pass
def on_test_begin(self):
"""Called when the test begins."""
pass
def on_test_end(self):
"""Called when the test ends."""
pass
_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization"
class ModelCheckpoint(Callback):
r"""
Save the model after every epoch.
Args:
filepath (str): path to save the model file.
Can contain named formatting options to be auto-filled.
Example::
# save epoch and val_loss in name
ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')
# saves file like: /path/epoch_2-val_loss_0.2.hdf5
monitor (str): quantity to monitor.
verbose (bool): verbosity mode, 0 or 1.
save_top_k (int): if `save_top_k == k`,
the best k models according to
the quantity monitored will be saved.
if `save_top_k == 0`, no models are saved.
if `save_top_k == -1`, all models are saved.
Please note that the monitors are checked every `period` epochs.
if `save_top_k >= 2` and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with `v0`.
mode (str): one of {auto, min, max}.
If `save_top_k != 0`, the decision
to overwrite the current save file is made
based on either the maximization or the
minimization of the monitored quantity. For `val_acc`,
this should be `max`, for `val_loss` this should
be `min`, etc. In `auto` mode, the direction is
automatically inferred from the name of the monitored quantity.
save_weights_only (bool): if True, then only the model's weights will be
saved (`model.save_weights(filepath)`), else the full model
is saved (`model.save(filepath)`).
period (int): Interval (number of epochs) between checkpoints.
Example::
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(filepath='my_path')
Trainer(checkpoint_callback=checkpoint_callback)
# saves checkpoints to my_path whenever 'val_loss' has a new min
"""
def __init__(self, filepath, monitor='val_loss', verbose=0,
save_top_k=1, save_weights_only=False,
mode='auto', period=1, prefix=''):
super(ModelCheckpoint, self).__init__()
if (
save_top_k and
os.path.isdir(filepath) and
len(os.listdir(filepath)) > 0
):
warnings.warn(
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
)
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
os.makedirs(filepath, exist_ok=True)
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.epochs_since_last_check = 0
self.prefix = prefix
self.best_k_models = {}
# {filename: monitor}
self.kth_best_model = ''
self.best = 0
if mode not in ['auto', 'min', 'max']:
warnings.warn(
f'ModelCheckpoint mode {mode} is unknown, '
'fallback to auto mode.', RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
self.kth_value = np.Inf
self.mode = 'min'
elif mode == 'max':
self.monitor_op = np.greater
self.kth_value = -np.Inf
self.mode = 'max'
else:
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
self.monitor_op = np.greater
self.kth_value = -np.Inf
self.mode = 'max'
else:
self.monitor_op = np.less
self.kth_value = np.Inf
self.mode = 'min'
def _del_model(self, filepath):
dirpath = os.path.dirname(filepath)
# make paths
os.makedirs(dirpath, exist_ok=True)
try:
shutil.rmtree(filepath)
except OSError:
os.remove(filepath)
def _save_model(self, filepath):
dirpath = os.path.dirname(filepath)
# make paths
os.makedirs(dirpath, exist_ok=True)
# delegate the saving to the model
self.save_function(filepath)
def check_monitor_top_k(self, current):
less_than_k_models = len(self.best_k_models.keys()) < self.save_top_k
if less_than_k_models:
return True
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
def on_epoch_end(self):
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG
logs = self._trainer.callback_metrics
epoch = self._trainer.current_epoch
self.epochs_since_last_check += 1
if self.save_top_k == 0:
# no models are saved
return
if self.epochs_since_last_check >= self.period:
self.epochs_since_last_check = 0
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt'
version_cnt = 0
while os.path.isfile(filepath):
# this epoch called before
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt'
version_cnt += 1
if self.save_top_k != -1:
current = logs.get(self.monitor)
if current is None:
warnings.warn(
f'Can save best model only with {self.monitor} available,'
' skipping.', RuntimeWarning)
else:
if self.check_monitor_top_k(current):
# remove kth
if len(self.best_k_models.keys()) == self.save_top_k:
delpath = self.kth_best_model
self.best_k_models.pop(self.kth_best_model)
self._del_model(delpath)
self.best_k_models[filepath] = current
if len(self.best_k_models.keys()) == self.save_top_k:
# monitor dict has reached k elements
if self.mode == 'min':
self.kth_best_model = max(self.best_k_models, key=self.best_k_models.get)
else:
self.kth_best_model = min(self.best_k_models, key=self.best_k_models.get)
self.kth_value = self.best_k_models[self.kth_best_model]
if self.mode == 'min':
self.best = min(self.best_k_models.values())
else:
self.best = max(self.best_k_models.values())
if self.verbose > 0:
log.info(
f'\nEpoch {epoch:05d}: {self.monitor} reached'
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
f' {filepath} as top {self.save_top_k}')
self._save_model(filepath)
else:
if self.verbose > 0:
log.info(
f'\nEpoch {epoch:05d}: {self.monitor}'
f' was not in top {self.save_top_k}')
else:
if self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
self._save_model(filepath)`
In my train loop I call the checkpoint:
class DiscriminatorGenerator(pl.LightningModule):
... ## all my training code ....
def on_epoch_end(self):
trainer.checkpoint_callback.on_epoch_end()
My training block looks like this:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
checkpoint_callback = ModelCheckpoint(
filepath='./checkpoints/AD_15',
save_top_k=10,
monitor='g_loss',
#save_best_only=False,
verbose=True,
prefix='V0.13.8-RGB'
)
#default logger used by trainer
logger = TensorBoardLogger(
save_dir='./logs',
version=252,
name='lightning_logs'
)
trainer = Trainer(
logger=logger,
min_epochs=2000,
max_epochs=5000,
checkpoint_callback=checkpoint_callback,
amp_level='O1', use_amp=True,
gpus=1,
weights_summary='full'
)
This works for me. Note that I work in a Jupyter notebook and just insert the modified callback somewhere at the beginning of the notebook. This should also work by importing your modified callback.
Hopefully this is an improvement @williamFalcon (but still doesn't allow saving all models independent of validation).
Just for anyone else, I couldn't get the above to work. pl versions are different. Seemed to get messy putting trainer into model. I'm now saving every epoch, while still validating n > 1 epochs using this custom callback. Doesn't require adjusting of callbacks.model_checkpoint.py. fairly hacky and redoes filenames, but works.
`
class Non_val_epoch_saves(pl.Callback):
def __init__(self, iteration, filepath):
self.iteration = iteration
self.filepath = filepath
self.ver = int(self.iteration[-1])
if any(self.iteration in x for x in os.listdir(self.filepath)):
self.ver += 1
def on_epoch_end(self, trainer, pl_module):
metrics = trainer.callback_metrics
if 'avg_val_loss' in metrics:
avl = metrics['avg_val_loss']
avl = f'{avl:.3f}'
else:
avl = 'NA'
tl = metrics.get(trainer.checkpoint_callback.monitor)
current_tl = f'{tl:0.3f}'
self.name = self.iteration[:-1] + str(self.ver) + '_epo='+ \
str(trainer.current_epoch) + \
'_tloss' + '=' + \
current_tl + '_' + \
'avloss=' + avl + \
'.ckpt'
trainer.checkpoint_callback._save_model(filepath=os.path.join(
self.filepath, self.name)
)`
which is called like:
`iteration = '18Mar_v0'
callback_dir = os.path.join(DATADIR,'dev_test_models/ckpts_' + iteration +'/')
callback = ModelCheckpoint(
filepath=callback_dir,
monitor='loss',
verbose=1,
save_top_k=0,
save_weights_only=False,
mode='min',
period=1,
prefix='''
)
trainer = Trainer(
accumulate_grad_batches=6,
callbacks=[Non_val_epoch_saves(
iteration=iteration,
filepath=callback_dir
)],
checkpoint_callback=callback,
check_val_every_n_epoch=2,
distributed_backend='ddp')
`
What I did was
class ModelCheckpointAtEpochEnd(pl.Callback):
def on_epoch_end(self, trainer, pl_module):
metrics = trainer.callback_metrics
metrics['epoch'] = trainer.current_epoch
if trainer.disable_validation:
trainer.checkpoint_callback.on_validation_end(trainer, pl_module)
And add this callback to the trainer too, and set the checkpoint_callback to do its thing. So it would be nice if this is added inside Lightning.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
Please reopen this issue.
Most helpful comment
Please reopen this issue.