Current metrics like Accuracy/Recall would be better to support mask.
For example, when I deal with a Sequence Labeling Task and pad some sequence to max-length, I do not want to calculate metrics at the padding locations.
I guess a simple manipulation would work for accuracy.(here is the original one)
from typing import Any, Optional
import torch
from pytorch_lightning.metrics.functional.classification import (
accuracy,
)
from pytorch_lightning.metrics.metric import TensorMetric
class MaskedAccuracy(TensorMetric):
"""
Computes the accuracy classification score
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> mask = torch.tensor([1, 1, 1, 0])
>>> metric = MaskedAccuracy(num_classes=4)
>>> metric(pred, target, mask)
tensor(1.)
"""
def __init__(
self,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='accuracy',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.num_classes = num_classes
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
mask: only caculate metrics where mask==1
Return:
A Tensor with the classification score.
"""
mask_fill = (1-mask).bool()
pred = pred.masked_fill_(mask=mask_fill, value=-1)
target = target.masked_fill_(mask=mask_fill, value=-1)
return accuracy(pred=pred, target=target,
num_classes=self.num_classes, reduction=self.reduction)
Looks nice! @YuxianMeng want to send it as a PR?
Cc: @justusschock @SkafteNicki
Looks nice! @YuxianMeng want to send it as a PR?
Cc: @justusschock @SkafteNicki
My pleasure:) A little question is should this PR contain only masked precision metrics or also contain other metrics?
I would say all, in fact it would be nice to have an abstract function/class that do this masking and the new metrics would be created just its application, so for example:
Does it make sense? @justusschock @SkafteNicki
@YuxianMeng But with your implementation, you calculate it also for the values you set to -1 I think.
What you instead need to do is accuracy(pred[mask], target[mask]) which is why I wouldn't add extras for them to be honest. We can't include every special case here and masking tensors is not much overhead, which is why I'd prefer not to include this into the metrics package. Thoughts @SkafteNicki ?
@YuxianMeng But with your implementation, you calculate it also for the values you set to -1 I think.
What you instead need to do is
accuracy(pred[mask], target[mask])which is why I wouldn't add extras for them to be honest. We can't include every special case here and masking tensors is not much overhead, which is why I'd prefer not to include this into the metrics package. Thoughts @SkafteNicki ?
@justusschock As for accuracy, actually only the non-negative classes are calculated. I thought about using accuracy(pred[mask], target[mask]), but it may cause speed trouble when training on TPU
I agree with @Borda that this should be an abstract function/class. The most simple, in my opinion, would be a class that the user can wrap their already existing metric with: masked_accuracy=MaskedMetric(Accuracy()). This would add a additional argument to the call: value = masked_accuracy(pred, target, mask). The alternative, re-writing each metric to include this feature, is not feasible at the moment.
@YuxianMeng mind send a PR and I guess @SkafteNicki or @justusschock could help/guide you throw 馃惏
I think I speak for both of us, saying that we'd for sure do that and really appreciate the PR :)
Yes just ping us in the PR when you are ready, and we will assist you.
Working on it, I will let you when I'm ready :)
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!
This issue has been closed. Does the mask metrics features has landed? Or nobody has worked on it yet?
It was closed due to no activity, so it is still not a part of lightning. @hadim please feel free to pick it up and send a PR :]
Most helpful comment
Yes just ping us in the PR when you are ready, and we will assist you.