When using Accuracy.update() with both inputs having the second dimension 1, e.g. in my case torch.Size([256,1]) the raised error message is misleading.
To reproduce
from ignite.metrics import Accuracy
import torch
acc = Accuracy(is_multilabel=True)
acc.update((torch.zeros((256,1)), torch.zeros((256,1))))
ValueError: y and y_pred must have same shape of (batch_size, num_categories, ...).
In this case the y and y_pred do have the same shape but the issue is that it's not an accepted multilabel input (the and y.shape[1] != 1 in the following code block from _check_shape in _BaseClassification). This should be indicated in the error message (or the if statement changed).
What is the argument to not allow a y.shape[1] of 1?
if self._is_multilabel and not (y.shape == y_pred.shape and y.ndimension() > 1 and y.shape[1] != 1):
raise ValueError("y and y_pred must have same shape of (batch_size, num_categories, ...).")
conda, pip, source): condaOk, I'm still working through this. Sorry for having posted too early. Maybe the issue is that the two lines
self._check_shape((y_pred, y))
self._check_type((y_pred, y))
should be in reverse order!?
Maybe someone could have a look at it.
@niowniow thanks for posting !
I agree with your first message about misleading error message regarding the input as in your example. Maybe, even the condition y.shape[1] != 1 is a bit wrong and should be y.shape[1] > 1 (to cover cases like t = torch.zeros(10, 0, 2)).
I think, we just need to improve the error message to
"y and y_pred must have same shape of (batch_size, num_categories, ...) and num_categories > 1"
What do you think ?
@niowniow Thank you for the question. The point is num_categories should be greater than one wrt multilabel classification, shouldn't it ?
@vfdev-5 You are too fast!! :)
@vfdev-5 That would probably do it. However, I'm don't know if it holds for y.shape[1] > 1
@sdesrozis Yes, it should be greater than one. However, I ran multiple tests with varying number of categories. Some happen to have only one category. And I didn't check for it before instantiating Accuracy.
But would there be a difference in the output of the binary case and the multilabel case with num_categories==1?
@sdesrozis Yes, it should be greater than one. However, I ran multiple tests with varying number of categories. Some happen to have only one category. And I didn't check for it before instantiating Accuracy.
Ok it makes sense, thank you for that clarification.
But would there be a difference in the output of the binary case and the multilabel case with
num_categories==1?
Actually we don't have the same implementation but I agree that should not be different. Let's see what we can do.
If num_categories changes, another error is raised:
https://github.com/pytorch/ignite/blob/01a2bc0f79ca03565dd0c112b73d6508bc81b070/ignite/metrics/accuracy.py#L88-L91
We discuss this issue and the behavior of Accuracy with is_multilabel=True in the case of num_categories == 1 won't change. Actually, the case num_categories == 1 is already covered by binary (i.e. simply is_multilabel=False here). We want to maintain a clear usage and do not have multiple ways (and maybe bad ones wrt performance) to do the same thing. I agree that is_multilabel=True and num_categories == 1 means binary but our implementation does not fit that.
Btw, I can help you about your specific needs. Feel free to share snippets, as you can see, @vfdev-5 is very fast, my challenge is to be faster :)
Thanks. It's fine for me to have another if-condition checking if my input should use binary.
@niowniow Perfect!
Most helpful comment
We discuss this issue and the behavior of
Accuracywithis_multilabel=Truein the case ofnum_categories == 1won't change. Actually, the casenum_categories == 1is already covered bybinary(i.e. simplyis_multilabel=Falsehere). We want to maintain a clear usage and do not have multiple ways (and maybe bad ones wrt performance) to do the same thing. I agree thatis_multilabel=Trueandnum_categories == 1meansbinarybut our implementation does not fit that.Btw, I can help you about your specific needs. Feel free to share snippets, as you can see, @vfdev-5 is very fast, my challenge is to be faster :)