Skip to content

Commit a209d9a

Browse files
chore(skore): Improve test time of data accessor (#2162)
Use a simpler model and re-use fixtures. Test time of `tests/unit/displays/table_report/` goes from 45s to 12s.
1 parent 83bdcdc commit a209d9a

File tree

3 files changed

+22
-26
lines changed

3 files changed

+22
-26
lines changed

skore/tests/unit/displays/table_report/test_common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,28 @@
1313
)
1414

1515

16-
@pytest.fixture
16+
@pytest.fixture(scope='module')
1717
def X_y():
1818
X, y = make_regression(n_samples=100, n_features=5, random_state=42)
1919
X = pd.DataFrame(X, columns=[f"Feature_{i}" for i in range(5)])
2020
y = pd.Series(y, name="Target_")
2121
return X, y
2222

2323

24-
@pytest.fixture
24+
@pytest.fixture(scope='module')
2525
def estimator_report(X_y):
2626
X, y = X_y
2727
split_data = train_test_split(X, y, random_state=0, as_dict=True)
2828
return EstimatorReport(tabular_pipeline("regressor"), **split_data)
2929

3030

31-
@pytest.fixture
31+
@pytest.fixture(scope='module')
3232
def cross_validation_report(X_y):
3333
X, y = X_y
3434
return CrossValidationReport(tabular_pipeline("regressor"), X=X, y=y)
3535

3636

37-
@pytest.fixture(params=["estimator_report", "cross_validation_report"])
37+
@pytest.fixture(params=["estimator_report", "cross_validation_report"], scope='module')
3838
def display(request):
3939
report = request.getfixturevalue(request.param)
4040
return report.data.analyze()

skore/tests/unit/displays/table_report/test_cross_validation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,21 @@
33
import pytest
44
from sklearn.cluster import KMeans
55
from sklearn.datasets import make_regression
6+
from sklearn.dummy import DummyRegressor
67

78
from skore import CrossValidationReport, Display, TableReportDisplay
89
from skore._externals._skrub_compat import tabular_pipeline
910

1011

11-
@pytest.fixture
12+
@pytest.fixture(scope="module")
1213
def cross_validation_report():
1314
X, y = make_regression(n_samples=100, n_features=5, random_state=42)
1415
X = pd.DataFrame(X, columns=[f"Feature_{i}" for i in range(5)])
1516
y = pd.Series(y, name="Target_")
16-
return CrossValidationReport(tabular_pipeline("regressor"), X=X, y=y)
17+
return CrossValidationReport(tabular_pipeline(DummyRegressor()), X=X, y=y)
1718

1819

19-
@pytest.fixture
20+
@pytest.fixture(scope="module")
2021
def display(cross_validation_report):
2122
return cross_validation_report.data.analyze()
2223

@@ -44,9 +45,8 @@ def test_table_report_display_constructor(display):
4445
)
4546

4647

47-
def test_table_report_display_frame(cross_validation_report):
48+
def test_table_report_display_frame(cross_validation_report, display):
4849
"""Check that we return the expected kind of data when calling `.frame`."""
49-
display = cross_validation_report.data.analyze()
5050
dataset = display.frame(kind="dataset")
5151

5252
pd.testing.assert_frame_equal(
@@ -81,7 +81,7 @@ def test_table_report_display_frame(cross_validation_report):
8181
)
8282
def test_display_creation(X, y):
8383
"""Check that the display can be created with different types of X and y."""
84-
report = CrossValidationReport(tabular_pipeline("regressor"), X=X, y=y)
84+
report = CrossValidationReport(tabular_pipeline(DummyRegressor()), X=X, y=y)
8585
display = report.data.analyze()
8686
assert isinstance(display, TableReportDisplay)
8787

skore/tests/unit/displays/table_report/test_estimator.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
import pandas as pd
33
import pytest
44
from matplotlib.collections import QuadMesh
5+
from sklearn.dummy import DummyRegressor
56
from skrub.datasets import fetch_employee_salaries
67

78
from skore import Display, EstimatorReport, train_test_split
89
from skore._externals._skrub_compat import tabular_pipeline
910
from skore._sklearn._plot.data.table_report import TableReportDisplay
1011

1112

12-
@pytest.fixture
13+
@pytest.fixture(scope="module")
1314
def estimator_report():
1415
data = fetch_employee_salaries()
1516
X, y = data.X, data.y
@@ -20,10 +21,10 @@ def estimator_report():
2021
).dt.to_pytimedelta()
2122
X["cents"] = 100 * y
2223
split_data = train_test_split(X, y, random_state=0, as_dict=True)
23-
return EstimatorReport(tabular_pipeline("regressor"), **split_data)
24+
return EstimatorReport(tabular_pipeline(DummyRegressor()), **split_data)
2425

