I'm new to ignite, and I want to check are there any build in multiclass accuracy methods? Like giving each class's accuracy? I don't see there is one in the doc and the multilabel option for Accuracy seems not working for me.
def output_converter(output):
y_pred, y = output
y_pred_result = torch.zeros(y_pred.shape[0], y_pred.shape[1])
y_result = torch.zeros(y_pred.shape[0], y_pred.shape[1])
_, preds_temp = torch.max(y_pred, 1)
for i in range(preds_temp.shape[0]):
y_pred_result[i][preds_temp[i]] = 1
y_result[i][y[i]] = 1
return (y_pred_result, y_result)
temp = ignite.metrics.Accuracy(output_transform=output_converter, is_multilabel=True)
And this is how I used the mutlilabel option, but it gives me a summed up accuracy instead of accuracy for each class, what should I do to fix this?
Thanks!
@CDWJustin thank you for this question!
You are right, actually there is no such built-in metric. I discussed exactly the point with @vfdev-5 few weeks ago.
However, as you did, it鈥檚 possible to use output_transform to create class-wise metrics
def get_single_label_output_fn(c):
def wrapper(output):
y_pred, y = output["y_pred"], output["y"]
return y_pred[:, c], y[:, c]
return wrapper
for i in range(config.num_classes):
for name, cls in zip(["Accuracy", "Precision", "Recall"], [Accuracy, Precision, Recall]):
val_metrics["{}/{}".format(name, i)] = cls(output_transform=get_single_label_output_fn(i))
We have an ongoing PR about nested metrics #968 (related to issue #959). It should help a lot to design a smart answer to your question.
HTH
@CDWJustin absence of this option is also related to the fact that how would you define it ?
In scikit-learn for example, there is no option for that neither. See how accuracy is defined in scikit-learn : https://scikit-learn.org/stable/modules/model_evaluation.html#accuracy-score
However, yes, we can one-hot encode targets and then iterpret multiclass problem as multilabel one and use the code from @sdesrozis 's answer.
@CDWJustin PR #968 just merged. It could help you to write your metric with specific output_transform and built-in Accuracy. Tell me if you need help.
Thanks so much! Will try this out!
Thanks so much! Problem solved! I'll close the issue then!