Pytorch-lightning: Accuracy metrics: In the case of all class indices of a target tensor is 0, throwing error.

Created on 16 Aug 2020  路  9Comments  路  Source: PyTorchLightning/pytorch-lightning

馃悰 Bug

version: 0.9.0rc12

To Reproduce

Code sample

from pytorch_lightning.metrics.functional import accuracy

pred = torch.tensor([0, 1, 2, 3])
target = torch.tensor([0, 0, 0, 0])

# calculates accuracy across all GPUs and all Nodes used in training
accuracy(pred, target)

RuntimeError

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-16-5d9063513e76> in <module>
      5 
      6 # calculates accuracy across all GPUs and all Nodes used in training
----> 7 accuracy(pred, target)

~/anaconda3/envs/pl/lib/python3.7/site-packages/pytorch_lightning/metrics/functional/classification.py in accuracy(pred, target, num_classes, reduction)
    268     """
    269     if not (target > 0).any() and num_classes is None:
--> 270         raise RuntimeError("cannot infer num_classes when target is all zero")
    271 
    272     tps, fps, tns, fns, sups = stat_scores_multiple_classes(

RuntimeError: cannot infer num_classes when target is all zero

Expected output

tensor(0.2500)
Metrics bug / fix help wanted waiting on author

All 9 comments

Hi! thanks for your contribution!, great first issue!

@justusschock @SkafteNicki This is a recurring confusion, we got asked about this several times now. Do we need to change the behavior here, or add a note/warning to the docs?

In sklearn for example, there is no such error and it returns 0.25

We can think about that. I just thought that for just one class, most metrics either aren't well defined or aren't descriptive.

Thoughts @SkafteNicki ?

another option would be to convert it to a warning instead of error.

We should probably change it. With so many people being confused about, I think it is a strong indicator that we are not doing what people expect, and we risk that people will not use the metrics if they are counter-intuitive.

In one batch, a single class scenario is common in the case of validation dataset without shuffling.

Even after passing the number of classes argument, it's throwing a warning.

@pchandra90 is this still an actual issue, mind test master? 馃

still raises an error, should be changed to a simple warning. @pchandra90 interested in sending a PR?

@Borda, Tested the master, there is no change in the issue status. Here is the code to reproduce:

import pytorch_lightning as pl
import torch

from pytorch_lightning.metrics.functional import accuracy

print('PyTorch Lightning Version: {}'.format(pl.__version__))

pred = torch.tensor([0, 1, 2, 3])
target = torch.tensor([0, 0, 0, 0])

print('Is the num_classes argument specified: No')
accuracy(pred, target)

RuntimeError:

PyTorch Lightning Version: 0.9.1rc3
Is the num_classes argument specified: No
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-1-2c64fef4d70c> in <module>
     10 
     11 print('Is the num_classes argument specified: No')
---> 12 accuracy(pred, target)

~/anaconda3/envs/pl_dev/lib/python3.7/site-packages/pytorch_lightning/metrics/functional/classification.py in accuracy(pred, target, num_classes, class_reduction)
    268     """
    269     if not (target > 0).any() and num_classes is None:
--> 270         raise RuntimeError("cannot infer num_classes when target is all zero")
    271 
    272     tps, fps, tns, fns, sups = stat_scores_multiple_classes(

RuntimeError: cannot infer num_classes when target is all zero


import pytorch_lightning as pl
import torch

from pytorch_lightning.metrics.functional import accuracy

print('PyTorch Lightning Version: {}'.format(pl.__version__))

pred = torch.tensor([0, 1, 2, 3])
target = torch.tensor([0, 0, 0, 0])

print('Is the num_classes argument specified: Yes')
accuracy(pred, target, num_classes=10)

Output with the warning:

PyTorch Lightning Version: 0.9.1rc3
Is the num_classes argument specified: Yes
/home/prakash/anaconda3/envs/pl_dev/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:37: UserWarning: You have set 10 number of classes if different from predicted (4) and target (1) number of classes
  warnings.warn(*args, **kwargs)
tensor(0.2500)

@rohitgr7 , Yes, I can send PR to fix the issue. However, can not do it before the weekend.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

williamFalcon picture williamFalcon  路  3Comments

Vichoko picture Vichoko  路  3Comments

DavidRuhe picture DavidRuhe  路  3Comments

versatran01 picture versatran01  路  3Comments

polars05 picture polars05  路  3Comments