-
-
Notifications
You must be signed in to change notification settings - Fork 8.8k
Open
Labels
Description
Summary
We have new work underway on shapley values and other related interpretability concepts. This will add new functionality, however the current feature importance/shapley features are included in the predict API. I propose extending the python api with a module (xgboost.interpret) for interpretability, containing stateless functions exposing upcoming features.
These functions accept either a Booster or an sklearn-style XGB* model, plus DMatrix/array-like inputs, and return well-typed results (arrays and or light-weight result objects).
Motivation
- Minimize disruption to existing
Booster/ sklearn APIs while adding interpretability features. - Improve discoverability and documentation (module-level functions are easy to document and test).
- Allow incremental implementation: start as wrappers over existing
predict(pred_contribs=..., pred_interactions=...), then evolve internals (esp. top-k) without changing the public API.
Proposed public API
Add a new module:
xgboost/interpret.py
Functions (accept Booster | XGBModel and DMatrix | array-like | pandas):
shap_values(model, X, *, output_margin=False, iteration_range=None, approx=False, validate_features=True, feature_names=None, return_bias=False)shap_interactions(model, X, *, output_margin=False, iteration_range=None, approx=False, validate_features=True, feature_names=None)topk_interactions(model, X, *, k=50, metric="mean_abs", per_row=False, output_margin=False, iteration_range=None, approx=False, validate_features=True, feature_names=None)partial_dependence(model, X, *, features, grid_resolution=50, percentiles=(0.05,0.95), grid=None, sample=None, random_state=0, output="prediction", iteration_range=None)
Dispatch/behavior notes
- Internally normalize
modelto aBoosterviamodelbeingBoosteror havingget_booster(). - Normalize
XtoDMatrixif needed; respect feature names where possible. - Initial SHAP implementations can wrap existing
Booster.predict(..., pred_contribs=True/pred_interactions=True)for compatibility. topk_interactionsshould ideally avoid materializing full (n, p, p) tensors; target a C++ implementation to compute aggregated top-k pairs efficiently.
Return types
Prefer lightweight result objects to keep outputs consistent and extensible:
ShapValues(values, base_values, feature_names, model_output, ...)ShapInteractions(values, feature_names, ...)with helpers for main effects / pair extractionTopKInteractions(pairs, scores, pair_names=None, per_row=None, ...)PDP(features, grid_values, averages, ...)
Documentation plan (Sphinx)
- Add
docs/python/interpretability.rst- narrative examples + API reference using
.. autofunction::for each function .. autoclass::for result types
- narrative examples + API reference using
trivialfis and ron-wettenstein