Ignite: Bug in binary precision

Created on 29 Nov 2018  路  11Comments  路  Source: pytorch/ignite

Basic tests of binary precision seem to fail:

precision = Precision(average=True)

y_pred = torch.rand(10, 1)
y = torch.randint(0, 2, size=(10,)).type(torch.LongTensor)

precision.update((y_pred, y))

np_y = y.numpy().ravel()
np_y_pred = (y_pred.numpy().ravel() > 0.5).astype('int')

precision_score(np_y, np_y_pred), precision.compute()

@jasonkriss could you please confirm this ?

EDIT: Another failing test: https://github.com/pytorch/ignite/pull/333#issuecomment-442643530

Probably, the error is at binary to categorical mapping and counting class-0 similarly to class-1. But in binary case we should ignore class-0.

There is also input "binary or categorical" checking is missing if user tries to mix both in several updates.

0.1.2 bug

Most helpful comment

Once this is merged we should probably cut a release. I think for bugfixes we should release often, just to avoid anyone running for too long with a buggy version

All 11 comments

@vfdev-5 this might be happening because we map binary into categorical thus creating precision for 2 classes and then average it.

If we treat binary as just once class and round it, we should get the same answer.

The reason we鈥檙e different answers is because of ignite treats it as 2 classes while sklearn treats it as 1 class.

Works on my local machine. Will try with all the binary tests.

EDIT: I just saw your edit lol

@anmolsjoshi I thought to add an attribute _is_binary like this:

class Precision(Metric):
    def __init__(self, output_transform=lambda x: x, is_multilabel=False, average=False, threshold_function=None):
        self._average = average
        if is_multilabel:
            if threshold_function is None:
                self._threshold = torch.round
            else:
                if callable(threshold_function):
                    self._threshold = threshold_function
                else:
                    raise ValueError("threshold_function must be a callable function.")
            if not self._average:
                warnings.warn('average should be True for multilabel cases. Precision._average updated'
                              ' to True. Average is calculated across samples, instead of classes.', UserWarning)
                self._average = True
            self.update = self._update_multilabel
        else:
            self.update = self._update_multiclass
        super(Precision, self).__init__(output_transform=output_transform)
        self._is_binary = None

    def reset(self):
        self._all_positives = None
        self._true_positives = None
        self._is_binary = None

...

    def _update_multiclass(self, output):
        y_pred, y = output
        dtype = y_pred.type()

        if not (y.ndimension() == y_pred.ndimension() or y.ndimension() + 1 == y_pred.ndimension()):
            raise ValueError("y must have shape of (batch_size, ...) and y_pred must have "
                             "shape of (batch_size, num_categories, ...) or (batch_size, ...).")

        if y.ndimension() == 1 or y.shape[1] == 1:
            # Binary Case, flattens y and num_classes is equal to 1.
            y = y.squeeze(dim=1).view(-1) if (y.ndimension() > 1) else y.view(-1)

        if y_pred.ndimension() == 1 or y_pred.shape[1] == 1:
            # Binary Case, flattens y and num_classes is equal to 1.
            y_pred = y_pred.squeeze(dim=1).view(-1) if (y_pred.ndimension() > 1) else y_pred.view(-1)

        y_shape = y.shape
        y_pred_shape = y_pred.shape

        if y.ndimension() + 1 == y_pred.ndimension():
            y_pred_shape = (y_pred_shape[0],) + y_pred_shape[2:]

        if not (y_shape == y_pred_shape):
            raise ValueError("y and y_pred must have compatible shapes.")

        if y_pred.ndimension() == y.ndimension():
            # Maps Binary Case to Categorical Case with 2 classes
            y_pred = y_pred.unsqueeze(dim=1)
            y_pred = torch.cat([1.0 - y_pred, y_pred], dim=1)
            if self._is_binary is None:
                self._is_binary = True
            elif not self._is_binary:
                raise ValueError("A binary y (shape={}) or y_pred (shape={}) values encountered while previous "
                                 "values are categorical.".format(y.shape, y_pred.shape))

        y = to_onehot(y.view(-1), num_classes=y_pred.size(1))
        indices = torch.max(y_pred, dim=1)[1].view(-1)
        y_pred = to_onehot(indices, num_classes=y_pred.size(1))

        y_pred = y_pred.type(dtype)
        y = y.type(dtype)

        correct = y * y_pred
        all_positives = y_pred.sum(dim=0)

        if self._is_binary:
            correct = correct[:, 1, ...]
            all_positives = all_positives[1, ...]

        if correct.sum() == 0:
            true_positives = torch.zeros_like(all_positives)
        else:
            true_positives = correct.sum(dim=0)
        if self._all_positives is None:
            self._all_positives = all_positives
            self._true_positives = true_positives
        else:
            self._all_positives += all_positives
            self._true_positives += true_positives

@anmolsjoshi IMO we should fix this bug in priority before merging multilabel case.
And we need to add more strong tests.

@vfdev-5 I agree that we should correct this bug.

What are your thoughts on splitting the Metric (Accuracy, Precision and Recall) as shown below:

  • Have a _check_output function that decides whether the problem is binary, multiclass, multilabel. The output of this function is y_pred, y, type. Type can be determined by shape. This bypasses the need for is_multilabel argument.

  • In the case of binary or multilabel, make threshold_function (initialized as None or user input) torch.round or user_input. I honestly think we need to use a threshold for binary problems rather than converting them to multiclass problem.

  • In one common update function, binary and multilabel are handled similarly and we keep multiclass handling the same.

  • Only change needed then would be on how self._true_positives and self._all_positives are accumulated.

