shap解释CNN解读

background = x_train[np.random.choice(x_train.shape[0], 1000, replace=False)]
e = shap.DeepExplainer(model, background)
shap_values = e.shap_values(x_test[test_index:test_index+1])
shap.image_plot(shap_values, -x_test[test_index:test_index+1])

注:x_train是形状为(27301,12,3)的array,x_train[np.random.choice(x_train.shape[0], 1000, replace=False)]是随机抽sample size = 1000的样本,pd.DataFrame不能按照row_index进行select,df[[]]的方式是按列进行select

background是通过抽样的方式确定base_value。
shap.image_plot画出来的图如下:
左边是原图,原始值取负号,所以越白的地方值越小,越黑的地方值越大。

右边是解释图,每个像素点的shap value,图像上shap value的分布是case by case的,但是也有一些规律可循,可以把每个像素点的shap value绝对值求平均,得到每个像素点的total贡献程度,然后看哪个区域的值高,那么对应区域就比较重要。












留言

熱門文章