Skip to content

Commit eb4dbfd

Browse files
[python-package] Add tests for plotting ylim, xlim, max_num_features, and pre-allocated ax
1 parent 80ab6d3 commit eb4dbfd

File tree

1 file changed

+112
-22
lines changed

1 file changed

+112
-22
lines changed

tests/python_package_test/test_plotting.py

Lines changed: 112 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
from sklearn.model_selection import train_test_split
66

77
import lightgbm as lgb
8-
from lightgbm.compat import GRAPHVIZ_INSTALLED, MATPLOTLIB_INSTALLED, PANDAS_INSTALLED, pd_DataFrame
8+
from lightgbm.compat import (
9+
GRAPHVIZ_INSTALLED,
10+
MATPLOTLIB_INSTALLED,
11+
PANDAS_INSTALLED,
12+
pd_DataFrame,
13+
)
914

1015
if MATPLOTLIB_INSTALLED:
1116
import matplotlib
@@ -19,17 +24,23 @@
1924

2025
@pytest.fixture(scope="module")
2126
def breast_cancer_split():
22-
return train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1, random_state=1)
27+
return train_test_split(
28+
*load_breast_cancer(return_X_y=True), test_size=0.1, random_state=1
29+
)
2330

2431

2532
def _categorical_data(category_values_lower_bound, category_values_upper_bound):
2633
X, y = load_breast_cancer(return_X_y=True)
2734
X_df = pd.DataFrame()
2835
rnd = np.random.RandomState(0)
29-
n_cat_values = rnd.randint(category_values_lower_bound, category_values_upper_bound, size=X.shape[1])
36+
n_cat_values = rnd.randint(
37+
category_values_lower_bound, category_values_upper_bound, size=X.shape[1]
38+
)
3039
for i in range(X.shape[1]):
3140
bins = np.linspace(0, 1, num=n_cat_values[i] + 1)
32-
X_df[f"cat_col_{i}"] = pd.qcut(X[:, i], q=bins, labels=range(n_cat_values[i])).as_unordered()
41+
X_df[f"cat_col_{i}"] = pd.qcut(
42+
X[:, i], q=bins, labels=range(n_cat_values[i])
43+
).as_unordered()
3344
return X_df, y
3445

3546

@@ -68,7 +79,9 @@ def test_plot_importance(params, breast_cancer_split, train_data):
6879
for patch in ax1.patches:
6980
assert patch.get_facecolor() == (1.0, 0, 0, 1.0) # red
7081

71-
ax2 = lgb.plot_importance(gbm0, color=["r", "y", "g", "b"], title=None, xlabel=None, ylabel=None)
82+
ax2 = lgb.plot_importance(
83+
gbm0, color=["r", "y", "g", "b"], title=None, xlabel=None, ylabel=None
84+
)
7285
assert isinstance(ax2, matplotlib.axes.Axes)
7386
assert ax2.get_title() == ""
7487
assert ax2.get_xlabel() == ""
@@ -80,7 +93,10 @@ def test_plot_importance(params, breast_cancer_split, train_data):
8093
assert ax2.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # b
8194

8295
ax3 = lgb.plot_importance(
83-
gbm0, title="t @importance_type@", xlabel="x @importance_type@", ylabel="y @importance_type@"
96+
gbm0,
97+
title="t @importance_type@",
98+
xlabel="x @importance_type@",
99+
ylabel="y @importance_type@",
84100
)
85101
assert isinstance(ax3, matplotlib.axes.Axes)
86102
assert ax3.get_title() == "t @importance_type@"
@@ -97,20 +113,53 @@ def test_plot_importance(params, breast_cancer_split, train_data):
97113
assert len(ax4.patches) <= 30
98114

99115
with pytest.raises(TypeError, match="xlim must be a tuple of 2 elements."):
100-
lgb.plot_importance(gbm0, title=None, xlabel=None, ylabel=None, xlim="not a tuple")
116+
lgb.plot_importance(
117+
gbm0, title=None, xlabel=None, ylabel=None, xlim="not a tuple"
118+
)
101119

