Flair: Confusion Matrix for Multi-class prediction

Created on 12 Feb 2019  路  6Comments  路  Source: flairNLP/flair

Hi guys,

I would like to share a concern regarding the prediction results output.
Currently, true negatives are being calculated for multi-class classification, which will produce duplicated counts and artificially increase the overall accuracy.
Please take a look at the following portion of log:

2019-02-07 22:45:48,581 Testing using best model ...
2019-02-07 22:45:48,664 MICRO_AVG: acc 0.9963 - f1-score 0.9448
2019-02-07 22:45:48,665 MACRO_AVG: acc 0.9963 - f1-score 0.9227
2019-02-07 22:45:48,665 label_1 tp: 5 - fp: 1 - fn: 0 - tn: 139 - precision: 0.8333 - recall: 1.0000 - accuracy: 0.9931 - f1-score: 0.9091
2019-02-07 22:45:48,665 label_2 tp: 5 - fp: 0 - fn: 0 - tn: 140 - precision: 1.0000 - recall: 1.0000 - accuracy: 1.0000 - f1-score: 1.0000
2019-02-07 22:45:48,665 label_3 tp: 5 - fp: 1 - fn: 0 - tn: 139 - precision: 0.8333 - recall: 1.0000 - accuracy: 0.9931 - f1-score: 0.9091
2019-02-07 22:45:48,665 label_4 tp: 5 - fp: 0 - fn: 0 - tn: 140 - precision: 1.0000 - recall: 1.0000 - accuracy: 1.0000 - f1-score: 1.0000
2019-02-07 22:45:48,665 label_5 tp: 5 - fp: 1 - fn: 0 - tn: 139 - precision: 0.8333 - recall: 1.0000 - accuracy: 0.9931 - f1-score: 0.9091
2019-02-07 22:45:48,665 label_6 tp: 5 - fp: 0 - fn: 0 - tn: 140 - precision: 1.0000 - recall: 1.0000 - accuracy: 1.0000 - f1-score: 1.0000
2019-02-07 22:45:48,665 label_7 tp: 5 - fp: 0 - fn: 0 - tn: 140 - precision: 1.0000 - recall: 1.0000 - accuracy: 1.0000 - f1-score: 1.0000
2019-02-07 22:45:48,666 label_8 tp: 5 - fp: 0 - fn: 0 - tn: 140 - precision: 1.0000 - recall: 1.0000 - accuracy: 1.0000 - f1-score: 1.0000
2019-02-07 22:45:48,666 label_9 tp: 5 - fp: 0 - fn: 0 - tn: 140 - precision: 1.0000 - recall: 1.0000 - accuracy: 1.0000 - f1-score: 1.0000
2019-02-07 22:45:48,666 label_10 tp: 5 - fp: 1 - fn: 0 - tn: 139 - precision: 0.8333 - recall: 1.0000 - accuracy: 0.9931 - f1-score: 0.9091
2019-02-07 22:45:48,666 label_11 tp: 4 - fp: 0 - fn: 1 - tn: 140 - precision: 1.0000 - recall: 0.8000 - accuracy: 0.9931 - f1-score: 0.8889
2019-02-07 22:45:48,666 label_12 tp: 4 - fp: 0 - fn: 1 - tn: 140 - precision: 1.0000 - recall: 0.8000 - accuracy: 0.9931 - f1-score: 0.8889
2019-02-07 22:45:48,666 label_13 tp: 4 - fp: 1 - fn: 1 - tn: 139 - precision: 0.8000 - recall: 0.8000 - accuracy: 0.9862 - f1-score: 0.8000
2019-02-07 22:45:48,666 label_14 tp: 5 - fp: 0 - fn: 0 - tn: 140 - precision: 1.0000 - recall: 1.0000 - accuracy: 1.0000 - f1-score: 1.0000
2019-02-07 22:45:48,666 label_15 tp: 5 - fp: 1 - fn: 0 - tn: 139 - precision: 0.8333 - recall: 1.0000 - accuracy: 0.9931 - f1-score: 0.9091
2019-02-07 22:45:48,666 label_16 tp: 5 - fp: 0 - fn: 0 - tn: 140 - precision: 1.0000 - recall: 1.0000 - accuracy: 1.0000 - f1-score: 1.0000
2019-02-07 22:45:48,666 label_17 tp: 5 - fp: 0 - fn: 0 - tn: 140 - precision: 1.0000 - recall: 1.0000 - accuracy: 1.0000 - f1-score: 1.0000
2019-02-07 22:45:48,666 label_18 tp: 5 - fp: 0 - fn: 0 - tn: 140 - precision: 1.0000 - recall: 1.0000 - accuracy: 1.0000 - f1-score: 1.0000
2019-02-07 22:45:48,666 label_19 tp: 5 - fp: 0 - fn: 0 - tn: 140 - precision: 1.0000 - recall: 1.0000 - accuracy: 1.0000 - f1-score: 1.0000
2019-02-07 22:45:48,667 label_20 tp: 5 - fp: 0 - fn: 0 - tn: 140 - precision: 1.0000 - recall: 1.0000 - accuracy: 1.0000 - f1-score: 1.0000
2019-02-07 22:45:48,667 label_21 tp: 5 - fp: 0 - fn: 0 - tn: 140 - precision: 1.0000 - recall: 1.0000 - accuracy: 1.0000 - f1-score: 1.0000
2019-02-07 22:45:48,667 label_22 tp: 5 - fp: 0 - fn: 0 - tn: 140 - precision: 1.0000 - recall: 1.0000 - accuracy: 1.0000 - f1-score: 1.0000
2019-02-07 22:45:48,667 label_23 tp: 4 - fp: 0 - fn: 1 - tn: 140 - precision: 1.0000 - recall: 0.8000 - accuracy: 0.9931 - f1-score: 0.8889
2019-02-07 22:45:48,667 label_24 tp: 5 - fp: 0 - fn: 0 - tn: 140 - precision: 1.0000 - recall: 1.0000 - accuracy: 1.0000 - f1-score: 1.0000
2019-02-07 22:45:48,667 label_25 tp: 4 - fp: 0 - fn: 1 - tn: 140 - precision: 1.0000 - recall: 0.8000 - accuracy: 0.9931 - f1-score: 0.8889
2019-02-07 22:45:48,667 label_26 tp: 3 - fp: 0 - fn: 2 - tn: 140 - precision: 1.0000 - recall: 0.6000 - accuracy: 0.9862 - f1-score: 0.7500
2019-02-07 22:45:48,667 label_27 tp: 5 - fp: 0 - fn: 0 - tn: 140 - precision: 1.0000 - recall: 1.0000 - accuracy: 1.0000 - f1-score: 1.0000
2019-02-07 22:45:48,667 label_28 tp: 5 - fp: 0 - fn: 0 - tn: 140 - precision: 1.0000 - recall: 1.0000 - accuracy: 1.0000 - f1-score: 1.0000
2019-02-07 22:45:48,667 label_29 tp: 4 - fp: 0 - fn: 1 - tn: 140 - precision: 1.0000 - recall: 0.8000 - accuracy: 0.9931 - f1-score: 0.8889
2019-02-07 22:45:48,667 label_30 tp: 0 - fp: 2 - fn: 0 - tn: 143 - precision: 0.0000 - recall: 0.0000 - accuracy: 0.9862 - f1-score: 0.0000

