Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 68 additions & 25 deletions eli5/lightgbm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division
from collections import defaultdict
from typing import DefaultDict
from typing import DefaultDict, Any, Tuple

import numpy as np # type: ignore
import lightgbm # type: ignore
Expand All @@ -17,7 +17,7 @@
all values sum to 1.
"""


@explain_weights.register(lightgbm.Booster)
@explain_weights.register(lightgbm.LGBMClassifier)
@explain_weights.register(lightgbm.LGBMRegressor)
def explain_weights_lightgbm(lgb,
Expand All @@ -32,7 +32,7 @@ def explain_weights_lightgbm(lgb,
):
"""
Return an explanation of an LightGBM estimator (via scikit-learn wrapper
LGBMClassifier or LGBMRegressor) as feature importances.
LGBMClassifier or LGBMRegressor, or via lightgbm.Booster) as feature importances.

See :func:`eli5.explain_weights` for description of
``top``, ``feature_names``,
Expand All @@ -51,8 +51,9 @@ def explain_weights_lightgbm(lgb,
across all trees
- 'weight' - the same as 'split', for compatibility with xgboost
"""
coef = _get_lgb_feature_importances(lgb, importance_type)
lgb_feature_names = lgb.booster_.feature_name()
booster, is_regression = _check_booster_args(lgb)
coef = _get_lgb_feature_importances(booster, importance_type)
lgb_feature_names = booster.feature_name()
return get_feature_importance_explanation(lgb, vec, coef,
feature_names=feature_names,
estimator_feature_names=lgb_feature_names,
Expand All @@ -64,7 +65,7 @@ def explain_weights_lightgbm(lgb,
is_regression=isinstance(lgb, lightgbm.LGBMRegressor),
)


@explain_prediction.register(lightgbm.Booster)
@explain_prediction.register(lightgbm.LGBMClassifier)
@explain_prediction.register(lightgbm.LGBMRegressor)
def explain_prediction_lightgbm(
Expand All @@ -80,7 +81,7 @@ def explain_prediction_lightgbm(
vectorized=False,
):
""" Return an explanation of LightGBM prediction (via scikit-learn wrapper
LGBMClassifier or LGBMRegressor) as feature weights.
LGBMClassifier or LGBMRegressor, or via lightgbm.Booster) as feature weights.