102-
gbm2 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, verbose=-1, importance_type="gain")
120+
# test ylim parameter
121+
ax5 = lgb.plot_importance(gbm0, title=None, xlabel=None, ylabel=None, ylim=(-1, 30))
122+
assert isinstance(ax5, matplotlib.axes.Axes)
123+
assert ax5.get_ylim() == (-1, 30)
124+
125+
with pytest.raises(TypeError, match="ylim must be a tuple of 2 elements."):
126+
lgb.plot_importance(gbm0, ylim="not a tuple")
127+
128+
# test max_num_features parameter
129+
ax6 = lgb.plot_importance(gbm0, max_num_features=5)
130+
assert isinstance(ax6, matplotlib.axes.Axes)
131+
assert len(ax6.patches) == 5
132+
133+
# test providing pre-allocated ax with figsize
134+
fig, ax_prealloc = matplotlib.pyplot.subplots(1, 1, figsize=(12, 8))
135+
ax7 = lgb.plot_importance(gbm0, ax=ax_prealloc, figsize=(6, 4))
136+
assert ax7 is ax_prealloc
137+
# when ax is provided, figsize should be ignored, so figure size remains (12, 8)
138+
assert ax7.get_figure().get_figwidth() == 12
139+
assert ax7.get_figure().get_figheight() == 8
140+
141+
gbm2 = lgb.LGBMClassifier(
142+
n_estimators=10, num_leaves=3, verbose=-1, importance_type="gain"
143+
)
103144
gbm2.fit(X_train, y_train)
104145

105146
def get_bounds_of_first_patch(axes):
106147
return axes.patches[0].get_extents().bounds
107148

108149
first_bar1 = get_bounds_of_first_patch(lgb.plot_importance(gbm1))
109-
first_bar2 = get_bounds_of_first_patch(lgb.plot_importance(gbm1, importance_type="split"))
110-
first_bar3 = get_bounds_of_first_patch(lgb.plot_importance(gbm1, importance_type="gain"))
150+
first_bar2 = get_bounds_of_first_patch(
151+
lgb.plot_importance(gbm1, importance_type="split")
152+
)
153+
first_bar3 = get_bounds_of_first_patch(
154+
lgb.plot_importance(gbm1, importance_type="gain")
155+
)
111156
first_bar4 = get_bounds_of_first_patch(lgb.plot_importance(gbm2))
112-
first_bar5 = get_bounds_of_first_patch(lgb.plot_importance(gbm2, importance_type="split"))
113-
first_bar6 = get_bounds_of_first_patch(lgb.plot_importance(gbm2, importance_type="gain"))
157+
first_bar5 = get_bounds_of_first_patch(
158+
lgb.plot_importance(gbm2, importance_type="split")
159+
)
160+
first_bar6 = get_bounds_of_first_patch(
161+
lgb.plot_importance(gbm2, importance_type="gain")
162+
)
114163

115164
assert first_bar1 == first_bar2
116165
assert first_bar1 == first_bar5
@@ -153,7 +202,13 @@ def test_plot_split_value_histogram(params, breast_cancer_split, train_data):
153202
assert patch.get_facecolor() == (1.0, 0, 0, 1.0) # red
154203

155204
ax2 = lgb.plot_split_value_histogram(
156-
gbm0, 27, bins=10, color=["r", "y", "g", "b"], title=None, xlabel=None, ylabel=None
205+
gbm0,
206+
27,
207+
bins=10,
208+
color=["r", "y", "g", "b"],
209+
title=None,
210+
xlabel=None,
211+
ylabel=None,
157212
)
158213
assert isinstance(ax2, matplotlib.axes.Axes)
159214
assert ax2.get_title() == ""
@@ -165,14 +220,22 @@ def test_plot_split_value_histogram(params, breast_cancer_split, train_data):
165220
assert ax2.patches[2].get_facecolor() == (0, 0.5, 0, 1.0) # g
166221
assert ax2.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # b
167222

223+
# test xlim and ylim parameters
224+
ax3 = lgb.plot_split_value_histogram(gbm0, 27, xlim=(0, 100), ylim=(0, 50))
225+
assert isinstance(ax3, matplotlib.axes.Axes)
226+
assert ax3.get_xlim() == (0, 100)
227+
assert ax3.get_ylim() == (0, 50)
228+
168229
with pytest.raises(
169-
ValueError, match="Cannot plot split value histogram, because feature 0 was not used in splitting"
230+
ValueError,
231+
match="Cannot plot split value histogram, because feature 0 was not used in splitting",
170232
):
171233
lgb.plot_split_value_histogram(gbm0, 0) # was not used in splitting
172234

173235

174236
@pytest.mark.skipif(
175-
not MATPLOTLIB_INSTALLED or not GRAPHVIZ_INSTALLED, reason="matplotlib or graphviz is not installed"
237+
not MATPLOTLIB_INSTALLED or not GRAPHVIZ_INSTALLED,
238+
reason="matplotlib or graphviz is not installed",
176239
)
177240
def test_plot_tree(breast_cancer_split):
178241
X_train, _, y_train, _ = breast_cancer_split
@@ -194,7 +257,9 @@ def test_create_tree_digraph(tmp_path, breast_cancer_split):
194257
X_train, _, y_train, _ = breast_cancer_split
195258