The test set contains 145 samples (5 per each label), but if we sum up all [tp,fp,fn,tn] values we would have 4350 (145 x 30 labels) samples.
So, I think reporting these results in a confusion matrix format would be more appropriate since they avoid duplicating the sample counts. Something like this:

| | label_1 | label_2 | ... | label_30 |
| --- | --- | --- | --- | --- |
| label_1 | 5 | 0 | ... | 0 |
| label_2 | 0 | 4 | ... | 2 |
| ... | ... | ... | ... | ... |
| label_30 | 0 | 0 | ... | 0 |

Here, rows denote the true labels and columns denote the predicted labels. Moreover, from this table we can compute Precision, Recall, F1 score and Accuracy values.
You can refer to the scikit-learn page about Confusion Matrix for more information.

Please let me know your thoughts on this.
Thank you!

bug

All 6 comments

Hello @frtacoa thank you for reporting this! This looks like a bug for TN ins the multi-label classification evaluation. We'll take a look!

Hello @frtacoa I think the current code might be correct after all since true positives etc. are always counted per class. So if there are 30 classes, the count across all classes should be data points x 30. This way, we can calcluate the accuracy per class and then average over all accuracies.

I've looked into some papers and I see lots of people only report F-score for multi-class classification, so maybe accuracy is not a good measure here?

Hi @alanakbik , thank you for the follow-up.

I've looked into some papers and I see lots of people only report F-score for multi-class classification, so maybe accuracy is not a good measure here?

In most cases, you would get the same value for accuracy and f1-score. But the disadvantage of using accuracy is that it doesn't take into account the total samples of the classes, that is, if you have a majority class that is much bigger than other classes and is correctly predicted most of the time, you will get a higher overall accuracy.

