Pytorch-lightning: [metrics] Automatic reduction of metrics from several validation steps

Created on 26 Mar 2020  ·  8Comments  ·  Source: PyTorchLightning/pytorch-lightning

🚀 Feature

As per the slack, it could be cool to implement this. More detail below.

Motivation

To avoid the user having to do this

logits = torch.cat(x['logits'] for x in output)
labels = torch.cat(x['labels'] for x in output) 
and so on ...

Pitch

Something like this:

    def collate_metrics(self, output):
        """
        Function to collate the output from several validation steps
        """
        collated_output = {}
        keys = output[0].keys()
        for key in keys:
            tensor_dim = output[0][key].dim()
            if tensor_dim > 0:
                collated_output[key] = torch.cat([x[key] for x in output])
            elif tensor_dim == 0:
                # Reduce scalars by mean
                collated_output[key] = torch.tensor([x[key] for x in output]).mean()
        return collated_output

Alternatives

Can just add the above to lightning module and use it anyway.

Metrics discussion enhancement help wanted

All 8 comments

I think this is cool. Things that come to my mind:

  • concatenate or stack seams reasonable as a default collate, but I would do without the mean, it's too specific.
  • let the user override the collate method
  • if there is a collate, it should not only apply to validation, but also to training_end, test_end, right? Then the question is do we let the user override each of these?

I think that collate_metrics function would need at least:

  • filtering - on which keys it should be applied
  • needed to be used on several places - end of epoch, end of training, etc ...

The code will start to be quite messy soon.
I quite like this approach https://github.com/PyTorchLightning/pytorch-lightning/issues/973#issuecomment-592795508

I guess that if add metrics as a class discussed in #973 we may define for each custom reduction method, right?

cc: @justusschock @SkafteNicki

@Borda I think we will integrate this into an automated metric calculation plan that also has a different collation per metric.

@justusschock can we close this issue? was it fixed?

@edenlightning not yet. We haven't yet come to implementing accumulation for metrics. This will be V2 of metrics

With PR #3245 merge, this is solved now. Each metric now have a aggregated property that contains the aggregated metric value of data seen so far. In practice you can use it like this in lightning:

def validation_step(self, batch, batch_idx):
    x, y = batch
    ypred = self(x)
    loss = self.loss_fn(ypred, y)
    val = self.metric(ypred, y)
    return loss # no need to return the value of the metric

def validation_epoch_end(self, validation_step_outputs):
    aggregated_metric = self.metric.aggregated
    return aggregated_metric

Closing this issue.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

jcreinhold picture jcreinhold  ·  3Comments

awaelchli picture awaelchli  ·  3Comments

edenlightning picture edenlightning  ·  3Comments

as754770178 picture as754770178  ·  3Comments

williamFalcon picture williamFalcon  ·  3Comments