@vfdev-5 see code below

I think if we can find a method to see if self.output_type changes during training, we can raise an error.

class Precision(Metric):
    def __init__(self, output_transform=lambda x: x, average=False, threshold_function=None):
        self._average = average

        if threshold_function is not None:
            if callable(threshold_function):
                self._threshold = threshold_function
            else:
                raise ValueError("threshold_function must be a callable function.")
        else:
            self._threshold = torch.round
        super(Precision, self).__init__(output_transform=output_transform)

    def _check_output(self, output):

        y_pred, y = output

        if not (y.ndimension() == y_pred.ndimension() or y.ndimension() + 1 == y_pred.ndimension()):
            raise ValueError("y must have shape of (batch_size, ...) and y_pred must have "
                             "shape of (batch_size, num_categories, ...) or (batch_size, ...).")

        if y.ndimension() == 1 or y.shape[1] == 1:
            # Binary Case, flattens y and num_classes is equal to 1.
            y = y.squeeze(dim=1).view(-1) if (y.ndimension() > 1) else y.view(-1)

        if y_pred.ndimension() == 1 or y_pred.shape[1] == 1:
            # Binary Case, flattens y and num_classes is equal to 1.
            y_pred = y_pred.squeeze(dim=1).view(-1) if (y_pred.ndimension() > 1) else y_pred.view(-1)

        y_shape = y.shape
        y_pred_shape = y_pred.shape

        if y.ndimension() + 1 == y_pred.ndimension():
            y_pred_shape = (y_pred_shape[0],) + y_pred_shape[2:]

        if not (y_shape == y_pred_shape):
            raise ValueError("y and y_pred must have compatible shapes.")

        if y.ndimension() + 1 == y_pred.ndimension():
            self.output_type = 'multiclass'
            self.axis = 0
            self.update_func = sum

        elif y_pred.shape == y.shape and not y.ndimension() == 1:
            self.output_type = 'multilabel'
            self.axis = 1
            self.update_func = torch.cat

            if y_pred.ndimension() > 2:
                num_classes = y_pred.size(1)
                y_pred = torch.transpose(y_pred, 1, 0).contiguous().view(num_classes, -1).transpose(1, 0)
                y = torch.transpose(y, 1, 0).contiguous().view(num_classes, -1).transpose(1, 0)

            if not self._average:
                warnings.warn('average should be True for multilabel cases. Precision._average updated'
                              ' to True. Average is calculated across samples, instead of classes.', UserWarning)
                self._average = True

        elif y_pred.shape == y.shape and y.ndimension() == 1:
            self.output_type = 'binary'
            self.axis = 0
            self.update_func = sum

        else:
            raise ValueError()

        return y_pred, y

    def update(self, output):
        y_pred, y = self._check_output(output)
        dtype = y_pred.type()

        if self.output_type == 'multiclass':
            y = to_onehot(y.view(-1), num_classes=y_pred.size(1))
            indices = torch.max(y_pred, dim=1)[1].view(-1)
            y_pred = to_onehot(indices, num_classes=y_pred.size(1))
        else:
            y_pred = self._threshold(y_pred)

            if torch.equal(y, y**2):
                raise ValueError()

            if torch.equal(y_pred, y_pred**2):
                raise ValueError()

        y_pred = y_pred.type(dtype)
        y = y.type(dtype)

        correct = y * y_pred
        all_positives = y_pred.sum(dim=self.axis)

        if correct.sum() == 0:
            true_positives = torch.zeros_like(all_positives)
        else:
            true_positives = correct.sum(dim=self.axis)

        if self._all_positives is None:
            self._all_positives = all_positives
            self._true_positives = true_positives
        else:

            self._all_positives = self.update_func([self._all_positives, all_positives])
            self._true_positives = self.update_func([self._true_positives, true_positives])

    def compute(self):
        if self._all_positives is None:
            raise NotComputableError('Precision must have at least one example before it can be computed')

        result = self._true_positives / self._all_positives
        result[result != result] = 0.0
        if self._average:
            return result.mean().item()
        else:
            return result

@anmolsjoshi if we unify all cases, we should make sure that the implementation is bullet-proof, all tricky cases are covered when user tries to mix types: binary, categorial, multilabel in updates during the same computation session. IMO, it becomes very difficult to follow the implementation.

@vfdev-5 I agree with your implementation. I'll incorporate it into the current PR #333 and update the binary tests accordingly for Precision and Recall i.e. not map y_pred to 2 classes when using sklearn.

Maybe in the future we can think of a more a sophisticated approach to unify all cases.

Once this is merged we should probably cut a release. I think for bugfixes we should release often, just to avoid anyone running for too long with a buggy version

@alykhantejani, so new release version will be 0.1.2 or 0.1.1post1 ?
If 0.1.2 then we need also to remove BinaryAccuracy, CategoricalAccuracy as mentioned in warnings

So we could follow the pytorch path here. 0.2 can contain the backwards incompatible changes and this can be 0.1.2, we can switch the warning message to say 0.2 and all the current tickets assigned to 0.1.2 we can just move to 0.2 and that can be the next release (pending no further bug fixes).

wdyt?

@alykhantejani sounds good, let's do it like you propose

Was this page helpful?
0 / 5 - 0 ratings

Related issues

UjwalKandi picture UjwalKandi  路  3Comments

vfdev-5 picture vfdev-5  路  3Comments

andreydung picture andreydung  路  4Comments

milongo picture milongo  路  3Comments

samarth-robo picture samarth-robo  路  3Comments