From sciagent-skills
Explains ML model predictions using SHAP with Tree/Deep/Linear/Kernel explainers and plots (beeswarm, waterfall, force). Use for feature attribution, debugging, fairness audits, model comparison.
How this skill is triggered — by the user, by Claude, or both
Slash command
/sciagent-skills:shap-model-explainabilityThe summary Claude sees in its skill listing — used to decide when to auto-load this skill
SHAP (SHapley Additive exPlanations) is a unified framework for explaining machine learning model predictions using Shapley values from cooperative game theory. It quantifies each feature's contribution to individual predictions and provides both local (per-instance) and global (dataset-level) explanations with theoretical guarantees of consistency and additivity.
SHAP (SHapley Additive exPlanations) is a unified framework for explaining machine learning model predictions using Shapley values from cooperative game theory. It quantifies each feature's contribution to individual predictions and provides both local (per-instance) and global (dataset-level) explanations with theoretical guarantees of consistency and additivity.
pip install shap matplotlib
# Optional: xgboost lightgbm tensorflow torch (depending on model)
import shap
import xgboost as xgb
from sklearn.model_selection import train_test_split
# Load example data
X, y = shap.datasets.adult()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# Train model
model = xgb.XGBClassifier(n_estimators=100).fit(X_train, y_train)
# Explain: select explainer → compute → visualize
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test)
shap.plots.beeswarm(shap_values) # Global importance
shap.plots.waterfall(shap_values[0]) # Single prediction
print(f"Base value: {shap_values.base_values[0]:.3f}")
print(f"SHAP values shape: {shap_values.values.shape}") # (n_samples, n_features)
Choose based on model type:
| Model Type | Explainer | Speed | Exactness |
|---|---|---|---|
| Tree-based (XGBoost, LightGBM, RF, CatBoost) | TreeExplainer | Fast | Exact |
| Linear (LogReg, GLM, Ridge) | LinearExplainer | Instant | Exact |
| Deep learning (TensorFlow, PyTorch) | DeepExplainer | Fast | Approximate |
| Deep learning (gradient-based) | GradientExplainer | Fast | Approximate |
| Any model (black-box) | KernelExplainer | Slow | Approximate |
| Any model (permutation-based) | PermutationExplainer | Very slow | Exact |
| Unsure? | shap.Explainer | Auto | Auto |
# Tree-based models (most common)
explainer = shap.TreeExplainer(model)
# Linear models
explainer = shap.LinearExplainer(model, X_train)
# Deep learning
explainer = shap.DeepExplainer(model, X_train[:100])
# Any model (model-agnostic, slower)
explainer = shap.KernelExplainer(model.predict, shap.kmeans(X_train, 50))
# Auto-select
explainer = shap.Explainer(model, X_train)
shap_values = explainer(X_test)
# shap_values object contains:
# .values — SHAP values array (n_samples, n_features)
# .base_values — Expected model output (baseline)
# .data — Original feature values
# Verify additivity: prediction = base_value + sum(SHAP values)
print(f" {shap_values.base_values[0]:.3f} + {shap_values.values[0].sum():.3f} = "
f"{shap_values.base_values[0] + shap_values.values[0].sum():.3f}")
# Beeswarm: feature importance + value distributions (most informative)
shap.plots.beeswarm(shap_values, max_display=15)
# Bar: clean mean |SHAP| importance
shap.plots.bar(shap_values)
# Waterfall: detailed breakdown of one prediction
shap.plots.waterfall(shap_values[0])
# Force: additive force visualization
shap.plots.force(shap_values[0])
# Scatter: how a feature affects predictions
shap.plots.scatter(shap_values[:, "Age"])
# Colored by interaction feature
shap.plots.scatter(shap_values[:, "Age"], color=shap_values[:, "Education-Num"])
# Heatmap: multi-sample SHAP grid
shap.plots.heatmap(shap_values[:100])
# Decision plot: cumulative SHAP paths
shap.plots.decision(shap_values.base_values[0], shap_values.values[:10],
feature_names=X_test.columns.tolist())
# Cohort comparison
import numpy as np
mask_a = X_test["Age"] < 40
shap.plots.bar({
"Under 40": shap_values[mask_a],
"40+": shap_values[~mask_a]
})
| Parameter | Explainer/Function | Default | Effect |
|---|---|---|---|
feature_perturbation | TreeExplainer | "tree_path_dependent" | "interventional" for causal interpretation (requires background data) |
model_output | TreeExplainer | "raw" | "probability" to explain probabilities instead of log-odds |
data (background) | KernelExplainer, DeepExplainer | Required | 100-1000 representative samples; use shap.kmeans(X, 50) for efficiency |
nsamples | KernelExplainer | "auto" | Higher = more accurate but slower; minimum 2×features |
max_display | All plot functions | 10 | Number of features shown in plots |
alpha | scatter/beeswarm | 1.0 | Point transparency for dense datasets |
show | All plot functions | True | Set False to get matplotlib figure for saving |
clustering | beeswarm | None | shap.utils.hclust(...) to cluster correlated features |
SHAP values have three theoretical guarantees (unique among explanation methods):
prediction = base_value + sum(SHAP values) — exact decompositionInterpretation: Positive SHAP → pushes prediction higher; Negative → lower; Magnitude → strength of impact.
Understand what your model outputs — SHAP explains the output space:
model_output="probability" for probability explanations| Method | Local | Global | Consistent | Model-agnostic |
|---|---|---|---|---|
| SHAP | Yes | Yes | Yes | Yes |
| Permutation importance | No | Yes | No | Yes |
| Gini/split importance | No | Yes | No | Trees only |
| LIME | Yes | No | No | Yes |
| Integrated Gradients | Yes | No | Partial | NN only |
shap_interaction = explainer.shap_interaction_values(X_test)
# Shape: (n_samples, n_features, n_features)
# Diagonal = main effects; off-diagonal = pairwise interactions
Background data establishes the baseline (expected model output). Selection affects SHAP magnitudes but not relative importance.
shap.kmeans(X_train, 50) for efficient summarizationtree_path_dependent: no background data needed (uses tree structure)import numpy as np
# Find misclassified samples
predictions = model.predict(X_test)
errors = predictions != y_test
error_indices = np.where(errors)[0]
# Explain errors
for idx in error_indices[:3]:
print(f"Sample {idx}: predicted={predictions[idx]}, actual={y_test.iloc[idx]}")
shap.plots.waterfall(shap_values[idx])
# Check for data leakage: unexpected high-importance features
mean_abs_shap = np.abs(shap_values.values).mean(0)
top_features = X_test.columns[mean_abs_shap.argsort()[-5:]]
print(f"Top features (check for leakage): {list(top_features)}")
# Compare SHAP distributions across groups
group_a = shap_values[X_test["Sex"] == 0]
group_b = shap_values[X_test["Sex"] == 1]
shap.plots.bar({"Female": group_a, "Male": group_b})
# Check protected attribute importance
sex_importance = np.abs(shap_values[:, "Sex"].values).mean()
total_importance = np.abs(shap_values.values).mean()
print(f"Sex contribution: {sex_importance/total_importance:.1%} of total importance")
import joblib
# Save explainer for reuse
joblib.dump(explainer, 'explainer.pkl')
explainer = joblib.load('explainer.pkl')
# Batch computation for API responses
def explain_batch(X_batch, explainer, top_n=5):
sv = explainer(X_batch)
results = []
for i in range(len(X_batch)):
top_idx = np.abs(sv.values[i]).argsort()[-top_n:]
results.append({
'prediction': sv.base_values[i] + sv.values[i].sum(),
'top_features': {X_batch.columns[j]: sv.values[i][j] for j in top_idx}
})
return results
import mlflow
import matplotlib.pyplot as plt
with mlflow.start_run():
model = xgb.XGBClassifier().fit(X_train, y_train)
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test)
shap.plots.beeswarm(shap_values, show=False)
mlflow.log_figure(plt.gcf(), "shap_beeswarm.png")
plt.close()
for feat, imp in zip(X_test.columns, np.abs(shap_values.values).mean(0)):
mlflow.log_metric(f"shap_{feat}", imp)
| Output | Type | Description |
|---|---|---|
shap_values | shap.Explanation | Object with .values (n_samples, n_features), .base_values (baseline), .data (input features) |
| Waterfall plot | matplotlib figure | Single-instance explanation showing feature contributions from base value to prediction |
| Beeswarm plot | matplotlib figure | Global summary: feature importance × direction for all samples |
| Bar plot | matplotlib figure | Mean absolute SHAP values per feature (global importance ranking) |
| Force plot | HTML/matplotlib | Interactive or static visualization of a single prediction |
mean_abs_shap | pd.Series | Per-feature mean absolute SHAP value for ranking and reporting |
| Problem | Cause | Solution |
|---|---|---|
| Very slow computation | Using KernelExplainer for tree model | Use TreeExplainer for tree-based models |
| Slow on large dataset | Computing all samples at once | Sample subset: explainer(X_test[:1000]) or batch |
| SHAP values don't sum to prediction | Wrong model output type | Check model_output parameter; verify additivity |
| Log-odds vs probability confusion | Tree classifier defaults to log-odds | Use TreeExplainer(model, model_output="probability") |
| Plots too cluttered | Too many features shown | Set max_display=10 or use feature clustering |
| DeepExplainer error | Background data too small | Use 100-1000 background samples |
| Memory error | Large dataset + many features | Reduce background data with shap.kmeans(X, 50) |
| Force plot not rendering | Missing JS in notebook | Run shap.initjs() at notebook start |
| Inconsistent importance across runs | KernelExplainer sampling variance | Increase nsamples or use deterministic explainer |
| Negative importance for relevant feature | Feature interactions or correlations | Use feature_perturbation="interventional" or scatter plots |
references/theory.md — Mathematical foundations: Shapley value formula, key properties (additivity, symmetry, dummy, monotonicity), computation algorithms (Tree SHAP, Kernel SHAP, Deep SHAP, Linear SHAP), conditional expectations (interventional vs observational), comparison with LIME/DeepLIFT/LRP/Integrated Gradients, interaction values, theoretical limitationsNot migrated from original: references/explainers.md (340 lines) — detailed constructor parameters, methods, and performance benchmarks for each explainer class. Explainer selection guide and common usage are covered inline in Workflow Step 1 and Key Parameters.
Not migrated from original: references/plots.md (508 lines) — comprehensive parameter reference for all 9 plot types with advanced customization (violin, decision, feature clustering). Main plot types are covered inline in Workflow Steps 3-6.
Not migrated from original: references/workflows.md (606 lines) — detailed step-by-step workflows for feature engineering, model comparison, deep learning explanation, production deployment, and time series. Core patterns are covered in Common Recipes; consult original for extended workflows.
TreeExplainer > LinearExplainer > DeepExplainer > KernelExplainer. Only use model-agnostic explainers when no specialized one existsshap.kmeans() for efficiencyfeature_perturbation="interventional" for causal interpretation or feature clustering for grouped importancenpx claudepluginhub jaechang-hits/sciagent-skills --plugin sciagent-skillsExplains ML model predictions using SHAP values. Computes feature importance and generates waterfall, beeswarm, bar, scatter, force, and heatmap plots for tree-based, deep learning, linear, and black-box models.
Explains ML model predictions with SHAP — computes feature importance and generates waterfall, beeswarm, bar, scatter, force, and heatmap plots. Works with XGBoost, LightGBM, PyTorch, TensorFlow, and other models.