Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions tests/python_package_test/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,20 @@ def test_plot_importance(params, breast_cancer_split, train_data):
with pytest.raises(TypeError, match="figsize must be a tuple of 2 elements."):
lgb.plot_importance(gbm0, title=None, xlabel=None, ylabel=None, figsize="not a tuple")

# test max_num_features parameter
total_features = len(gbm0.feature_importance())
assert total_features > 5, "model must have more than 5 features to test max_num_features"
ax7 = lgb.plot_importance(gbm0, max_num_features=5)
assert isinstance(ax7, matplotlib.axes.Axes)
assert len(ax7.patches) == 5
# verify the 5 displayed features are the top 5 by importance
importance = gbm0.feature_importance()
feature_names = gbm0.feature_name()
sorted_pairs = sorted(zip(feature_names, importance), key=lambda x: x[1])
top5_names = [name for name, _ in sorted_pairs[-5:]]
displayed_labels = [label.get_text() for label in ax7.get_yticklabels()]
assert displayed_labels == top5_names

gbm2 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, verbose=-1, importance_type="gain")
gbm2.fit(X_train, y_train)

Expand Down Expand Up @@ -187,6 +201,17 @@ def test_plot_split_value_histogram(params, breast_cancer_split, train_data):
assert ax2.patches[2].get_facecolor() == (0, 0.5, 0, 1.0) # g
assert ax2.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # b

# test xlim parameter
ax3 = lgb.plot_split_value_histogram(gbm0, 27, xlim=(0, 100), title=None, xlabel=None, ylabel=None)
assert isinstance(ax3, matplotlib.axes.Axes)
assert ax3.get_title() == ""
assert ax3.get_xlabel() == ""
assert ax3.get_ylabel() == ""
assert ax3.get_xlim() == (0, 100)

with pytest.raises(TypeError, match="xlim must be a tuple of 2 elements."):
lgb.plot_split_value_histogram(gbm0, 27, xlim="not a tuple")

with pytest.raises(
ValueError, match="Cannot plot split value histogram, because feature 0 was not used in splitting"
):
Expand Down Expand Up @@ -557,3 +582,14 @@ def test_plot_metrics(params, breast_cancer_split, train_data):
legend_items = ax4.get_legend().get_texts()
assert len(legend_items) == 1
assert legend_items[0].get_text() == "valid_0"

# test xlim parameter
ax5 = lgb.plot_metric(evals_result0, metric="binary_logloss", xlim=(0, 15), title=None, xlabel=None, ylabel=None)
assert isinstance(ax5, matplotlib.axes.Axes)
assert ax5.get_title() == ""
assert ax5.get_xlabel() == ""
assert ax5.get_ylabel() == ""
assert ax5.get_xlim() == (0, 15)

with pytest.raises(TypeError, match="xlim must be a tuple of 2 elements."):
lgb.plot_metric(evals_result0, metric="binary_logloss", xlim="not a tuple")