Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions catboost_shap_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Test script to reproduce some issues with CatBoost."""

from __future__ import annotations

import numpy as np
import shap
from catboost import CatBoostClassifier, CatBoostRegressor, Pool

from shapiq.explainer.tree import TreeExplainer


def do_classification() -> None:
"""Test script to reproduce some issues with CatBoost."""
# data example from the CatBoost documentation
train_data_cl = np.random.randint(0, 100, size=(100, 10)) # noqa: NPY002
train_labels_cl = np.random.randint(0, 3, size=(100)) # noqa: NPY002
test_data_cl = Pool(train_data_cl, train_labels_cl)

model_cl = CatBoostClassifier(iterations=2, depth=2, learning_rate=1)

# train the model
model_cl.fit(train_data_cl, train_labels_cl)

# evaluate using shap
explainer_shap_cl = shap.TreeExplainer(model_cl)
original_shap_values_cl = explainer_shap_cl(train_data_cl)
print(original_shap_values_cl) # noqa: T201

# evaluate using built in shap values
build_in_shap_values_cl = model_cl.get_feature_importance(data=test_data_cl, type="ShapValues")
print(build_in_shap_values_cl) # noqa: T201

# evaluate using shapiq
explainer_shapiq_cl = TreeExplainer(
model=model_cl,
max_order=1,
index="SV",
class_index=train_labels_cl,
)
# raises divide by zero warning!!
sv_shapiq_cl = explainer_shapiq_cl.explain(x=train_data_cl[0])
sv_shapiq_values_cl = sv_shapiq_cl.get_n_order_values(1)
print(sv_shapiq_values_cl) # noqa: T201


def do_regression() -> None:
"""Test script to reproduce some issues with CatBoost."""
print("Testing CatBoost regression...") # noqa: T201
# data example from the CatBoost documentation
train_data_reg = np.random.randint(0, 100, size=(100, 10)) # noqa: NPY002
train_labels_reg = np.random.randint(0, 1000, size=(100)) # noqa: NPY002
test_data_reg = np.random.randint(0, 100, size=(50, 10)) # noqa: NPY002

# initialize Pool
train_pool_reg = Pool(train_data_reg, train_labels_reg, cat_features=[0, 2, 5])
test_pool_reg = Pool(test_data_reg, cat_features=[0, 2, 5])

# train the model
model_reg = CatBoostRegressor(iterations=2, depth=2, learning_rate=1)
model_reg.fit(train_pool_reg)

# evaluate using shap
explainer_shap_reg = shap.TreeExplainer(model_reg)
original_shap_values_reg = explainer_shap_reg(train_data_reg[0:1])
print("Original SHAP values:", original_shap_values_reg) # noqa: T201

# evaluate using built in shap values
build_in_shap_values_reg = model_reg.get_feature_importance(
data=test_pool_reg, type="ShapValues"
)
print("Built-in SHAP values:", build_in_shap_values_reg) # noqa: T201

# evaluate using shapiq
explainer_shapiq_reg = TreeExplainer(
model=model_reg,
max_order=1,
index="SV",
)
# raises invalid value in divide warning!!
sv_shapiq_reg = explainer_shapiq_reg.explain(x=train_data_reg[0])
print("SHAPIQ SV values:", sv_shapiq_reg.get_n_order_values(1)) # noqa: T201


if __name__ == "__main__":
do_classification()
do_regression()
295,247 changes: 47 additions & 295,200 deletions docs/source/notebooks/basics_notebooks/parallel_computation.ipynb

Large diffs are not rendered by default.

15 changes: 12 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ games = [
"lightgbm",
"transformers",
"scikit-image",
# tf only for python < 3.13 and not windows
"tensorflow; python_version < '3.13' and platform_system != 'Windows'",
"tf-keras; python_version < '3.13' and platform_system != 'Windows'",
]

[tool.pytest.ini_options]
Expand Down Expand Up @@ -225,11 +222,23 @@ exclude = ["tests", "docs", "benchmark", "scripts", "src/shapiq/plot", "src/shap
pythonVersion = "3.10"

[dependency-groups]
ml = [
"catboost",
"tabpfn>=2.0.4",
"torchvision",
"torch",
"xgboost",
"lightgbm",
"transformers",
"tensorflow; python_version < '3.13' and platform_system != 'Windows'",
"tf-keras; python_version < '3.13' and platform_system != 'Windows'",
]
test = [
"pytest>=8.3.5",
"pytest-cov>=6.0.0",
"pytest-xdist>=3.6.1",
"packaging>=24.2",
{include-group = "ml"},
]
lint = [
"ruff>=0.11.2",
Expand Down
140 changes: 140 additions & 0 deletions src/shapiq/explainer/tree/conversion/catboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""Functions for converting catboost decision trees to the format used by shapiq."""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from shapiq.explainer.tree.base import TreeModel
from shapiq.utils import safe_isinstance

