Pytorch-lightning: Metrics are not reset when using self.log()

Created on 22 Nov 2020  路  17Comments  路  Source: PyTorchLightning/pytorch-lightning

馃悰 Bug

Metrics are not reset when using self.log() unless user explicitly calling self.metric.compute()
See MSE metric example in the colab notebook linked below.
Printing internal states on epoch end shows that the metric states are not reset. Calling self.metric.compute() explicitly resolve the issue (uncomment line in epoch_end in linked colab)

Note: I didn't tested if reset occurs when logging on_step

https://colab.research.google.com/drive/10cTIhVkxdgKZ23WwAHiCSAlU7K1a3NRr?usp=sharing

Expected behavior

According to the documentation, metric should be reset even when using self.log() only.

Metrics help wanted tests / CI with code

All 17 comments

@itsikad Thanks. Could you please share the colab notebook? I'd like to take a look. Thanks.

@junwen-austin
should be available now. thanks!

There is probably no good way one can automate this. The metric returns a tensor, so for the log function there is no way of telling that it should call compute at the end of the epoch (the metrics devs will correct me if I'm wrong haha). The docs here
https://pytorch-lightning.readthedocs.io/en/stable/metrics.html
show the example of how to log both on step and on epoch end with resetting on epoch end by a call to compute().

@awaelchli Thanks for your fast response.
The docs shows two equivalent approaches:

  1. Using metric .forward() in <mode>_step() and explicitly use .compute() on <mode>_<step/epoch>_end().
  2. Use self.log('metric name', metric, on_step=True/False, on_epoch=True/False) which calls .compute() internally and so calling .compute() explicitly in '_

Note: according to the documentation, in DP the metric .forward() call should be within <mode>_step_end(), however, I this issues also occurs in DDP.

Oh ^^ You are doing the correct thing then. I did not know you can pass the metric object into the log method :)

from pytorch_lightning import LightningModule
from pytorch_lightning.metrics.regression import MeanSquaredError

import os

import torch
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
tmpdir = os.getcwd()


class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.metric = MeanSquaredError()

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        metric = self.metric(output, torch.ones_like(output))
        self.log('mse', self.metric, on_step=False, on_epoch=True)
        print(f'epoch: {self.current_epoch}, batch_idx: {batch_idx}, mse: {loss}')
        return {"loss": loss, "batch_mse": loss}

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def training_epoch_end(self, outputs) -> None:
        avg_mse = torch.stack([x["loss"] for x in outputs]).mean()
        # print(f'Epoch {self.current_epoch} end,  metric: {self.metric.compute()}, mean: {avg_mse}')
        print(f'Sum squared error: {self.metric.sum_squared_error}, Total samples: {self.metric.total}')

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"x": loss}

    def validation_epoch_end(self, outputs) -> None:
        torch.stack([x['x'] for x in outputs]).mean()

    def test_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        self.log('fake_test_acc', loss)
        return {"y": loss}

    def test_epoch_end(self, outputs) -> None:
        torch.stack([x["y"] for x in outputs]).mean()

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]


class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


def main():
    num_samples = 10000

    train = RandomDataset(32, num_samples)
    train = DataLoader(train, batch_size=4)

    val = RandomDataset(32, num_samples)
    val = DataLoader(val, batch_size=4)

    test = RandomDataset(32, num_samples)
    test = DataLoader(test, batch_size=4)

    # init model
    model = BoringModel()

    # Initialize a trainer
    trainer = pl.Trainer(
        max_epochs=3,
        gpus=1,
        progress_bar_refresh_rate=0,
        limit_train_batches=5,
        limit_val_batches=5,
        limit_test_batches=0,
    )

    trainer.fit(model, train, val)

    # trainer.test(test_dataloaders=test)


if __name__ == "__main__":
    main()

A copy from google colab for reproducing locally.

Installing from latest master, I see total correctly reset to 0, while on 1.0.7 I see total increasig from epoch to epoch. Can you confirm?

Use
! pip install --upgrade git+https://github.com/PyTorchLightning/pytorch-lightning@master
in the notebook, reload the instance, rerun the cells.

Thanks @awaelchli, I was going crazy yesterday why I could sometimes reproduce and sometimes not.
Since we have not made any changes to metrics since 1.0.7, do you have any idea what should have fixed it.

@awaelchli Indeed, both total and sum_squared_error are reset between epochs.

@teddykoker @SkafteNicki don't we have a test for something like this? I thought we were testing for this.

This is the test we have, I guess it doesn't necessarily cover this? https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/metrics/test_metric_lightning.py#L55-L88

@teddykoker agree, it only runs for 1 epoch
We could probably change it so we actually test that the metric is being reset

Here is the code that resets the metric, https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/step_result.py#L318-L324, it should get called every epoch

It must have been fixed in one of these self.log PRs #3813 that @tchaton worked on.

Just checked, if we change epoch=2 the test will fail on 1.0.7, but not master.

Awesome, this should cover us for the future :)

TODO: need to just update the test.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

baeseongsu picture baeseongsu  路  3Comments

Vichoko picture Vichoko  路  3Comments

anthonytec2 picture anthonytec2  路  3Comments

DavidRuhe picture DavidRuhe  路  3Comments

williamFalcon picture williamFalcon  路  3Comments