Currently EarlyStopping's state is updated after the checkpoint callback, so what is being saved here is last epoch's state.
This is somewhat related to #1463 so I am going to use the same code.
Steps to reproduce the behavior:
Install using pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master --upgrade
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import pytorch_lightning as pl
class CoolSystem(pl.LightningModule):
def __init__(self):
super(CoolSystem, self).__init__()
# not the best model...
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_nb):
# REQUIRED
x, y = batch
y_hat = self.forward(x)
return {'loss': F.cross_entropy(y_hat, y)}
def validation_step(self, batch, batch_nb):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
return {'val_loss': F.cross_entropy(y_hat, y)}
def validation_epoch_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
return {'val_loss': avg_loss}
def configure_optimizers(self):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
return torch.optim.Adam(self.parameters(), lr=0.02)
def train_dataloader(self):
# REQUIRED
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
def val_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
model = CoolSystem()
checkpoint_callback = ModelCheckpoint(
filepath='./model_ckpt/whatever_the_name_is_gonna_be_auto_chosen',
save_top_k=-1,
verbose=True,
monitor='val_loss',
mode='auto'
)
class EarlyStoppingPrinting(EarlyStopping):
def on_train_start(self, trainer, pl_module):
print('EarlyStoppingPrinting before on_train_start')
print('self.wait = ', self.wait)
super().on_train_start(trainer, pl_module)
print('EarlyStoppingPrinting after on_train_start')
print('self.wait = ', self.wait)
def on_epoch_end(self, trainer, pl_module):
ret = super().on_epoch_end(trainer, pl_module)
if self.wait:
print('Early stopping patience: %d/%d' % (self.patience-self.wait, self.patience))
return ret
early_stopping = EarlyStoppingPrinting(
monitor='val_loss',
patience=5,
verbose=True,
mode='auto'
)
trainer = Trainer(max_nb_epochs=1000, train_percent_check=0.1,
checkpoint_callback=checkpoint_callback,
early_stop_callback=early_stopping)
trainer.fit(model)
Let the model train until convergence. And then reload the model and see how it continues:
trainer = Trainer(max_nb_epochs=1000, train_percent_check=0.1,
checkpoint_callback=None,
resume_from_checkpoint = 'model_ckpt/_ckpt_epoch_7.ckpt',
early_stop_callback=early_stopping)
trainer.fit(model)
The early_stopping callback would print:
EarlyStoppingPrinting before on_train_start
self.wait = 4
...
and keeps training.
The early_stopping callback should print:
EarlyStoppingPrinting before on_train_start
self.wait = 5
...
and should not be trained again at all since self.wait >= self.patience.
If the model is loaded from an interrupted save, then it should still train after resuming, but with corrected self.wait.
This is ran on Google colab.
https://colab.research.google.com/drive/1ZdiFf6ksNpgsqOdSKM6lMO0yIhqpnTHD
Somewhat related to #1463.
@lizhitwo thanks for this very detailed bug report! looking into it...
at a high level, i think that we should keep the concern of the callback state contained within the callback itself. we can follow the pytorch convention of having methods for state_dict() and load_state_dict(). the trainer can just call those methods rather than reaching in and saving individual attributes. (this more so addresses #1463 but i plan to fix both in a single PR)
@Borda do you know why we need on_train_start? why can't we just set the values in __init__.py?
edit: going to remove on_train_start
the core problem for this issue is the CheckpointCallback runs on on_validation_end which occurs before EarlyStoppingCallback runs during on_epoch_end. the checkpoint callback is not run again after early stopping halts training. the checkpoint includes a state dict of the early stopping values, and as @lizhitwo points out the last saved callback contains the early stopping state of the second to last epoch.
we could move the checkpoint callback to also run during on_epoch_end, but this might not always be desired (eg. if you run validation multiple times per epoch and want all checkpoints).
we could also just write the checkpoint callback to re-run at the end of an epoch, but not sure how we want to handle saving the k best models in this case
cc @PyTorchLightning/core-contributors any suggestions? i'm not a huge fan of either of these approaches
I would like to propose that we do the following:
ckpt_epoch_7_early_stopping.ckpt)@lizhitwo would this be a suitable solution for you needs?
This would be correct only when the training ends normally. When the user hits Ctrl+C to interrupt, the previous checkpoints that are already saved are still lagging one epoch behind.
I would also advise against putting more undocumented checkpoint naming conventions into Lightning. I currently am already confused why Lightning overrides my checkpoint name unless I put e.g. {epoch} in them.
Question: is there a reason why early stopping callback can't be split into two, and the status update moved before checkpoint? You can update its state before checkpoint, and checkpoint, and then query its state to decide if early-stopping should be performed using e.g. a should_early_stop property or something.
yeah you're right, we need a better solution here. i think we can also improve how the checkpoint naming is done but i'll leave that for a separate issue.
Question: is there a reason why early stopping callback can't be split into two, and the status update moved before checkpoint?
it's a good question. this is a bit challenging because our checkpoint callback is currently set up to run every time we iterate over the validation set. this can happen once per epoch, multiple times per epoch, or once every n epochs, depending on how the user has defined various arguments in their Trainer.
here's how i think we should proceed:
training_step output or validation_step output (this is cleaner than guessing imo)on_batch_end) which always runs before the checkpoint callbackon_validation_end) but _before_ the checkpoint callback runscc @lizhitwo and @Borda want to weigh in if this is a good plan?
I think that in some cases you want to monitor both train and valid...
switch from having early stopping return a value to instead set a Trainer attribute
do you mean a Trainer state like we discussed several times before?
I think that there is one more thing we shall think about and it specifies the "unit" for evaluation if the e.g. patience od epoch or batch/step
cc: @PyTorchLightning/core-contributors
I think that in some cases you want to monitor both train and valid...
could you provide an example? in 90% of the cases the user is going to want to monitor something like val_loss. it seems very odd to condition the early stopping criterion on something which bridges both train and valid
do you mean a Trainer state like we discussed several times before?
i'm thinking of setting an attribute.
if self.enable_early_stop:
if (met_min_epochs and met_min_steps) or self.fast_dev_run:
should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model())
# stop training
stop = should_stop and met_min_epochs
if stop:
self.run_training_teardown()
return
to
if self.should_stop: # the checks against met_min_epochs and met_min_steps happen in the callback
return
I think this works, as long as
then no matter when the checkpoint is done, the early-stopping's state is clean.
~Although this would break compatibility with older code, so maybe set the attribute inside Trainer according to early stopping's return value, instead of in early stopping callback which may be overridden a lot?~
I think monitoring both train and val is better left to users, since they need to specify how the criteria is computed anyway. They can choose to compute it during either val or train and add it to the log in one of them.
yes that was a behavioral regression introduced by #1528 i will fix it as well, thanks for catching it! we clearly need better tests to spot these errors
@jeremyjordan submit asap so we can get it in 0.7.4? @lizhitwo thanks for catching that!
actually... this may not be a trivial fix
in (https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/training_loop.py#L459) we need to check it whenever we save weights. In this case, if this is true, then we need to stop training but make sure we do all the rest of the actions we need to call.
In (https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/training_loop.py#L369) we actually want to exit.
But anyhow, @jeremyjordan if you figure this out let's get this into 0.7.4
yes i'm hoping to get this all wrapped up this weekend, it's a bit of a tricky one
Most helpful comment
This would be correct only when the training ends normally. When the user hits Ctrl+C to interrupt, the previous checkpoints that are already saved are still lagging one epoch behind.
I would also advise against putting more undocumented checkpoint naming conventions into Lightning. I currently am already confused why Lightning overrides my checkpoint name unless I put e.g.
{epoch}in them.Question: is there a reason why early stopping callback can't be split into two, and the status update moved before checkpoint? You can update its state before checkpoint, and checkpoint, and then query its state to decide if early-stopping should be performed using e.g. a
should_early_stopproperty or something.