if TYPE_CHECKING:
from shapiq.typing import Model

SUPPORTED_CATBOOST_MODELS = {"catboost.CatBoostClassifier", "catboost.CatBoostRegressor"}


def convert_catboost(
tree_model: Model,
class_label: int | None = None,
) -> list[TreeModel]:
"""Transforms models from the catboost package to the format used by shapiq.

Note: part of this implementation is taken and adapted from the shap package, where it can be found in shap/explainers/_tree.py.

Args:
tree_model: The catboost model to convert.
class_label: The class label of the model to explain. Only used for classification models.

Returns:
The converted catboost model.

"""
output_type = "raw"
model_type = "undefined"

if safe_isinstance(tree_model, "catboost.CatBoostClassifier"):
model_type = "classifier"
elif safe_isinstance(tree_model, "catboost.CatBoostRegressor"):
model_type = "regressor"
else:
msg = f"Unsupported model type. Supported mode types are {SUPPORTED_CATBOOST_MODELS}"
raise ValueError(msg)

# workaround to get the single trees in the ensemble
import json
import tempfile
from pathlib import Path

with tempfile.TemporaryDirectory() as tmp_dir:
tmp_file = Path(tmp_dir) / "model.json"
tree_model.save_model(tmp_file, format="json")
with Path.open(tmp_file, encoding="utf-8") as fh:
loaded_cb_model = json.load(fh)

num_trees = len(loaded_cb_model["oblivious_trees"])

# determine number of classes or set to 1 for regression
if model_type == "classifier":
num_classes = len(loaded_cb_model["model_info"]["class_params"]["class_names"])
elif model_type == "regressor":
num_classes = 1

trees = []
for tree_index in range(num_trees):
# get the values for all nodes
leaf_values_json = loaded_cb_model["oblivious_trees"][tree_index]["leaf_values"]
node_values = [0] * (len(leaf_values_json) - num_classes) + leaf_values_json
node_values = np.array(node_values)
total_nodes = int(len(node_values) / num_classes)
node_values = node_values.reshape((-1, num_classes)) # reshape to match number of classes

# get the children
children_left = list(range(1, total_nodes, 2))
children_left += [-1] * (total_nodes - len(children_left))
children_right = list(range(2, total_nodes + 1, 2))
children_right += [-1] * (total_nodes - len(children_right))

# add a weight to each node
leaf_weights_json = loaded_cb_model["oblivious_trees"][tree_index]["leaf_weights"]
leaf_weights = [0] * (len(leaf_weights_json) - 1) + leaf_weights_json
leaf_weights[0] = sum(leaf_weights_json)
for index in range(len(leaf_weights_json) - 2, 0, -1):
# each node weight is the sum of its children
leaf_weights[index] = leaf_weights[2 * index + 1] + leaf_weights[2 * index + 2]

# split features and borders from leafs to the root
split_features_index_json = []
borders_json = []

for elem in loaded_cb_model["oblivious_trees"][tree_index]["splits"]:
split_type = elem.get("split_type")
if split_type == "FloatFeature":
split_feature_index = elem.get("float_feature_index")
borders_json.append(elem["border"])
elif split_type == "OneHotFeature":
split_feature_index = elem.get("cat_feature_index")
borders_json.append(elem["value"])
else:
split_feature_index = elem.get("ctr_target_border_idx")
borders_json.append(elem["border"])
split_features_index_json.append(split_feature_index)

