Pytorch-lightning: Mismatch between on_validation_epoch_end and on_train_epoch_end

Created on 4 Oct 2020  路  2Comments  路  Source: PyTorchLightning/pytorch-lightning

馃悰 Bug

The behaviour of the callback hooks on_train_epoch_end and on_validation_epoch_end do not match. While on_train_epoch_end can access metrics of the same epoch from the validation_epoch_end method the opposite is not true. More concretely, when accessing trainer.callback_metrics from on_validation_epoch_end the values logged in training_epoch_end correspond to the previous epoch rather than the current one.

To Reproduce

On master:

class LitModel(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        return x

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('loss', loss)
        return {'loss': loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('val_loss', loss)
        return {"val_loss": loss}

    def training_epoch_end(self, outputs):
        loss_val = torch.stack([x['loss'] for x in outputs]).mean()
        self.log('train_loss_epoch', loss_val)

    def validation_epoch_end(self, outputs):
        loss_val = torch.stack([x['val_loss'] for x in outputs]).mean()
        self.log('val_loss_epoch', loss_val)

class CustomCallback(Callback):


    def on_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        epoch = trainer.current_epoch
        print(f"Epoch {epoch}: {metrics}")

    def on_validation_epoch_end(self, trainer, pl_module):
        print(f"Val_epoch_end: {trainer.callback_metrics}")

    def on_train_epoch_end(self, trainer, pl_module):
        print(f"Train_epoch_end: {trainer.callback_metrics}")
        print("\n")


model = LitModel()

dataset = MNIST(os.getcwd(), download=True, train=False, transform=transforms.ToTensor())
train, val = random_split(dataset, [9000, 1000])

train_loader = DataLoader(train, batch_size=32)
val_loader = DataLoader(val, batch_size=32)

# train
trainer = pl.Trainer(gpus=1, progress_bar_refresh_rate=0, max_epochs=10, 
                     num_sanity_val_steps=0, callbacks=[CustomCallback()])
trainer.fit(model, train_loader, val_loader)

Observed behaviour

Val_epoch_end: {'loss': tensor(0.3406, device='cuda:0'), 'val_loss': tensor(0.1225, device='cuda:0'), 'val_loss_epoch': tensor(0.2968, device='cuda:0')}
Epoch 0: {'loss': tensor(0.3406, device='cuda:0'), 'val_loss': tensor(0.1225, device='cuda:0'), 'val_loss_epoch': tensor(0.2968, device='cuda:0'), 'train_loss_epoch': tensor(0.6143, device='cuda:0')}
Train_epoch_end: {'loss': tensor(0.3406, device='cuda:0'), 'val_loss': tensor(0.1225, device='cuda:0'), 'val_loss_epoch': tensor(0.2968, device='cuda:0'), 'train_loss_epoch': tensor(0.6143, device='cuda:0')}


Val_epoch_end: {'loss': tensor(0.2575, device='cuda:0'), 'val_loss': tensor(0.0525, device='cuda:0'), 'val_loss_epoch': tensor(0.2291, device='cuda:0'), 'train_loss_epoch': tensor(0.6143, device='cuda:0')}
Epoch 1: {'loss': tensor(0.2575, device='cuda:0'), 'val_loss': tensor(0.0525, device='cuda:0'), 'val_loss_epoch': tensor(0.2291, device='cuda:0'), 'train_loss_epoch': tensor(0.2746, device='cuda:0')}
Train_epoch_end: {'loss': tensor(0.2575, device='cuda:0'), 'val_loss': tensor(0.0525, device='cuda:0'), 'val_loss_epoch': tensor(0.2291, device='cuda:0'), 'train_loss_epoch': tensor(0.2746, device='cuda:0')}


Val_epoch_end: {'loss': tensor(0.1669, device='cuda:0'), 'val_loss': tensor(0.0316, device='cuda:0'), 'val_loss_epoch': tensor(0.1905, device='cuda:0'), 'train_loss_epoch': tensor(0.2746, device='cuda:0')}
Epoch 2: {'loss': tensor(0.1669, device='cuda:0'), 'val_loss': tensor(0.0316, device='cuda:0'), 'val_loss_epoch': tensor(0.1905, device='cuda:0'), 'train_loss_epoch': tensor(0.2091, device='cuda:0')}
Train_epoch_end: {'loss': tensor(0.1669, device='cuda:0'), 'val_loss': tensor(0.0316, device='cuda:0'), 'val_loss_epoch': tensor(0.1905, device='cuda:0'), 'train_loss_epoch': tensor(0.2091, device='cuda:0')}

Expected behavior

Though I imagine there may not be many situations where this may have an impact (as accessing training metrics using the on_validation_epoch_end hook would be unusual) I find the behaviour somewhat confusing. I suspect this may be due to the order in which the hooks are executed. Is this intended behaviour, or should there be consistency in the behaviour of both hooks?

bug / fix help wanted

Most helpful comment

Not every epoch takes 2 minutes :)

On many research lines one needs to check validation multiple times within an epoch.

For example:
Train --------------------------------------
Valid -C--------------C-----------------C-

This comes up during things like BERT + NLP, imagenet and huge datasets where an epoch might take days.

This behavior is expected.

Second, callback metrics are not meant to be accessed (but can be). Instead, logged_metrics and prog_bar_metrics are there in case people want the metrics.

All 2 comments

From #2816 and determining when training_epoch_end and validation_epoch_end are called explains the behaviour - the hook on_validation_epoch_end is called before training_epoch_end which is why the values from the training epoch are not yet available and the values correspond to the previous epoch. Not sure if the behaviour warrants a fix - as not necessarily a bug, though somewhat counterintuitive when first looking at it?

Not every epoch takes 2 minutes :)

On many research lines one needs to check validation multiple times within an epoch.

For example:
Train --------------------------------------
Valid -C--------------C-----------------C-

This comes up during things like BERT + NLP, imagenet and huge datasets where an epoch might take days.

This behavior is expected.

Second, callback metrics are not meant to be accessed (but can be). Instead, logged_metrics and prog_bar_metrics are there in case people want the metrics.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

maxime-louis picture maxime-louis  路  3Comments

remisphere picture remisphere  路  3Comments

awaelchli picture awaelchli  路  3Comments

DavidRuhe picture DavidRuhe  路  3Comments

williamFalcon picture williamFalcon  路  3Comments