Early stopping does not work anymore. When I downgrade from 0.7.1 or the current dev version to 0.6.0 early stopping works again, with the same code.
def main(hparams):
if hparams.early_stopping == 'yes':
early_stopping = EarlyStopping(
monitor='batch/mean_absolute_loss',
min_delta=hparams.min_delta,
patience=hparams.patience,
mode='min'
)
else:
early_stopping = False
model = MemoryTest(hparams)
trainer = pl.Trainer(
val_percent_check=0,
early_stop_callback=early_stopping,
default_save_path=src.settings.LOG_DIR,
max_epochs=hparams.epochs
)
trainer.fit(model)
class MemoryTest(pl.LightningModule):
# Main Testing Unit for Experiments on Recurrent Cells
def __init__(self, hp):
super(MemoryTest, self).__init__()
self.predict_col = hp.predict_col
self.n_datasamples = hp.n_datasamples
self.dataset = hp.dataset
if self.dataset is 'rand':
self.seq_len = None
else:
self.seq_len = hp.seq_len
self.hparams = hp
self.learning_rate = hp.learning_rate
self.training_losses = []
self.final_loss = None
self.model = RecurrentModel(1, hp.n_cells, hp.n_layers, celltype=hp.celltype)
def forward(self, input, input_len):
return self.model(input, input_len)
def training_step(self, batch, batch_idx):
x, y, input_len = batch
features_y = self.forward(x, input_len)
loss = F.mse_loss(features_y, y)
mean_absolute_loss = F.l1_loss(features_y, y)
self.training_losses.append(mean_absolute_loss.item())
neptune_logs = {'batch/train_loss': loss, 'batch/mean_absolute_loss': mean_absolute_loss}
return {'loss': loss, 'batch/mean_absolute_loss': mean_absolute_loss, 'log': neptune_logs}
def on_epoch_end(self):
train_loss_mean = np.mean(self.training_losses)
self.final_loss = train_loss_mean
self.training_losses = [] # reset for next epoch
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=self.learning_rate)
@pl.data_loader
def train_dataloader(self):
train_dataset = dg.RandomDataset(self.predict_col, self.n_datasamples)
if self.dataset == 'rand_fix':
train_dataset = dg.RandomDatasetFix(self.predict_col, self.n_datasamples, self.seq_len)
if self.dataset == 'correlated':
train_dataset = dg.CorrelatedDataset(self.predict_col, self.n_datasamples)
train_loader = DataLoader(dataset=train_dataset, batch_size=1)
return train_loader
@staticmethod
def add_model_specific_args(parent_parser):
# MODEL specific
model_parser = ArgumentParser(parents=[parent_parser])
model_parser.add_argument('--learning_rate', default=1e-2, type=float)
model_parser.add_argument('--n_layers', default=1, type=int)
model_parser.add_argument('--n_cells', default=5, type=int)
model_parser.add_argument('--celltype', default='LSTM', type=str)
# training specific (for this model)
model_parser.add_argument('--epochs', default=500, type=int)
model_parser.add_argument('--patience', default=5, type=int)
model_parser.add_argument('--min_delta', default=0.1, type=float)
model_parser.add_argument('--early_stopping', default='yes', type=str)
# data specific
model_parser.add_argument('--n_datasamples', default=1000, type=int)
model_parser.add_argument('--seq_len', default=10, type=int)
model_parser.add_argument('--dataset', default='rand', type=str)
model_parser.add_argument('--predict_col', default=1, type=int)
return model_parser
Early-stopping to take effect again.
@Dunrar would you check it on actual master?
@Borda do you mean the bleeding edge version via pip install git+git://github.com/PyTorchLightning/pytorch-lightning.git?
Okay, I tried that but early stopping still does not work
The code sample you provide does not define a validation step/end/dataloader.
I would expect that early stopping does not work without it. How could it?
if no val step is present, it uses the training step for early stopping
oh, my bad! Then I will have a closer look at this issue.
@awaelchli little update. In training_loop.py the line if self.enable_early_stop and not self.disable_validation and is_val_epoch: is to blame. Just deleting the self.disable_validation and is_val_epoch checks solves the problem in my case, but there is probably more to take into consideration.
I also came to that point when I looked at it 2 days ago, will have more time to look at it soon. If I remember correctly, the tests didnt pass and I was tracking down at which point the change was introduced to figure out the reason it is there.
@Dunrar Thanks for the help. Your suggestion worked and I was able to make a test so that it doesn't break in the future :)
cheers!
@awaelchli Thank you!