196259
constraints = [-1, 1] * int(X_train.shape[1] / 2)
197-
gbm = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, verbose=-1, monotone_constraints=constraints)
260+
gbm = lgb.LGBMClassifier(
261+
n_estimators=10, num_leaves=3, verbose=-1, monotone_constraints=constraints
262+
)
198263
gbm.fit(X_train, y_train)
199264

200265
with pytest.raises(IndexError, match="tree_index is out of range."):
@@ -389,7 +454,9 @@ def test_example_case_in_tree_digraph():
389454
while "decision_type" in node: # iterate through the splits
390455
split_index = node["split_index"]
391456

392-
node_in_graph = [n for n in gbody if f"split{split_index}" in n and "->" not in n]
457+
node_in_graph = [
458+
n for n in gbody if f"split{split_index}" in n and "->" not in n
459+
]
393460
assert len(node_in_graph) == 1
394461
seen_indices.add(gbody.index(node_in_graph[0]))
395462

@@ -420,14 +487,22 @@ def test_example_case_in_tree_digraph():
420487
assert "color=blue" in leaf_in_graph[0]
421488
assert len(edge_to_leaf) == 1
422489
assert "color=blue" in edge_to_leaf[0]
423-
seen_indices.update([gbody.index(leaf_in_graph[0]), gbody.index(edge_to_leaf[0])])
490+
seen_indices.update(
491+
[gbody.index(leaf_in_graph[0]), gbody.index(edge_to_leaf[0])]
492+
)
424493

425494
# check that the rest of the elements have black color
426-
remaining_elements = [e for i, e in enumerate(graph.body) if i not in seen_indices and "graph" not in e]
495+
remaining_elements = [
496+
e
497+
for i, e in enumerate(graph.body)
498+
if i not in seen_indices and "graph" not in e
499+
]
427500
assert all("color=black" in e for e in remaining_elements)
428501

429502
# check that we got to the expected leaf
430-
expected_leaf = bst.predict(example_case, start_iteration=i, num_iteration=1, pred_leaf=True)[0]
503+
expected_leaf = bst.predict(
504+
example_case, start_iteration=i, num_iteration=1, pred_leaf=True
505+
)[0]
431506
assert leaf_index == expected_leaf
432507
assert makes_categorical_splits
433508

@@ -464,7 +539,9 @@ def test_plot_metrics(params, breast_cancer_split, train_data):
464539
num_boost_round=10,
465540
callbacks=[lgb.record_evaluation(evals_result0)],
466541
)
467-
with pytest.warns(UserWarning, match="More than one metric available, picking one to plot."):
542+
with pytest.warns(
543+
UserWarning, match="More than one metric available, picking one to plot."
544+
):
468545
ax0 = lgb.plot_metric(evals_result0)
469546
assert isinstance(ax0, matplotlib.axes.Axes)
470547
assert ax0.get_title() == "Metric during training"
@@ -521,7 +598,12 @@ def test_plot_metrics(params, breast_cancer_split, train_data):
521598
assert not grid_line.get_visible()
522599

523600
evals_result1 = {}
524-
lgb.train(params, train_data, num_boost_round=10, callbacks=[lgb.record_evaluation(evals_result1)])
601+
lgb.train(
602+
params,
603+
train_data,
604+
num_boost_round=10,
605+
callbacks=[lgb.record_evaluation(evals_result1)],
606+
)
525607
with pytest.raises(ValueError, match="eval results cannot be empty."):
526608
lgb.plot_metric(evals_result1)
527609

@@ -535,3 +617,11 @@ def test_plot_metrics(params, breast_cancer_split, train_data):
535617
legend_items = ax4.get_legend().get_texts()
536618
assert len(legend_items) == 1
537619
assert legend_items[0].get_text() == "valid_0"
620+
621+
# test xlim and ylim parameters
622+
ax5 = lgb.plot_metric(
623+
evals_result0, metric="binary_logloss", xlim=(0, 15), ylim=(0, 1)
624+
)
625+
assert isinstance(ax5, matplotlib.axes.Axes)
626+
assert ax5.get_xlim() == (0, 15)
627+
assert ax5.get_ylim() == (0, 1)

0 commit comments

Comments
 (0)