Actually, there are no problems from the point of view of your code. Calculating the true/false positives/negatives seems to be working fine. The problem comes when calculating the accuracy, even accuracy per class.
Below, I try to explain what happens when you use and don't use the true negatives in the calculation.

This example is taken from scikit-learn and modified to explain about True Positives (TPs), False Positives (FPs), False Negatives (FNs) and True Negatives (TNs) for multi-class prediction, and why we should not include TNs if we want to calculate Accuracy.

Suppose we have a list of true labels y_true and a list of predicted labels y_pred:

import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix

y_true = ["cat", "ant", "cat", "cat", "ant", "bird"]
y_pred = ["ant", "ant", "cat", "cat", "ant", "cat"]

labels = list(set(y_true))
labels.sort()

print("Total labels: %s -> %s" % (len(labels), labels))
Total labels: 3 -> ['ant', 'bird', 'cat']

We can visualize the prediction details in a confusion matrix as follows:

df = pd.DataFrame(
    data=confusion_matrix(y_true, y_pred, labels=labels),
    columns=labels,
    index=labels
)
df
ant bird cat
ant 2 0 0
bird 0 0 1
cat 1 0 2

Using this matrix we can compute Precision, Recall, F1-Score and Accuracy, both local (per class) and global.

First, we compute the local metrics:

#
# Local (metrics per class)
#
tps = {}
fps = {}
fns = {}
precision_local = {}
recall_local = {}
f1_local = {}
accuracy_local = {}
for label in labels:
    tps[label] = df.loc[label, label]
    fps[label] = df[label].sum() - tps[label]
    fns[label] = df.loc[label].sum() - tps[label]
    tp, fp, fn = tps[label], fps[label], fns[label]

    precision_local[label] = tp / (tp + fp) if (tp + fp) > 0. else 0.
    recall_local[label] = tp / (tp + fn) if (tp + fp) > 0. else 0.
    p, r = precision_local[label], recall_local[label]

    f1_local[label] = 2. * p * r / (p + r) if (p + r) > 0. else 0.
    accuracy_local[label] = tp / (tp + fp + fn) if (tp + fp + fn) > 0. else 0.

print("#-- Local measures --#")
print("True Positives:", tps)
print("False Positives:", fps)
print("False Negatives:", fns)
print("Precision:", precision_local)
print("Recall:", recall_local)
print("F1-Score:", f1_local)
print("Accuracy:", accuracy_local)
#-- Local measures --#
True Positives: {'ant': 2, 'bird': 0, 'cat': 2}
False Positives: {'ant': 1, 'bird': 0, 'cat': 1}
False Negatives: {'ant': 0, 'bird': 1, 'cat': 1}
Precision: {'ant': 0.6666666666666666, 'bird': 0.0, 'cat': 0.6666666666666666}
Recall: {'ant': 1.0, 'bird': 0.0, 'cat': 0.6666666666666666}
F1-Score: {'ant': 0.8, 'bird': 0.0, 'cat': 0.6666666666666666}
Accuracy: {'ant': 0.6666666666666666, 'bird': 0.0, 'cat': 0.5}

Then, we compute the global measures:

#
# Global
#
micro_averages = {}
macro_averages = {}

correct_predictions = sum(tps.values())
den = sum(list(tps.values()) + list(fps.values()))
micro_averages["Precision"] = 1. * correct_predictions / den if den > 0. else 0.

den = sum(list(tps.values()) + list(fns.values()))
micro_averages["Recall"] = 1. * correct_predictions / den if den > 0. else 0.

micro_avg_p, micro_avg_r = micro_averages["Precision"], micro_averages["Recall"]
micro_averages["F1-score"] = 2. * micro_avg_p * micro_avg_r / (micro_avg_p + micro_avg_r) if (micro_avg_p + micro_avg_r) > 0. else 0.

macro_averages["Precision"] = np.mean(list(precision_local.values()))
macro_averages["Recall"] = np.mean(list(recall_local.values()))

macro_avg_p, macro_avg_r = macro_averages["Precision"], macro_averages["Recall"]
macro_averages["F1-Score"] = 2. * macro_avg_p * macro_avg_r / (macro_avg_p + macro_avg_r) if (macro_avg_p + macro_avg_r) > 0. else 0.

total_predictions = df.values.sum()
accuracy_global = correct_predictions / total_predictions if total_predictions > 0. else 0.

print("#-- Global measures --#")
print("Micro-Averages:", micro_averages)
print("Macro-Averages:", macro_averages)
print("Correct predictions:", correct_predictions)
print("Total predictions:", total_predictions)
print("Accuracy:", accuracy_global)