See :func:`eli5.explain_prediction` for description of
``top``, ``top_targets``, ``target_names``, ``targets``,
Expand Down Expand Up @@ -108,20 +109,48 @@ def explain_prediction_lightgbm(
Weights of all features sum to the output score of the estimator.
"""

vec, feature_names = handle_vec(lgb, doc, vec, vectorized, feature_names)
booster, is_regression = _check_booster_args(lgb)
lgb_feature_names = booster.feature_name()
vec, feature_names = handle_vec(lgb, doc, vec, vectorized, feature_names,
num_features=len(lgb_feature_names))
if feature_names.bias_name is None:
# LightGBM estimators do not have an intercept, but here we interpret
# them as having an intercept
feature_names.bias_name = '<BIAS>'
X = get_X(doc, vec, vectorized=vectorized)

if isinstance(lgb, lightgbm.Booster):
prediction = lgb.predict(X)
n_targets = prediction.shape[-1]
if is_regression is None and target_names is None:
# When n_targets is 1, this can be classification too,
# but it's safer to assume regression.
# If n_targets > 1, it must be classification.
is_regression = n_targets == 1
elif is_regression is None:
is_regression = len(target_names) == 1

if is_regression:
proba = None
else:
if n_targets == 1:
p, = prediction
proba = np.array([1 - p, p])
else:
proba, = prediction
else:
proba = predict_proba(lgb, X)
n_targets = _lgb_n_targets(lgb)

proba = predict_proba(lgb, X)
weight_dicts = _get_prediction_feature_weights(lgb, X, _lgb_n_targets(lgb))
x = get_X0(add_intercept(X))
if is_regression:
names = ['y']
elif isinstance(lgb, lightgbm.Booster):
names = np.arange(max(2, n_targets))
else:
names = lgb.classes_

is_regression = isinstance(lgb, lightgbm.LGBMRegressor)
is_multiclass = _lgb_n_targets(lgb) > 2
names = lgb.classes_ if not is_regression else ['y']
weight_dicts = _get_prediction_feature_weights(booster, X, n_targets)
x = get_X0(add_intercept(X))

def get_score_weights(_label_id):
_weights = _target_feature_weights(
Expand All @@ -145,22 +174,38 @@ def get_score_weights(_label_id):
targets=targets,
top_targets=top_targets,
is_regression=is_regression,
is_multiclass=is_multiclass,
is_multiclass=n_targets > 1,
proba=proba,
get_score_weights=get_score_weights,
)


def _check_booster_args(lgb, is_regression=None):
# type: (Any, bool) -> Tuple[lightgbm.Booster, bool]
if isinstance(lgb, lightgbm.Booster):
booster = lgb
else:
booster = lgb.booster_
_is_regression = isinstance(lgb, lightgbm.LGBMRegressor)
if is_regression is not None and is_regression != _is_regression:
raise ValueError(
'Inconsistent is_regression={} passed. '
'You don\'t have to pass it when using scikit-learn API'
.format(is_regression))
is_regression = _is_regression
return booster, is_regression

def _lgb_n_targets(lgb):
if isinstance(lgb, lightgbm.LGBMClassifier):
return lgb.n_classes_
else:
return 1 if lgb.n_classes_ == 2 else lgb.n_classes_
elif isinstance(lgb, lightgbm.LGBMRegressor):
return 1
else:
raise TypeError


def _get_lgb_feature_importances(lgb, importance_type):
def _get_lgb_feature_importances(booster, importance_type):
aliases = {'weight': 'split'}
coef = lgb.booster_.feature_importance(
coef = booster.feature_importance(
importance_type=aliases.get(importance_type, importance_type)
)
norm = coef.sum()
Expand Down Expand Up @@ -237,17 +282,15 @@ def walk(tree, parent_id=-1):
return leaf_index, split_index


def _get_prediction_feature_weights(lgb, X, n_targets):
def _get_prediction_feature_weights(booster, X, n_targets):
"""
Return a list of {feat_id: value} dicts with feature weights,
following ideas from http://blog.datadive.net/interpreting-random-forests/
"""
if n_targets == 2:
n_targets = 1
dump = lgb.booster_.dump_model()
dump = booster.dump_model()
tree_info = dump['tree_info']
_compute_node_values(tree_info)
pred_leafs = lgb.booster_.predict(X, pred_leaf=True).reshape(-1, n_targets)
pred_leafs = booster.predict(X, pred_leaf=True).reshape(-1, n_targets)
tree_info = np.array(tree_info).reshape(-1, n_targets)
assert pred_leafs.shape == tree_info.shape

Expand Down
130 changes: 130 additions & 0 deletions tests/test_lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
import lightgbm
from lightgbm import LGBMClassifier, LGBMRegressor

from eli5 import explain_weights, explain_prediction
from eli5.lightgbm import _check_booster_args, _lgb_n_targets
from .test_sklearn_explain_weights import (
test_explain_tree_classifier as _check_rf_classifier,
test_explain_random_forest_and_tree_feature_filter as _check_rf_feature_filter,
Expand All @@ -18,6 +20,7 @@
)
from .test_sklearn_explain_prediction import (
assert_linear_regression_explained,
assert_trained_linear_regression_explained,
test_explain_prediction_pandas as _check_explain_prediction_pandas,
test_explain_clf_binary_iris as _check_binary_classifier,
)
Expand Down Expand Up @@ -144,3 +147,130 @@ def test_explain_weights_feature_names_pandas(boston_train):
res = explain_weights(reg, feature_names=numeric_feature_names)
for expl in format_as_all(res, reg):
assert 'zz12' in expl


def test_check_booster_args():
x, y = np.random.random((10, 2)), np.random.randint(2, size=10)
regressor = LGBMRegressor(min_data=1).fit(x, y)
classifier = LGBMClassifier(min_data=1).fit(x, y)

booster, is_regression = _check_booster_args(regressor)
assert is_regression == True
assert isinstance(booster, lightgbm.Booster)
_, is_regression = _check_booster_args(regressor, is_regression=True)
assert is_regression == True
_, is_regression = _check_booster_args(classifier)
assert is_regression == False
_, is_regression = _check_booster_args(classifier, is_regression=False)
assert is_regression == False
with pytest.raises(ValueError):
_check_booster_args(classifier, is_regression=True)
with pytest.raises(ValueError):
_check_booster_args(regressor, is_regression=False)

booster = regressor.booster_
_booster, is_regression = _check_booster_args(booster)
assert _booster is booster
assert is_regression is None
_, is_regression = _check_booster_args(booster, is_regression=True)
assert is_regression == True

booster = classifier.booster_
_booster, is_regression = _check_booster_args(booster)
assert _booster is booster
assert is_regression is None
_, is_regression = _check_booster_args(booster, is_regression=False)
assert is_regression == False

def test_explain_lightgbm_booster(boston_train):
xs, ys, feature_names = boston_train
booster = lightgbm.train(
params={'objective': 'regression', 'verbose_eval': -1},
train_set=lightgbm.Dataset(xs, label=ys),
)
res = explain_weights(booster)
for expl in format_as_all(res, booster):
assert 'Column_12' in expl
res = explain_weights(booster, feature_names=feature_names)
for expl in format_as_all(res, booster):
assert 'LSTAT' in expl

def test_explain_prediction_reg_booster(boston_train):
X, y, feature_names = boston_train
booster = lightgbm.train(
params={'objective': 'regression', 'verbose_eval': -1},
train_set=lightgbm.Dataset(X, label=y),
)
assert_trained_linear_regression_explained(
X[0], feature_names, booster, explain_prediction,
reg_has_intercept=True)

def test_explain_prediction_booster_multitarget(newsgroups_train):
docs, ys, target_names = newsgroups_train
vec = CountVectorizer(stop_words='english', dtype=np.float64)
xs = vec.fit_transform(docs)
clf = lightgbm.train(
params={'objective': 'multiclass', 'verbose_eval': -1, 'max_depth': 2,'n_estimators':100,
'min_child_samples':1, 'min_child_weight':1,
'num_class': len(target_names)},
train_set=lightgbm.Dataset(xs.toarray(), label=ys))

doc = 'computer graphics in space: a new religion'
res = explain_prediction(clf, doc, vec=vec, target_names=target_names)
format_as_all(res, clf)
check_targets_scores(res)
graphics_weights = res.targets[1].feature_weights
assert 'computer' in get_all_features(graphics_weights.pos)
religion_weights = res.targets[3].feature_weights
assert 'religion' in get_all_features(religion_weights.pos)

top_target_res = explain_prediction(clf, doc, vec=vec, top_targets=2)
assert len(top_target_res.targets) == 2
assert sorted(t.proba for t in top_target_res.targets) == sorted(
t.proba for t in res.targets)[-2:]

def test_explain_prediction_booster_binary(
newsgroups_train_binary_big):
docs, ys, target_names = newsgroups_train_binary_big
vec = CountVectorizer(stop_words='english', dtype=np.float64)
xs = vec.fit_transform(docs)
explain_kwargs = {}
clf = lightgbm.train(
params={'objective': 'binary', 'verbose_eval': -1, 'max_depth': 2,'n_estimators':100,
'min_child_samples':1, 'min_child_weight':1},
train_set=lightgbm.Dataset(xs.toarray(), label=ys))

get_res = lambda **kwargs: explain_prediction(
clf, 'computer graphics in space: a sign of atheism',
vec=vec, target_names=target_names, **kwargs)
res = get_res()
for expl in format_as_all(res, clf, show_feature_values=True):
assert 'graphics' in expl
check_targets_scores(res)
weights = res.targets[0].feature_weights
pos_features = get_all_features(weights.pos)
neg_features = get_all_features(weights.neg)
assert 'graphics' in pos_features
assert 'computer' in pos_features
assert 'atheism' in neg_features

flt_res = get_res(feature_re='gra')
flt_pos_features = get_all_features(flt_res.targets[0].feature_weights.pos)
assert 'graphics' in flt_pos_features
assert 'computer' not in flt_pos_features

def test_lgb_n_targets():
clf = LGBMClassifier(min_data=1)
clf.fit(np.array([[0], [1]]), np.array([0, 1]))
assert _lgb_n_targets(clf) == 1

clf = LGBMClassifier(min_data=1)
clf.fit(np.array([[0], [1], [2]]), np.array([0, 1, 2]))
assert _lgb_n_targets(clf) == 3

reg = LGBMRegressor(min_data=1)
reg.fit(np.array([[0], [1], [2]]), np.array([0, 1, 2]))
assert _lgb_n_targets(reg) == 1

with pytest.raises(TypeError):
_lgb_n_targets(object())