Basic tests of binary precision seem to fail:
precision = Precision(average=True)
y_pred = torch.rand(10, 1)
y = torch.randint(0, 2, size=(10,)).type(torch.LongTensor)
precision.update((y_pred, y))
np_y = y.numpy().ravel()
np_y_pred = (y_pred.numpy().ravel() > 0.5).astype('int')
precision_score(np_y, np_y_pred), precision.compute()
@jasonkriss could you please confirm this ?
EDIT: Another failing test: https://github.com/pytorch/ignite/pull/333#issuecomment-442643530
Probably, the error is at binary to categorical mapping and counting class-0 similarly to class-1. But in binary case we should ignore class-0.
There is also input "binary or categorical" checking is missing if user tries to mix both in several updates.
@vfdev-5 this might be happening because we map binary into categorical thus creating precision for 2 classes and then average it.
If we treat binary as just once class and round it, we should get the same answer.
The reason we鈥檙e different answers is because of ignite treats it as 2 classes while sklearn treats it as 1 class.
Works on my local machine. Will try with all the binary tests.
EDIT: I just saw your edit lol
@anmolsjoshi I thought to add an attribute _is_binary like this:
class Precision(Metric):
def __init__(self, output_transform=lambda x: x, is_multilabel=False, average=False, threshold_function=None):
self._average = average
if is_multilabel:
if threshold_function is None:
self._threshold = torch.round
else:
if callable(threshold_function):
self._threshold = threshold_function
else:
raise ValueError("threshold_function must be a callable function.")
if not self._average:
warnings.warn('average should be True for multilabel cases. Precision._average updated'
' to True. Average is calculated across samples, instead of classes.', UserWarning)
self._average = True
self.update = self._update_multilabel
else:
self.update = self._update_multiclass
super(Precision, self).__init__(output_transform=output_transform)
self._is_binary = None
def reset(self):
self._all_positives = None
self._true_positives = None
self._is_binary = None
...
def _update_multiclass(self, output):
y_pred, y = output
dtype = y_pred.type()
if not (y.ndimension() == y_pred.ndimension() or y.ndimension() + 1 == y_pred.ndimension()):
raise ValueError("y must have shape of (batch_size, ...) and y_pred must have "
"shape of (batch_size, num_categories, ...) or (batch_size, ...).")
if y.ndimension() == 1 or y.shape[1] == 1:
# Binary Case, flattens y and num_classes is equal to 1.
y = y.squeeze(dim=1).view(-1) if (y.ndimension() > 1) else y.view(-1)
if y_pred.ndimension() == 1 or y_pred.shape[1] == 1:
# Binary Case, flattens y and num_classes is equal to 1.
y_pred = y_pred.squeeze(dim=1).view(-1) if (y_pred.ndimension() > 1) else y_pred.view(-1)
y_shape = y.shape
y_pred_shape = y_pred.shape
if y.ndimension() + 1 == y_pred.ndimension():
y_pred_shape = (y_pred_shape[0],) + y_pred_shape[2:]
if not (y_shape == y_pred_shape):
raise ValueError("y and y_pred must have compatible shapes.")
if y_pred.ndimension() == y.ndimension():
# Maps Binary Case to Categorical Case with 2 classes
y_pred = y_pred.unsqueeze(dim=1)
y_pred = torch.cat([1.0 - y_pred, y_pred], dim=1)
if self._is_binary is None:
self._is_binary = True
elif not self._is_binary:
raise ValueError("A binary y (shape={}) or y_pred (shape={}) values encountered while previous "
"values are categorical.".format(y.shape, y_pred.shape))
y = to_onehot(y.view(-1), num_classes=y_pred.size(1))
indices = torch.max(y_pred, dim=1)[1].view(-1)
y_pred = to_onehot(indices, num_classes=y_pred.size(1))
y_pred = y_pred.type(dtype)
y = y.type(dtype)
correct = y * y_pred
all_positives = y_pred.sum(dim=0)
if self._is_binary:
correct = correct[:, 1, ...]
all_positives = all_positives[1, ...]
if correct.sum() == 0:
true_positives = torch.zeros_like(all_positives)
else:
true_positives = correct.sum(dim=0)
if self._all_positives is None:
self._all_positives = all_positives
self._true_positives = true_positives
else:
self._all_positives += all_positives
self._true_positives += true_positives
@anmolsjoshi IMO we should fix this bug in priority before merging multilabel case.
And we need to add more strong tests.
@vfdev-5 I agree that we should correct this bug.
What are your thoughts on splitting the Metric (Accuracy, Precision and Recall) as shown below:
Have a _check_output function that decides whether the problem is binary, multiclass, multilabel. The output of this function is y_pred, y, type. Type can be determined by shape. This bypasses the need for is_multilabel argument.
In the case of binary or multilabel, make threshold_function (initialized as None or user input) torch.round or user_input. I honestly think we need to use a threshold for binary problems rather than converting them to multiclass problem.
In one common update function, binary and multilabel are handled similarly and we keep multiclass handling the same.
Only change needed then would be on how self._true_positives and self._all_positives are accumulated.
@vfdev-5 see code below
I think if we can find a method to see if self.output_type changes during training, we can raise an error.
class Precision(Metric):
def __init__(self, output_transform=lambda x: x, average=False, threshold_function=None):
self._average = average
if threshold_function is not None:
if callable(threshold_function):
self._threshold = threshold_function
else:
raise ValueError("threshold_function must be a callable function.")
else:
self._threshold = torch.round
super(Precision, self).__init__(output_transform=output_transform)
def _check_output(self, output):
y_pred, y = output
if not (y.ndimension() == y_pred.ndimension() or y.ndimension() + 1 == y_pred.ndimension()):
raise ValueError("y must have shape of (batch_size, ...) and y_pred must have "
"shape of (batch_size, num_categories, ...) or (batch_size, ...).")
if y.ndimension() == 1 or y.shape[1] == 1:
# Binary Case, flattens y and num_classes is equal to 1.
y = y.squeeze(dim=1).view(-1) if (y.ndimension() > 1) else y.view(-1)
if y_pred.ndimension() == 1 or y_pred.shape[1] == 1:
# Binary Case, flattens y and num_classes is equal to 1.
y_pred = y_pred.squeeze(dim=1).view(-1) if (y_pred.ndimension() > 1) else y_pred.view(-1)
y_shape = y.shape
y_pred_shape = y_pred.shape
if y.ndimension() + 1 == y_pred.ndimension():
y_pred_shape = (y_pred_shape[0],) + y_pred_shape[2:]
if not (y_shape == y_pred_shape):
raise ValueError("y and y_pred must have compatible shapes.")
if y.ndimension() + 1 == y_pred.ndimension():
self.output_type = 'multiclass'
self.axis = 0
self.update_func = sum
elif y_pred.shape == y.shape and not y.ndimension() == 1:
self.output_type = 'multilabel'
self.axis = 1
self.update_func = torch.cat
if y_pred.ndimension() > 2:
num_classes = y_pred.size(1)
y_pred = torch.transpose(y_pred, 1, 0).contiguous().view(num_classes, -1).transpose(1, 0)
y = torch.transpose(y, 1, 0).contiguous().view(num_classes, -1).transpose(1, 0)
if not self._average:
warnings.warn('average should be True for multilabel cases. Precision._average updated'
' to True. Average is calculated across samples, instead of classes.', UserWarning)
self._average = True
elif y_pred.shape == y.shape and y.ndimension() == 1:
self.output_type = 'binary'
self.axis = 0
self.update_func = sum
else:
raise ValueError()
return y_pred, y
def update(self, output):
y_pred, y = self._check_output(output)
dtype = y_pred.type()
if self.output_type == 'multiclass':
y = to_onehot(y.view(-1), num_classes=y_pred.size(1))
indices = torch.max(y_pred, dim=1)[1].view(-1)
y_pred = to_onehot(indices, num_classes=y_pred.size(1))
else:
y_pred = self._threshold(y_pred)
if torch.equal(y, y**2):
raise ValueError()
if torch.equal(y_pred, y_pred**2):
raise ValueError()
y_pred = y_pred.type(dtype)
y = y.type(dtype)
correct = y * y_pred
all_positives = y_pred.sum(dim=self.axis)
if correct.sum() == 0:
true_positives = torch.zeros_like(all_positives)
else:
true_positives = correct.sum(dim=self.axis)
if self._all_positives is None:
self._all_positives = all_positives
self._true_positives = true_positives
else:
self._all_positives = self.update_func([self._all_positives, all_positives])
self._true_positives = self.update_func([self._true_positives, true_positives])
def compute(self):
if self._all_positives is None:
raise NotComputableError('Precision must have at least one example before it can be computed')
result = self._true_positives / self._all_positives
result[result != result] = 0.0
if self._average:
return result.mean().item()
else:
return result
@anmolsjoshi if we unify all cases, we should make sure that the implementation is bullet-proof, all tricky cases are covered when user tries to mix types: binary, categorial, multilabel in updates during the same computation session. IMO, it becomes very difficult to follow the implementation.
@vfdev-5 I agree with your implementation. I'll incorporate it into the current PR #333 and update the binary tests accordingly for Precision and Recall i.e. not map y_pred to 2 classes when using sklearn.
Maybe in the future we can think of a more a sophisticated approach to unify all cases.
Once this is merged we should probably cut a release. I think for bugfixes we should release often, just to avoid anyone running for too long with a buggy version
@alykhantejani, so new release version will be 0.1.2 or 0.1.1post1 ?
If 0.1.2 then we need also to remove BinaryAccuracy, CategoricalAccuracy as mentioned in warnings
So we could follow the pytorch path here. 0.2 can contain the backwards incompatible changes and this can be 0.1.2, we can switch the warning message to say 0.2 and all the current tickets assigned to 0.1.2 we can just move to 0.2 and that can be the next release (pending no further bug fixes).
wdyt?
@alykhantejani sounds good, let's do it like you propose
Most helpful comment
Once this is merged we should probably cut a release. I think for bugfixes we should release often, just to avoid anyone running for too long with a buggy version