#-- Global measures --#
Micro-Averages: {'Precision': 0.6666666666666666, 'Recall': 0.6666666666666666, 'F1-score': 0.6666666666666666}
Macro-Averages: {'Precision': 0.4444444444444444, 'Recall': 0.5555555555555555, 'F1-Score': 0.49382716049382713}
Correct predictions: 4
Total predictions: 6
Accuracy: 0.6666666666666666

So far, we have not computed and used TNs values. Now, let's see what happens if we include these:

#
# TN
#
tns = {}
for label in set(y_true):
    tns[label] = len(y_true) - (tps[label] + fps[label] + fns[label])

print("True Negatives:", tns)
True Negatives: {'ant': 3, 'bird': 5, 'cat': 2}

Here, we are assuming that for each class the TNs can be computed as follows: tn_c = N - (tp_c + fp_c + fn_c)
where N is the total predictions and c is the target class.

Then, we can compute the accuracy that includes TNs and compare against the previous results:

accuracy_local_new = {}
for label in labels:
    tp, fp, fn, tn = tps[label], fps[label], fns[label], tns[label]
    accuracy_local_new[label] = (tp + tn) / (tp + fp + fn + tn) if (tp + fp + fn + tn) > 0. else 0.

total_true = sum(list(tps.values()) + list(tns.values()))
total_predictions = sum(list(tps.values()) + list(tns.values()) + list(fps.values()) + list(fns.values()))
accuracy_global_new = 1. * total_true / total_predictions if total_predictions > 0. else 0.

print("Accuracy (per class), with TNs:", accuracy_local_new)
print("Accuracy (per class), without TNs:", accuracy_local)
print("Accuracy (global), with TNs:", accuracy_global_new)
print("Accuracy (global), without TNs:", accuracy_global)
Accuracy (per class), with TNs: {'ant': 0.8333333333333334, 'bird': 0.8333333333333334, 'cat': 0.6666666666666666}
Accuracy (per class), without TNs: {'ant': 0.6666666666666666, 'bird': 0.0, 'cat': 0.5}
Accuracy (global), with TNs: 0.7777777777777778
Accuracy (global), without TNs: 0.6666666666666666

So, by looking at these results we can see the effect of including TNs for multi-class model accuracy (especially for the bird class).

If we compute the accuracy in scikit-learn, we can confirm that TNs are not being included in the calculation:

from sklearn.metrics import classification_report, accuracy_score
print("Accuracy (global), using sklearn:", accuracy_score(y_true, y_pred))
Accuracy (global), using sklearn: 0.6666666666666666

We can use the classification_report function to confirm the other values that we computed previously (actually, there is a small discrepancy in the value reported for macro avg f1-score):

print(classification_report(y_true, y_pred, digits=4))
              precision    recall  f1-score   support

         ant     0.6667    1.0000    0.8000         2
        bird     0.0000    0.0000    0.0000         1
         cat     0.6667    0.6667    0.6667         3

   micro avg     0.6667    0.6667    0.6667         6
   macro avg     0.4444    0.5556    0.4889         6
weighted avg     0.5556    0.6667    0.6000         6

The following link provides additional information about the confusion matrix and accuracy.

Sorry for so long answer. If something is not clear, please let me know.

Kind regards

Hello @frtacoa - first of all, thank you for the explanations! This is very helpful and makes a lot of sense. We will change the accuracy calculation as you suggest, so the numbers do not get artificially inflated with true negatives. Also, getting the same numbers as Scikit-Learn is always good :) Generally, perhaps we should encourage people to use F-score instead which seems more common.

Another note: I think the reason that your F-score is different from the one used by Scikit-Learn is that they compute the macro average F-score as discussed in #486 - that is, they compute each F-score and then get the average. We changed our code to also do it the way that Sciki-Learn is doing.

Again, thank you very much for illustrating this and the many examples!

Hi @alanakbik , thanks for your comments above. You are right, I could confirm that the way of scikit-learn to calculate the Macro Avg F1 is as you mentioned.
So, with the code from above, you can calculate it like this:

macro_averages["F1-Score"] = np.mean(list(f1_local.values()))

This is fixed in 0.4.1 so closing this issue - thanks again @frtacoa!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

Aditya715 picture Aditya715  路  3Comments

alanakbik picture alanakbik  路  3Comments

gopalkalpande picture gopalkalpande  路  3Comments

mnishant2 picture mnishant2  路  3Comments

jannenev picture jannenev  路  3Comments