Shap: support custom colors in plots

Created on 20 Apr 2018  路  21Comments  路  Source: slundberg/shap

Red and blue are nice, but it's generally a requirement in business work to be able to change the color (including specific to a client).

This definitely applies to the summary_plot, and may apply to others. Easy fixes, if anyone's interested.

todo

Most helpful comment

I hope it helps

import matplotlib.pyplot as pl

shap.summary_plot(shap_values, X, plot_type="dot",color=pl.get_cmap("tab10"))
image

Here is another example where ability to change color map is useful: summary-plot for multi-class classification. Currently, it is hard to differentiate the different shades of blue.

summary-plot-multiclass

All 21 comments

Custom colors are available in force plots already #37 . For summary plots, there's a color argument, but it doesn't appear to be doing much. In the code it looks like it's for coloring the scatter plots in summary plots, but I tinkered with some values and it didn't change the plot. This could be a great addition, since in business work, often plots are colored with the business' colors in mind.

Great point. The summary plots need to be updated to support other color scales.

Aside from custom colors/maps, it was brought up in #58 that supporting these per feature might be a requirement. Easy to do - just nailing down the API is important.

I was going to open a new issue related to this, but I'll just add it in a comment. I see that force_plot accepts the arg plot_cmap, which can be used to change the colors. My suggestion is to put that in the shap docs as an option, or since it's implemented in iml, document it there, then put a link to the details in the shap readme.

I've shown the shap plots to a few people and they were a bit confused because the default magenta in summary_plot (meaning high value for the actual feature) is the same as the magenta bars in the single-instance force_plot (meaning positive shap value for that feature). They definitely look pretty but maybe the force plot should use a different default. :)

@vaughnkoch Thanks for the suggestion! I have been slowly working on getting read-the-docs working for shap, so I'll try and work this into that.

Here is another example where ability to change color map is useful: summary-plot for multi-class classification. Currently, it is hard to differentiate the different shades of blue.

summary-plot-multiclass

Yeah it would be good to have a plot_cmap option for the summary plot as well. I can't promise when I can get to that though.

I hope it helps

import matplotlib.pyplot as pl

shap.summary_plot(shap_values, X, plot_type="dot",color=pl.get_cmap("tab10"))
image

Here is another example where ability to change color map is useful: summary-plot for multi-class classification. Currently, it is hard to differentiate the different shades of blue.

summary-plot-multiclass

from here you can choose the color pallet you want
https://matplotlib.org/examples/color/colormaps_reference.html

@ferdous150439 Yes, this was added in 0.26.0 release, I believe.

Here's an example if you want to specify a color per class:

from matplotlib import colors as plt_colors
import numpy as np
import shap

# class names
classes = ['r', 'g', 'b']

# set RGB tuple per class
colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]

# get class ordering from shap values
class_inds = np.argsort([-np.abs(shap_values[i]).mean() for i in range(len(shap_values))])

# create listed colormap
cmap = plt_colors.ListedColormap(np.array(colors)[class_inds])

# plot
shap.summary_plot(shap_values, features, feature_names, color=cmap, class_names=classes)

I suggest allowing users to use a dictionary like {class: color} or correctly ordered list of colors as inputs to the color parameter.

@JamesTownend sounds like it would be a nice feature! PRs are welcome :)

Hi guy, still not working for me, any ideas?

this is the data set i am playing with: https://www.kaggle.com/dansbecker/hospital-readmissions
would love to change to colors for the plot

from matplotlib import colors as plt_colors
# class names
classes = ['a', 'b']

# set RGB tuple per class
colors = [(200, 200, 0), (200, 200, 0)]

# get class ordering from shap values
class_inds = np.argsort([-np.abs(shap_values[i]).mean() for i in range(len(shap_values))])

# create listed colormap
cmap = plt_colors.ListedColormap(np.array(colors)[class_inds])
shap.summary_plot(shap_values[1], small_val_X,color=cmap, class_names=classes)

image

