0.9.0rc12from pytorch_lightning.metrics.functional import accuracy
pred = torch.tensor([0, 1, 2, 3])
target = torch.tensor([0, 0, 0, 0])
# calculates accuracy across all GPUs and all Nodes used in training
accuracy(pred, target)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-16-5d9063513e76> in <module>
5
6 # calculates accuracy across all GPUs and all Nodes used in training
----> 7 accuracy(pred, target)
~/anaconda3/envs/pl/lib/python3.7/site-packages/pytorch_lightning/metrics/functional/classification.py in accuracy(pred, target, num_classes, reduction)
268 """
269 if not (target > 0).any() and num_classes is None:
--> 270 raise RuntimeError("cannot infer num_classes when target is all zero")
271
272 tps, fps, tns, fns, sups = stat_scores_multiple_classes(
RuntimeError: cannot infer num_classes when target is all zero
tensor(0.2500)
Hi! thanks for your contribution!, great first issue!
@justusschock @SkafteNicki This is a recurring confusion, we got asked about this several times now. Do we need to change the behavior here, or add a note/warning to the docs?
In sklearn for example, there is no such error and it returns 0.25
We can think about that. I just thought that for just one class, most metrics either aren't well defined or aren't descriptive.
Thoughts @SkafteNicki ?
another option would be to convert it to a warning instead of error.
We should probably change it. With so many people being confused about, I think it is a strong indicator that we are not doing what people expect, and we risk that people will not use the metrics if they are counter-intuitive.
In one batch, a single class scenario is common in the case of validation dataset without shuffling.
Even after passing the number of classes argument, it's throwing a warning.
@pchandra90 is this still an actual issue, mind test master? 馃
still raises an error, should be changed to a simple warning. @pchandra90 interested in sending a PR?
@Borda, Tested the master, there is no change in the issue status. Here is the code to reproduce:
import pytorch_lightning as pl
import torch
from pytorch_lightning.metrics.functional import accuracy
print('PyTorch Lightning Version: {}'.format(pl.__version__))
pred = torch.tensor([0, 1, 2, 3])
target = torch.tensor([0, 0, 0, 0])
print('Is the num_classes argument specified: No')
accuracy(pred, target)
RuntimeError:
PyTorch Lightning Version: 0.9.1rc3
Is the num_classes argument specified: No
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-1-2c64fef4d70c> in <module>
10
11 print('Is the num_classes argument specified: No')
---> 12 accuracy(pred, target)
~/anaconda3/envs/pl_dev/lib/python3.7/site-packages/pytorch_lightning/metrics/functional/classification.py in accuracy(pred, target, num_classes, class_reduction)
268 """
269 if not (target > 0).any() and num_classes is None:
--> 270 raise RuntimeError("cannot infer num_classes when target is all zero")
271
272 tps, fps, tns, fns, sups = stat_scores_multiple_classes(
RuntimeError: cannot infer num_classes when target is all zero
import pytorch_lightning as pl
import torch
from pytorch_lightning.metrics.functional import accuracy
print('PyTorch Lightning Version: {}'.format(pl.__version__))
pred = torch.tensor([0, 1, 2, 3])
target = torch.tensor([0, 0, 0, 0])
print('Is the num_classes argument specified: Yes')
accuracy(pred, target, num_classes=10)
Output with the warning:
PyTorch Lightning Version: 0.9.1rc3
Is the num_classes argument specified: Yes
/home/prakash/anaconda3/envs/pl_dev/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:37: UserWarning: You have set 10 number of classes if different from predicted (4) and target (1) number of classes
warnings.warn(*args, **kwargs)
tensor(0.2500)
@rohitgr7 , Yes, I can send PR to fix the issue. However, can not do it before the weekend.