As per the slack, it could be cool to implement this. More detail below.
To avoid the user having to do this
logits = torch.cat(x['logits'] for x in output)
labels = torch.cat(x['labels'] for x in output)
and so on ...
Something like this:
def collate_metrics(self, output):
"""
Function to collate the output from several validation steps
"""
collated_output = {}
keys = output[0].keys()
for key in keys:
tensor_dim = output[0][key].dim()
if tensor_dim > 0:
collated_output[key] = torch.cat([x[key] for x in output])
elif tensor_dim == 0:
# Reduce scalars by mean
collated_output[key] = torch.tensor([x[key] for x in output]).mean()
return collated_output
Can just add the above to lightning module and use it anyway.
I think this is cool. Things that come to my mind:
I think that collate_metrics function would need at least:
The code will start to be quite messy soon.
I quite like this approach https://github.com/PyTorchLightning/pytorch-lightning/issues/973#issuecomment-592795508
I guess that if add metrics as a class discussed in #973 we may define for each custom reduction method, right?
cc: @justusschock @SkafteNicki
@Borda I think we will integrate this into an automated metric calculation plan that also has a different collation per metric.
@justusschock can we close this issue? was it fixed?
@edenlightning not yet. We haven't yet come to implementing accumulation for metrics. This will be V2 of metrics
With PR #3245 merge, this is solved now. Each metric now have a aggregated property that contains the aggregated metric value of data seen so far. In practice you can use it like this in lightning:
def validation_step(self, batch, batch_idx):
x, y = batch
ypred = self(x)
loss = self.loss_fn(ypred, y)
val = self.metric(ypred, y)
return loss # no need to return the value of the metric
def validation_epoch_end(self, validation_step_outputs):
aggregated_metric = self.metric.aggregated
return aggregated_metric
Closing this issue.