Pytorch-lightning: Problem to handle val_loss from validation_step

Created on 28 Feb 2020  路  6Comments  路  Source: PyTorchLightning/pytorch-lightning

馃悰 Bug

The version from a master branch does not handle dict result from validation_step.

Code sample

With such function:

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.model(x)
       return {'val_loss': self.criterion(y_pred, y.view(-1, 1))}

I got error:
RuntimeError: Early stopping conditioned on metricval_loss` which is not available. Available metrics are: ```

If I change this function to the next one:

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.model(x)
        return self.criterion(y_pred, y.view(-1, 1))

Then early stopping seems start working.

But for both cases the callback ModelCheckpoint not works and show a warning:

RuntimeWarning: Can save best model only with val_loss available, skipping.
  ' skipping.', RuntimeWarning)
help wanted question

All 6 comments

Hi! thanks for your contribution!, great first issue!

This is not a bug.
Solution: You should implement validation_end and return a dict containing 'val_loss' there.

EarlyStopping and ModelCheckpoint do not monitor metrics returned by validation_step. The error over the whole validation set is what is needed, and this is done by collecting all metrics from validation_step and merging them in validation_end.
https://pytorch-lightning.readthedocs.io/en/0.6.0/lightning-module.html

See also this basic example:
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/lightning_module_template.py

@awaelchli Thank you for the explanation.

Can you please also help me with the logger understanding. Will it automatically log all metrics from training_end and validation_end? Or I need manually log it in these functions?

See here: https://pytorch-lightning.readthedocs.io/en/latest/experiment_reporting.html
In the output dict add another entry called "log" which is another dict with all metrics you would like to log. Example:

def validation_end(self, outputs):
   loss = some_loss()
   ...

   logs = {'val_loss': loss}  
   output = {
      'val_loss': loss,  # for early stopping, model checkpoint
      'log': logs,  # will be consumed by logger
   }
   return output

@Borda could you remove the bug label?

@awaelchli Thank you for the help.
I am closing the issue as this is really not a bug.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

mmsamiei picture mmsamiei  路  3Comments

Vichoko picture Vichoko  路  3Comments

anthonytec2 picture anthonytec2  路  3Comments

edenlightning picture edenlightning  路  3Comments

versatran01 picture versatran01  路  3Comments