Hi!
I've been working in a multiclass problem but I don't know how to identify the class in the shap_values matrix. For instance, the next figure:

The plot shows class 0,1 and 2 but in the y_test I have 1, 2 and 3 instead. How can I identify the real class in the plot?
Thanks in advance!
Hey @franciscorodriguez92! If you create a list of your class names in the order that they're presented in your label/target vectors, then you can pass that list directly to the plotting function as a keyword argument. For example, I did this recently:
class_names = ["Class A", "Class B", "Class C"]
shap.summary_plot(shap_values, features=features, feature_names=feature_names, class_names=class_names)
The plotting function will then add the class names to the plot's legend. It worked quite nicely for me! You just need to make sure the class names are in the same order as their associate SHAP values arrays in your multiclass SHAP values list.
Hi @GarrettCGraham thanks for answering!
Do you think I could use the classes_ argument to make sure the class names are in the same order as their associate SHAP values? For example:
shap.summary_plot(shap_values, features=features, feature_names=feature_names, class_names=model.classes_)
Thanks for helping!
Sure thang, @franciscorodriguez92! I'm not sure which ML library you're using, but that looks like it'd work just fine!
I'd just check your plot afterwards to make certain that at least a few things intuitively make sense. If you notice that prominent features aren't explaining the classes you'd expect them to explain, that might mean your class labels were out of order.
Perfect! Thanks so much for your help!
Most helpful comment
Hey @franciscorodriguez92! If you create a list of your class names in the order that they're presented in your label/target vectors, then you can pass that list directly to the plotting function as a keyword argument. For example, I did this recently:
class_names = ["Class A", "Class B", "Class C"]shap.summary_plot(shap_values, features=features, feature_names=feature_names, class_names=class_names)The plotting function will then add the class names to the plot's legend. It worked quite nicely for me! You just need to make sure the class names are in the same order as their associate SHAP values arrays in your multiclass SHAP values list.