'SHAP function throws exception in plotting method

samples.zip The sample zipped folder contains:

  1. model.pkl
  2. x_test.csv

To reproduce the problems, do the following steps:

  1. use lin2 =joblib.load('model.pkl') to load the linear regression model
  2. use x_test_2 = pd.read_csv('x_test.csv').drop(['Unnamed: 0'],axis=1) to load the x_test_2
  3. run the code below to load the explainer
explainer_test = shap.Explainer(lin2.predict, x_test_2)
shap_values_test = explainer_test(x_test_2)
  1. Then run partial_dependence_plot to see the error message:

ValueError: x and y can be no greater than 2-D, but have shapes (2,) and (2, 1, 1)

sample_ind = 3
shap.partial_dependence_plot(
    "new_personal_projection_delta", 
    lin.predict, 
    x_test, model_expected_value=True,
    feature_expected_value=True, ice=False,
    shap_values=shap_values_test[sample_ind:sample_ind+1,:]
)
  1. Run another function to plot waterfall to see error message:

Exception: waterfall_plot requires a scalar base_values of the model output as the first parameter, but you have passed an array as the first parameter! Try shap.waterfall_plot(explainer.base_values[0], values[0], X[0]) or for multi-output models try shap.waterfall_plot(explainer.base_values[0], values[0][0], X[0]).

shap.plots.waterfall(shap_values_test[sample_ind], max_display=14)

Questions:

  1. Why I cannot run partial_dependence_plot & shap.plots.waterfall?
  2. What changes I need to do with my input so I can run the methods above?


Solution 1:[1]

You need to properly construct Explanation object expected by new SHAP plotting API.

The following will do:

import joblib
import shap
import warnings
warnings.filterwarnings("ignore")

model =joblib.load('model.pkl')
data = pd.read_csv('x_test.csv').drop(['Unnamed: 0'],axis=1)
explainer = shap.Explainer(model.predict, data)
sv = explainer(data)

idx = 3
exp = shap.Explanation(sv.values, sv.base_values[0][0], sv.data)
shap.plots.waterfall(exp[idx])

enter image description here

shap.partial_dependence_plot(
    "x7",
    model.predict,
    data,
    model_expected_value=True,
    feature_expected_value=True,
    ice=False,
    shap_values=exp
)

enter image description here

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1