I have a task where I first compute SHAP values, then plot subsets of rows with summary_plot after clustering. What's happening is that one of the subsets has features with all NaNs and I get this error:
<ipython-input-65-b77b1b60a575> in visualize_clusters_shap(shap_vals, X, feat_cols, cluster_labels, file_name_label, top_k)
79 x = shap_vals_sub[cluster_labels == i, :]
80 y = Xsub[cluster_labels == i, :]
---> 81 shap.summary_plot(x, y, feature_names=feat_cols_sub, max_display=50, show=False, color_bar=False, sort=False)
82 plt.title('Cluster #%s' % i)
83 plt.xlim([-v, v])
~/anaconda3/lib/python3.6/site-packages/shap/plots/summary.py in summary_plot(shap_values, features, feature_names, max_display, plot_type, color, axis_color, title, alpha, show, sort, color_bar, auto_size_plot, layered_violin_max_num_bins, class_names)
192 cmap=colors.red_blue, vmin=vmin, vmax=vmax, s=16,
193 c=values[np.invert(nan_mask)], alpha=alpha, linewidth=0,
--> 194 zorder=3, rasterized=len(shaps) > 500)
195 else:
196
~/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py in scatter(x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, verts, edgecolors, data, **kwargs)
2862 vmin=vmin, vmax=vmax, alpha=alpha, linewidths=linewidths,
2863 verts=verts, edgecolors=edgecolors, **({"data": data} if data
-> 2864 is not None else {}), **kwargs)
2865 sci(__ret)
2866 return __ret
~/anaconda3/lib/python3.6/site-packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs)
1803 "the Matplotlib list!)" % (label_namer, func.__name__),
1804 RuntimeWarning, stacklevel=2)
-> 1805 return func(ax, *args, **kwargs)
1806
1807 inner.__doc__ = _add_data_doc(inner.__doc__,
~/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_axes.py in scatter(self, x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, verts, edgecolors, **kwargs)
4193 isinstance(c, str) or
4194 (isinstance(c, collections.Iterable) and
-> 4195 isinstance(c[0], str))):
4196 c_array = None
4197 else:
IndexError: index 0 is out of bounds for axis 0 with size 0
Any ideas if this can be fixed? I realize having a column of all NaNs is not a standard use-case, but I am hoping that this case can still be accomodated.
By the way, thanks for the excellent package and all your hard work! It's very much appreciated.
Thanks! Seems like a good thing to fix, but I can't seem to reproduce it. Could you give a small example? The follow works for me (with a warning):
import xgboost
import shap
import numpy as np
X,y = shap.datasets.boston()
model = xgboost.XGBRegressor(n_estimators=20, max_depth=10)
model.fit(X, y)
e = shap.TreeExplainer(model, X)
shap_values = e.shap_values(X)
X.iloc[:,0] = np.nan
shap.summary_plot(shap_values, X)
Thanks for getting back to me on this. The following breaks for me:
import shap
import numpy as np
shap_vals = np.array([[0.46430412, 0.1721865],
[0.4520666 , 0.17486936]])
X = np.array([[ 9., np.nan],
[np.nan, np.nan]])
shap.summary_plot(shap_vals, X)
Interesting. Could you try the master version? It does not break for me, just gives warnings.
Just cloned and installed via python setup.py install. It installed v0.26.0 without issue and I still get the error when I paste in my latest short snippet. Maybe it has to do with the matplotlib version? The trace below suggests that.
I have matplotlib v3.0.1, which conda won't upgrade any further.
Here is the full trace:
IndexError Traceback (most recent call last)
<ipython-input-1-6602b60aad66> in <module>
8 [np.nan, np.nan]])
9
---> 10 shap.summary_plot(shap_vals, X)
~/anaconda3/lib/python3.6/site-packages/shap-0.26.0-py3.6-linux-x86_64.egg/shap/plots/summary.py in summary_plot(shap_values, features, feature_names, max_display, plot_type, color, axis_color, title, alpha, show, sort, color_bar, auto_size_plot, layered_violin_max_num_bins, class_names)
192 cmap=colors.red_blue, vmin=vmin, vmax=vmax, s=16,
193 c=values[np.invert(nan_mask)], alpha=alpha, linewidth=0,
--> 194 zorder=3, rasterized=len(shaps) > 500)
195 else:
196
~/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py in scatter(x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, verts, edgecolors, data, **kwargs)
2862 vmin=vmin, vmax=vmax, alpha=alpha, linewidths=linewidths,
2863 verts=verts, edgecolors=edgecolors, **({"data": data} if data
-> 2864 is not None else {}), **kwargs)
2865 sci(__ret)
2866 return __ret
~/anaconda3/lib/python3.6/site-packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs)
1803 "the Matplotlib list!)" % (label_namer, func.__name__),
1804 RuntimeWarning, stacklevel=2)
-> 1805 return func(ax, *args, **kwargs)
1806
1807 inner.__doc__ = _add_data_doc(inner.__doc__,
~/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_axes.py in scatter(self, x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, verts, edgecolors, **kwargs)
4193 isinstance(c, str) or
4194 (isinstance(c, collections.Iterable) and
-> 4195 isinstance(c[0], str))):
4196 c_array = None
4197 else:
IndexError: index 0 is out of bounds for axis 0 with size 0
My matplotlib was 2.2.2 so I upgraded to 3.0.2 but still I don't see the error. I would be surprised if 3.0.1 to 3.0.2 fixed it.
Oh boy, not only did upgrading not fix it, I now get Segmentation fault (core dumped)!
In any case, this looks like a problem that I'm having, and not your package's fault. Thanks for helping me debug. I'll close this.
I would be surprised if 3.0.1 to 3.0.2 fixed it.
In my case, this is exactly what did the job.
I used to have the same problem and I fixed it by upgrading Matplotlib from 3.0.1 to 3.1.1.
Segmentation fault (core dumped)!
I was getting similar issue while importing the shap.
This is how I fixed the issue by rollbacking shap to previous version.
pip install shap==0.34.0
Most helpful comment
In my case, this is exactly what did the job.