@ra161 I see that your list of two colors has one same tuple of RGB code:
colors = [(200, 200, 0), (200, 200, 0)]
Maybe this is the problem? Try alternating these two colors?

@ibuda yea just saw that also. The issue is the colors don't seem to change no matter the cmap:

cmap = plt.get_cmap('hot')
shap.summary_plot(shap_values[1], small_val_X,color=cmap)

Only trying to plot this for 1 class hence slicing into shap_values[1]. Would love a way to get the colors changeing with a cmap

I have been trying to create my 2 gradient color map doing the following, but nothing has changed:

from matplotlib import cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from matplotlib import colors as plt_colors
import matplotlib

RGB_val = 255

color01= (0/RGB_val,150/RGB_val,200/RGB_val)  # Blue wanted
color04= (220/RGB_val,60/RGB_val,60/RGB_val)  # red wanted
Colors = [color01, color04]
CustomCmap = matplotlib.colors.ListedColormap(Colors, name="MyColors")
# Trials
shap.summary_plot(shap_values_XGB_train, X_train, color=pl.get_cmap("MyColors"))
shap.summary_plot(shap_values_XGB_train, X_train, color=CustomCmap)

Any ideas ?

I've been needing a custom colormap for shap.summary_plot() for a while now and came up with this workaround solution using the set_cmap() function of figure's artists:

import shap
import numpy as np
import matplotlib.pyplot as plt

# Define colormap
my_cmap = plt.get_cmap('viridis')

# Plot the summary without showing it
plt.figure()
shap.summary_plot(np.array([[-1., 0., 1.]]).T,
                  features=np.array([[-1., 0., 1.]]).T,
                  feature_names=['Feature1'],
                  show=False
                  )

# Change the colormap of the artists
for fc in plt.gcf().get_children():
    for fcc in fc.get_children():
        if hasattr(fcc, "set_cmap"):
            fcc.set_cmap(my_cmap)

image

Still not working for me:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn import datasets
from sklearn.model_selection import train_test_split
import xgboost as xgb
import shap

# import some data to play with
iris = datasets.load_iris()
Y = pd.DataFrame(iris.target, columns = ["Species"])
X = pd.DataFrame(iris.data, columns = iris.feature_names)


X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3, random_state=0, stratify=Y)

params = { # General Parameters
            'booster': 'gbtree',
            # Param for boosting
             'eta': 0.2, 
             'gamma': 1,
             'max_depth': 5,
             'min_child_weight': 5,
             'subsample': 0.5,
             'colsample_bynode': 0.5,             
             'lambda': 0,  #default = 0                                        
             'alpha': 1,    #default = 1            
            # Command line parameters
             'num_rounds': 10000,
            # Learning Task Parameters
             'objective': 'multi:softprob' #'multi:softprob'
             }


model = xgb.XGBClassifier(**params, verbose=0, cv=5 , )
# fitting the model
model.fit(X_train,np.ravel(Y_train), eval_set=[(X_test, np.ravel(Y_test))], early_stopping_rounds=20)
# Tree on XGBoost
explainerXGB = shap.TreeExplainer(model, data=X, model_output ="margin")
#recall one  can put "probablity"  then we explain the output of the model transformed 
#into probability space (note that this means the SHAP values now sum to the probability output of the model).
shap_values_XGB_test = explainerXGB.shap_values(X_test)
shap_values_XGB_train = explainerXGB.shap_values(X_train)

import matplotlib.pyplot as plt
shap.summary_plot(shap_values_XGB_train, X_train, show = False)#color=cmap

plt.figure()
my_cmap = plt.get_cmap('viridis')

# Change the colormap of the artists
for fc in plt.gcf().get_children():
    for fcc in fc.get_children():
        if hasattr(fcc, "set_cmap"):
            fcc.set_cmap(my_cmap)

Hm, by placing plt.figure() BEFORE the shap.summary_plot(), I get this:
image

Any recent update on this ?

Hm, by placing plt.figure() BEFORE the shap.summary_plot(), I get this:
image

This solution worked! Thank you!

Was this page helpful?
0 / 5 - 0 ratings