22import pandas as pd
33import pytest
44from matplotlib .collections import QuadMesh
5+ from sklearn .dummy import DummyRegressor
56from skrub .datasets import fetch_employee_salaries
67
78from skore import Display , EstimatorReport , train_test_split
89from skore ._externals ._skrub_compat import tabular_pipeline
910from skore ._sklearn ._plot .data .table_report import TableReportDisplay
1011
1112
12- @pytest .fixture
13+ @pytest .fixture ( scope = "module" )
1314def estimator_report ():
1415 data = fetch_employee_salaries ()
1516 X , y = data .X , data .y
@@ -20,10 +21,10 @@ def estimator_report():
2021 ).dt .to_pytimedelta ()
2122 X ["cents" ] = 100 * y
2223 split_data = train_test_split (X , y , random_state = 0 , as_dict = True )
23- return EstimatorReport (tabular_pipeline ("regressor" ), ** split_data )
24+ return EstimatorReport (tabular_pipeline (DummyRegressor () ), ** split_data )
2425
2526
26- @pytest .fixture
27+ @pytest .fixture ( scope = "module" )
2728def display (estimator_report ):
2829 return estimator_report .data .analyze ()
2930
@@ -86,7 +87,7 @@ def test_constructor(display):
8687)
8788def test_X_y (X , y ):
8889 split_data = train_test_split (X , y , random_state = 0 , as_dict = True )
89- report = EstimatorReport (tabular_pipeline ("regressor" ), ** split_data )
90+ report = EstimatorReport (tabular_pipeline (DummyRegressor () ), ** split_data )
9091 display = report .data .analyze ()
9192 assert isinstance (display , TableReportDisplay )
9293
@@ -114,17 +115,16 @@ def test_frame(estimator_report, data_source):
114115 )
115116
116117
117- def test_categorical_plots_1d (pyplot , estimator_report ):
118+ def test_categorical_plots_1d (pyplot , display ):
118119 """Check the plot output with categorical data in 1-d."""
119- display = estimator_report .data .analyze (data_source = "train" )
120120 display .plot (x = "gender" )
121121 assert hasattr (display , "ax_" )
122122 assert hasattr (display , "figure_" )
123123 assert display .ax_ .get_xlabel () == "gender"
124124 assert [label .get_text () for label in display .ax_ .get_xticklabels ()] == ["M" , "F" ]
125125 labels = display .ax_ .get_yticklabels ()
126126 assert labels [0 ].get_text () == "0"
127- assert labels [- 1 ].get_text () == "4000 "
127+ assert labels [- 1 ].get_text () == "5000 "
128128 assert display .ax_ .get_ylabel () == "Count"
129129 # orange
130130 assert display .ax_ .containers [0 ].patches [0 ].get_facecolor () == (
@@ -162,18 +162,16 @@ def test_numeric_plots_1d(pyplot, estimator_report):
162162 assert display .ax_ .get_ylabel () == "year_first_hired"
163163
164164
165- def test_top_k_categorical_plots_1d (pyplot , estimator_report ):
165+ def test_top_k_categorical_plots_1d (pyplot , display ):
166166 """Check the plot output with categorical data in 1-d and top k categories."""
167- display = estimator_report .data .analyze (data_source = "train" )
168167 display .plot (x = "division" )
169168 assert len (display .ax_ .get_xticklabels ()) == 20
170169 display .plot (x = "division" , top_k_categories = 30 )
171170 assert len (display .ax_ .get_xticklabels ()) == 30
172171
173172
174- def test_hue_plots_1d (pyplot , estimator_report ):
173+ def test_hue_plots_1d (pyplot , display ):
175174 """Check the plot output with hue in 1-d."""
176- display = estimator_report .data .analyze (data_source = "train" )
177175 display .plot (x = "gender" , hue = "current_annual_salary" )
178176 assert "BoxPlotContainer" in display .ax_ .containers [0 ].__class__ .__name__
179177 legend_labels = display .ax_ .legend_ .texts
@@ -205,9 +203,8 @@ def test_plot_duration_data_1d(pyplot, display):
205203 assert display .ax_ .get_ylabel () == "Years"
206204
207205
208- def test_plots_2d (pyplot , estimator_report ):
206+ def test_plots_2d (pyplot , display ):
209207 """Check the general behaviour of the 2-d plots."""
210- display = estimator_report .data .analyze (data_source = "train" )
211208 # scatter plot
212209 display .plot (y = "current_annual_salary" , x = "year_first_hired" )
213210 assert display .ax_ .get_xlabel () == "year_first_hired"
@@ -233,7 +230,7 @@ def test_plots_2d(pyplot, estimator_report):
233230
234231 # heatmap
235232 display .plot (x = "gender" , y = "division" )
236- assert len (display .ax_ .get_yticklabels ()) == 19
233+ assert len (display .ax_ .get_yticklabels ()) == 20
237234 assert display .ax_ .get_ylabel () == "division"
238235 assert display .ax_ .get_xlabel () == "gender"
239236 assert isinstance (display .ax_ .collections [0 ], QuadMesh )
@@ -247,9 +244,8 @@ def test_plots_2d(pyplot, estimator_report):
247244 assert any ("e+" in annotation for annotation in annotations )
248245
249246
250- def test_hue_plots_2d (pyplot , estimator_report ):
247+ def test_hue_plots_2d (pyplot , display ):
251248 """Check the plot output with hue parameter in 2-d."""
252- display = estimator_report .data .analyze (data_source = "train" )
253249 display .plot (x = "year_first_hired" , y = "current_annual_salary" , hue = "division" )
254250 assert len (display .ax_ .legend_ .texts ) == 21
255251 assert display .ax_ .legend_ .get_title ().get_text () == "division"
0 commit comments