2526

26-
@pytest.fixture
27+
@pytest.fixture(scope="module")
2728
def display(estimator_report):
2829
return estimator_report.data.analyze()
2930

@@ -86,7 +87,7 @@ def test_constructor(display):
8687
)
8788
def test_X_y(X, y):
8889
split_data = train_test_split(X, y, random_state=0, as_dict=True)
89-
report = EstimatorReport(tabular_pipeline("regressor"), **split_data)
90+
report = EstimatorReport(tabular_pipeline(DummyRegressor()), **split_data)
9091
display = report.data.analyze()
9192
assert isinstance(display, TableReportDisplay)
9293

@@ -114,17 +115,16 @@ def test_frame(estimator_report, data_source):
114115
)
115116

116117

117-
def test_categorical_plots_1d(pyplot, estimator_report):
118+
def test_categorical_plots_1d(pyplot, display):
118119
"""Check the plot output with categorical data in 1-d."""
119-
display = estimator_report.data.analyze(data_source="train")
120120
display.plot(x="gender")
121121
assert hasattr(display, "ax_")
122122
assert hasattr(display, "figure_")
123123
assert display.ax_.get_xlabel() == "gender"
124124
assert [label.get_text() for label in display.ax_.get_xticklabels()] == ["M", "F"]
125125
labels = display.ax_.get_yticklabels()
126126
assert labels[0].get_text() == "0"
127-
assert labels[-1].get_text() == "4000"
127+
assert labels[-1].get_text() == "5000"
128128
assert display.ax_.get_ylabel() == "Count"
129129
# orange
130130
assert display.ax_.containers[0].patches[0].get_facecolor() == (
@@ -162,18 +162,16 @@ def test_numeric_plots_1d(pyplot, estimator_report):
162162
assert display.ax_.get_ylabel() == "year_first_hired"
163163

164164

165-
def test_top_k_categorical_plots_1d(pyplot, estimator_report):
165+
def test_top_k_categorical_plots_1d(pyplot, display):
166166
"""Check the plot output with categorical data in 1-d and top k categories."""
167-
display = estimator_report.data.analyze(data_source="train")
168167
display.plot(x="division")
169168
assert len(display.ax_.get_xticklabels()) == 20
170169
display.plot(x="division", top_k_categories=30)
171170
assert len(display.ax_.get_xticklabels()) == 30
172171

173172

174-
def test_hue_plots_1d(pyplot, estimator_report):
173+
def test_hue_plots_1d(pyplot, display):
175174
"""Check the plot output with hue in 1-d."""
176-
display = estimator_report.data.analyze(data_source="train")
177175
display.plot(x="gender", hue="current_annual_salary")
178176
assert "BoxPlotContainer" in display.ax_.containers[0].__class__.__name__
179177
legend_labels = display.ax_.legend_.texts
@@ -205,9 +203,8 @@ def test_plot_duration_data_1d(pyplot, display):
205203
assert display.ax_.get_ylabel() == "Years"
206204

207205

208-
def test_plots_2d(pyplot, estimator_report):
206+
def test_plots_2d(pyplot, display):
209207
"""Check the general behaviour of the 2-d plots."""
210-
display = estimator_report.data.analyze(data_source="train")
211208
# scatter plot
212209
display.plot(y="current_annual_salary", x="year_first_hired")
213210
assert display.ax_.get_xlabel() == "year_first_hired"
@@ -233,7 +230,7 @@ def test_plots_2d(pyplot, estimator_report):
233230

234231
# heatmap
235232
display.plot(x="gender", y="division")
236-
assert len(display.ax_.get_yticklabels()) == 19
233+
assert len(display.ax_.get_yticklabels()) == 20
237234
assert display.ax_.get_ylabel() == "division"
238235
assert display.ax_.get_xlabel() == "gender"
239236
assert isinstance(display.ax_.collections[0], QuadMesh)
@@ -247,9 +244,8 @@ def test_plots_2d(pyplot, estimator_report):
247244
assert any("e+" in annotation for annotation in annotations)
248245

249246

250-
def test_hue_plots_2d(pyplot, estimator_report):
247+
def test_hue_plots_2d(pyplot, display):
251248
"""Check the plot output with hue parameter in 2-d."""
252-
display = estimator_report.data.analyze(data_source="train")
253249
display.plot(x="year_first_hired", y="current_annual_salary", hue="division")
254250
assert len(display.ax_.legend_.texts) == 21
255251
assert display.ax_.legend_.get_title().get_text() == "division"

0 commit comments

Comments
 (0)