55from sklearn .model_selection import train_test_split
66
77import lightgbm as lgb
8- from lightgbm .compat import GRAPHVIZ_INSTALLED , MATPLOTLIB_INSTALLED , PANDAS_INSTALLED , pd_DataFrame
8+ from lightgbm .compat import (
9+ GRAPHVIZ_INSTALLED ,
10+ MATPLOTLIB_INSTALLED ,
11+ PANDAS_INSTALLED ,
12+ pd_DataFrame ,
13+ )
914
1015if MATPLOTLIB_INSTALLED :
1116 import matplotlib
1924
2025@pytest .fixture (scope = "module" )
2126def breast_cancer_split ():
22- return train_test_split (* load_breast_cancer (return_X_y = True ), test_size = 0.1 , random_state = 1 )
27+ return train_test_split (
28+ * load_breast_cancer (return_X_y = True ), test_size = 0.1 , random_state = 1
29+ )
2330
2431
2532def _categorical_data (category_values_lower_bound , category_values_upper_bound ):
2633 X , y = load_breast_cancer (return_X_y = True )
2734 X_df = pd .DataFrame ()
2835 rnd = np .random .RandomState (0 )
29- n_cat_values = rnd .randint (category_values_lower_bound , category_values_upper_bound , size = X .shape [1 ])
36+ n_cat_values = rnd .randint (
37+ category_values_lower_bound , category_values_upper_bound , size = X .shape [1 ]
38+ )
3039 for i in range (X .shape [1 ]):
3140 bins = np .linspace (0 , 1 , num = n_cat_values [i ] + 1 )
32- X_df [f"cat_col_{ i } " ] = pd .qcut (X [:, i ], q = bins , labels = range (n_cat_values [i ])).as_unordered ()
41+ X_df [f"cat_col_{ i } " ] = pd .qcut (
42+ X [:, i ], q = bins , labels = range (n_cat_values [i ])
43+ ).as_unordered ()
3344 return X_df , y
3445
3546
@@ -68,7 +79,9 @@ def test_plot_importance(params, breast_cancer_split, train_data):
6879 for patch in ax1 .patches :
6980 assert patch .get_facecolor () == (1.0 , 0 , 0 , 1.0 ) # red
7081
71- ax2 = lgb .plot_importance (gbm0 , color = ["r" , "y" , "g" , "b" ], title = None , xlabel = None , ylabel = None )
82+ ax2 = lgb .plot_importance (
83+ gbm0 , color = ["r" , "y" , "g" , "b" ], title = None , xlabel = None , ylabel = None
84+ )
7285 assert isinstance (ax2 , matplotlib .axes .Axes )
7386 assert ax2 .get_title () == ""
7487 assert ax2 .get_xlabel () == ""
@@ -80,7 +93,10 @@ def test_plot_importance(params, breast_cancer_split, train_data):
8093 assert ax2 .patches [3 ].get_facecolor () == (0 , 0 , 1.0 , 1.0 ) # b
8194
8295 ax3 = lgb .plot_importance (
83- gbm0 , title = "t @importance_type@" , xlabel = "x @importance_type@" , ylabel = "y @importance_type@"
96+ gbm0 ,
97+ title = "t @importance_type@" ,
98+ xlabel = "x @importance_type@" ,
99+ ylabel = "y @importance_type@" ,
84100 )
85101 assert isinstance (ax3 , matplotlib .axes .Axes )
86102 assert ax3 .get_title () == "t @importance_type@"
@@ -97,20 +113,53 @@ def test_plot_importance(params, breast_cancer_split, train_data):
97113 assert len (ax4 .patches ) <= 30
98114
99115 with pytest .raises (TypeError , match = "xlim must be a tuple of 2 elements." ):
100- lgb .plot_importance (gbm0 , title = None , xlabel = None , ylabel = None , xlim = "not a tuple" )
116+ lgb .plot_importance (
117+ gbm0 , title = None , xlabel = None , ylabel = None , xlim = "not a tuple"
118+ )
101119
102- gbm2 = lgb .LGBMClassifier (n_estimators = 10 , num_leaves = 3 , verbose = - 1 , importance_type = "gain" )
120+ # test ylim parameter
121+ ax5 = lgb .plot_importance (gbm0 , title = None , xlabel = None , ylabel = None , ylim = (- 1 , 30 ))
122+ assert isinstance (ax5 , matplotlib .axes .Axes )
123+ assert ax5 .get_ylim () == (- 1 , 30 )
124+
125+ with pytest .raises (TypeError , match = "ylim must be a tuple of 2 elements." ):
126+ lgb .plot_importance (gbm0 , ylim = "not a tuple" )
127+
128+ # test max_num_features parameter
129+ ax6 = lgb .plot_importance (gbm0 , max_num_features = 5 )
130+ assert isinstance (ax6 , matplotlib .axes .Axes )
131+ assert len (ax6 .patches ) == 5
132+
133+ # test providing pre-allocated ax with figsize
134+ fig , ax_prealloc = matplotlib .pyplot .subplots (1 , 1 , figsize = (12 , 8 ))
135+ ax7 = lgb .plot_importance (gbm0 , ax = ax_prealloc , figsize = (6 , 4 ))
136+ assert ax7 is ax_prealloc
137+ # when ax is provided, figsize should be ignored, so figure size remains (12, 8)
138+ assert ax7 .get_figure ().get_figwidth () == 12
139+ assert ax7 .get_figure ().get_figheight () == 8
140+
141+ gbm2 = lgb .LGBMClassifier (
142+ n_estimators = 10 , num_leaves = 3 , verbose = - 1 , importance_type = "gain"
143+ )
103144 gbm2 .fit (X_train , y_train )
104145
105146 def get_bounds_of_first_patch (axes ):
106147 return axes .patches [0 ].get_extents ().bounds
107148
108149 first_bar1 = get_bounds_of_first_patch (lgb .plot_importance (gbm1 ))
109- first_bar2 = get_bounds_of_first_patch (lgb .plot_importance (gbm1 , importance_type = "split" ))
110- first_bar3 = get_bounds_of_first_patch (lgb .plot_importance (gbm1 , importance_type = "gain" ))
150+ first_bar2 = get_bounds_of_first_patch (
151+ lgb .plot_importance (gbm1 , importance_type = "split" )
152+ )
153+ first_bar3 = get_bounds_of_first_patch (
154+ lgb .plot_importance (gbm1 , importance_type = "gain" )
155+ )
111156 first_bar4 = get_bounds_of_first_patch (lgb .plot_importance (gbm2 ))
112- first_bar5 = get_bounds_of_first_patch (lgb .plot_importance (gbm2 , importance_type = "split" ))
113- first_bar6 = get_bounds_of_first_patch (lgb .plot_importance (gbm2 , importance_type = "gain" ))
157+ first_bar5 = get_bounds_of_first_patch (
158+ lgb .plot_importance (gbm2 , importance_type = "split" )
159+ )
160+ first_bar6 = get_bounds_of_first_patch (
161+ lgb .plot_importance (gbm2 , importance_type = "gain" )
162+ )
114163
115164 assert first_bar1 == first_bar2
116165 assert first_bar1 == first_bar5
@@ -153,7 +202,13 @@ def test_plot_split_value_histogram(params, breast_cancer_split, train_data):
153202 assert patch .get_facecolor () == (1.0 , 0 , 0 , 1.0 ) # red
154203
155204 ax2 = lgb .plot_split_value_histogram (
156- gbm0 , 27 , bins = 10 , color = ["r" , "y" , "g" , "b" ], title = None , xlabel = None , ylabel = None
205+ gbm0 ,
206+ 27 ,
207+ bins = 10 ,
208+ color = ["r" , "y" , "g" , "b" ],
209+ title = None ,
210+ xlabel = None ,
211+ ylabel = None ,
157212 )
158213 assert isinstance (ax2 , matplotlib .axes .Axes )
159214 assert ax2 .get_title () == ""
@@ -165,14 +220,22 @@ def test_plot_split_value_histogram(params, breast_cancer_split, train_data):
165220 assert ax2 .patches [2 ].get_facecolor () == (0 , 0.5 , 0 , 1.0 ) # g
166221 assert ax2 .patches [3 ].get_facecolor () == (0 , 0 , 1.0 , 1.0 ) # b
167222
223+ # test xlim and ylim parameters
224+ ax3 = lgb .plot_split_value_histogram (gbm0 , 27 , xlim = (0 , 100 ), ylim = (0 , 50 ))
225+ assert isinstance (ax3 , matplotlib .axes .Axes )
226+ assert ax3 .get_xlim () == (0 , 100 )
227+ assert ax3 .get_ylim () == (0 , 50 )
228+
168229 with pytest .raises (
169- ValueError , match = "Cannot plot split value histogram, because feature 0 was not used in splitting"
230+ ValueError ,
231+ match = "Cannot plot split value histogram, because feature 0 was not used in splitting" ,
170232 ):
171233 lgb .plot_split_value_histogram (gbm0 , 0 ) # was not used in splitting
172234
173235
174236@pytest .mark .skipif (
175- not MATPLOTLIB_INSTALLED or not GRAPHVIZ_INSTALLED , reason = "matplotlib or graphviz is not installed"
237+ not MATPLOTLIB_INSTALLED or not GRAPHVIZ_INSTALLED ,
238+ reason = "matplotlib or graphviz is not installed" ,
176239)
177240def test_plot_tree (breast_cancer_split ):
178241 X_train , _ , y_train , _ = breast_cancer_split
@@ -194,7 +257,9 @@ def test_create_tree_digraph(tmp_path, breast_cancer_split):
194257 X_train , _ , y_train , _ = breast_cancer_split
195258
196259 constraints = [- 1 , 1 ] * int (X_train .shape [1 ] / 2 )
197- gbm = lgb .LGBMClassifier (n_estimators = 10 , num_leaves = 3 , verbose = - 1 , monotone_constraints = constraints )
260+ gbm = lgb .LGBMClassifier (
261+ n_estimators = 10 , num_leaves = 3 , verbose = - 1 , monotone_constraints = constraints
262+ )
198263 gbm .fit (X_train , y_train )
199264
200265 with pytest .raises (IndexError , match = "tree_index is out of range." ):
@@ -389,7 +454,9 @@ def test_example_case_in_tree_digraph():
389454 while "decision_type" in node : # iterate through the splits
390455 split_index = node ["split_index" ]
391456
392- node_in_graph = [n for n in gbody if f"split{ split_index } " in n and "->" not in n ]
457+ node_in_graph = [
458+ n for n in gbody if f"split{ split_index } " in n and "->" not in n
459+ ]
393460 assert len (node_in_graph ) == 1
394461 seen_indices .add (gbody .index (node_in_graph [0 ]))
395462
@@ -420,14 +487,22 @@ def test_example_case_in_tree_digraph():
420487 assert "color=blue" in leaf_in_graph [0 ]
421488 assert len (edge_to_leaf ) == 1
422489 assert "color=blue" in edge_to_leaf [0 ]
423- seen_indices .update ([gbody .index (leaf_in_graph [0 ]), gbody .index (edge_to_leaf [0 ])])
490+ seen_indices .update (
491+ [gbody .index (leaf_in_graph [0 ]), gbody .index (edge_to_leaf [0 ])]
492+ )
424493
425494 # check that the rest of the elements have black color
426- remaining_elements = [e for i , e in enumerate (graph .body ) if i not in seen_indices and "graph" not in e ]
495+ remaining_elements = [
496+ e
497+ for i , e in enumerate (graph .body )
498+ if i not in seen_indices and "graph" not in e
499+ ]
427500 assert all ("color=black" in e for e in remaining_elements )
428501
429502 # check that we got to the expected leaf
430- expected_leaf = bst .predict (example_case , start_iteration = i , num_iteration = 1 , pred_leaf = True )[0 ]
503+ expected_leaf = bst .predict (
504+ example_case , start_iteration = i , num_iteration = 1 , pred_leaf = True
505+ )[0 ]
431506 assert leaf_index == expected_leaf
432507 assert makes_categorical_splits
433508
@@ -464,7 +539,9 @@ def test_plot_metrics(params, breast_cancer_split, train_data):
464539 num_boost_round = 10 ,
465540 callbacks = [lgb .record_evaluation (evals_result0 )],
466541 )
467- with pytest .warns (UserWarning , match = "More than one metric available, picking one to plot." ):
542+ with pytest .warns (
543+ UserWarning , match = "More than one metric available, picking one to plot."
544+ ):
468545 ax0 = lgb .plot_metric (evals_result0 )
469546 assert isinstance (ax0 , matplotlib .axes .Axes )
470547 assert ax0 .get_title () == "Metric during training"
@@ -521,7 +598,12 @@ def test_plot_metrics(params, breast_cancer_split, train_data):
521598 assert not grid_line .get_visible ()
522599
523600 evals_result1 = {}
524- lgb .train (params , train_data , num_boost_round = 10 , callbacks = [lgb .record_evaluation (evals_result1 )])
601+ lgb .train (
602+ params ,
603+ train_data ,
604+ num_boost_round = 10 ,
605+ callbacks = [lgb .record_evaluation (evals_result1 )],
606+ )
525607 with pytest .raises (ValueError , match = "eval results cannot be empty." ):
526608 lgb .plot_metric (evals_result1 )
527609
@@ -535,3 +617,11 @@ def test_plot_metrics(params, breast_cancer_split, train_data):
535617 legend_items = ax4 .get_legend ().get_texts ()
536618 assert len (legend_items ) == 1
537619 assert legend_items [0 ].get_text () == "valid_0"
620+
621+ # test xlim and ylim parameters
622+ ax5 = lgb .plot_metric (
623+ evals_result0 , metric = "binary_logloss" , xlim = (0 , 15 ), ylim = (0 , 1 )
624+ )
625+ assert isinstance (ax5 , matplotlib .axes .Axes )
626+ assert ax5 .get_xlim () == (0 , 15 )
627+ assert ax5 .get_ylim () == (0 , 1 )
0 commit comments