split_features_index = []
for counter, feature_index in enumerate(split_features_index_json[::-1]):
split_features_index += [feature_index] * (2**counter)
split_features_index += [-2] * (total_nodes - len(split_features_index))

borders = []
for counter, border in enumerate(borders_json[::-1]):
borders += [border] * (2**counter)
borders += [0] * (total_nodes - len(borders))

if model_type == "classifier" and class_label is not None:
class_label = 0

# make probabilities
if class_label is not None:
row_sums = np.sum(node_values, axis=1, keepdims=True)
zero_mask = row_sums == 0
normalized = np.divide(node_values, row_sums, where=~zero_mask) # avoid division by 0
node_values = normalized[:, class_label]
node_values[zero_mask.flatten()] = 0

output_type = "probability"

trees.append(
TreeModel(
children_left=np.array(children_left),
children_right=np.array(children_right),
features=np.array(split_features_index),
thresholds=np.array(borders),
values=node_values,
node_sample_weight=np.array(leaf_weights),
empty_prediction=None, # compute empty prediction later
original_output_type=output_type,
)
)

return trees
7 changes: 7 additions & 0 deletions src/shapiq/explainer/tree/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from shapiq.utils.modules import safe_isinstance

from .base import TreeModel
from .conversion.catboost import convert_catboost
from .conversion.lightgbm import convert_lightgbm_booster
from .conversion.sklearn import (
convert_sklearn_forest,
Expand All @@ -19,6 +20,8 @@
from shapiq.typing import Model

SUPPORTED_MODELS = {
"catboost.CatBoostClassifier",
"catboost.CatBoostRegressor",
"sklearn.tree.DecisionTreeRegressor",
"sklearn.tree._classes.DecisionTreeRegressor",
"sklearn.tree.DecisionTreeClassifier",
Expand Down Expand Up @@ -105,6 +108,10 @@ def validate_tree_model(
"xgboost.sklearn.XGBClassifier",
):
tree_model = convert_xgboost_booster(model, class_label=class_label)
elif safe_isinstance(model, "catboost.CatBoostRegressor") or safe_isinstance(
model, "catboost.CatBoostClassifier"
):
tree_model = convert_catboost(model, class_label=class_label)
# unsupported model
else:
msg = f"Unsupported model type.Supported models are: {SUPPORTED_MODELS}"
Expand Down
31 changes: 29 additions & 2 deletions tests/shapiq/fixtures/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,9 @@ def if_clf_model(if_clf_dataset) -> IsolationForest:
def et_clf_model(background_clf_dataset) -> Model:
"""Return a simple (classification) extra trees model."""
X, y = background_clf_dataset
model = ExtraTreesClassifier(random_state=RANDOM_SEED_MODELS, max_depth=3, n_estimators=3)
model = ExtraTreesClassifier(
random_state=RANDOM_SEED_MODELS, max_depth=3, n_estimators=3, verbose=False
)
model.fit(X, y)
return model

Expand All @@ -318,7 +320,32 @@ def et_clf_model(background_clf_dataset) -> Model:
def et_reg_model(background_reg_dataset) -> Model:
"""Return a simple (regression) extra trees model."""
X, y = background_reg_dataset
model = ExtraTreesRegressor(random_state=RANDOM_SEED_MODELS, max_depth=3, n_estimators=3)
model = ExtraTreesRegressor(
random_state=RANDOM_SEED_MODELS, max_depth=3, n_estimators=3, verbose=False
)
model.fit(X, y)
return model


# CatBoost model
@pytest.fixture
def cb_clf_model(background_clf_dataset) -> Model:
"""Return a simple CatBoost classification model."""
catboost = pytest.importorskip("catboost")

X, y = background_clf_dataset
model = catboost.CatBoostClassifier(depth=3, random_state=42, n_estimators=3)
model.fit(X, y)
return model


@pytest.fixture
def cb_reg_model(background_reg_dataset) -> Model:
"""Return a simple CatBoost regression model."""
catboost = pytest.importorskip("catboost")

X, y = background_reg_dataset
model = catboost.CatBoostRegressor(depth=3, random_state=42, n_estimators=3)
model.fit(X, y)
return model

Expand Down
Loading
Loading