Skip to content

Commit 2f7f1ab

Browse files
authored
chore: Remove some unnecessary for loops in some unit tests (#1897)
closes #1853
1 parent d5f38d5 commit 2f7f1ab

File tree

4 files changed

+24
-24
lines changed

4 files changed

+24
-24
lines changed

skore/tests/unit/sklearn/plot/prediction_error/test_comparison_cross_validation.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,15 @@ def test_constructor(regression_data_no_split):
142142
display = report.metrics.prediction_error()
143143

144144
index_columns = ["estimator_name", "split_index"]
145-
for df in [display.prediction_error]:
146-
assert all(col in df.columns for col in index_columns)
147-
assert df.query("estimator_name == 'estimator_1'")[
148-
"split_index"
149-
].unique().tolist() == list(range(cv))
150-
assert df.query("estimator_name == 'estimator_2'")[
151-
"split_index"
152-
].unique().tolist() == list(range(cv + 1))
153-
assert df["estimator_name"].unique().tolist() == report.report_names_
145+
df = display.prediction_error
146+
assert all(col in df.columns for col in index_columns)
147+
assert df.query("estimator_name == 'estimator_1'")[
148+
"split_index"
149+
].unique().tolist() == list(range(cv))
150+
assert df.query("estimator_name == 'estimator_2'")[
151+
"split_index"
152+
].unique().tolist() == list(range(cv + 1))
153+
assert df["estimator_name"].unique().tolist() == report.report_names_
154154

155155

156156
def test_frame(report):

skore/tests/unit/sklearn/plot/prediction_error/test_comparison_estimator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def test_constructor(regression_data):
273273
display = report.metrics.prediction_error()
274274

275275
index_columns = ["estimator_name", "split_index"]
276-
for df in [display.prediction_error]:
277-
assert all(col in df.columns for col in index_columns)
278-
assert df["estimator_name"].unique().tolist() == report.report_names_
279-
assert df["split_index"].isnull().all()
276+
df = display.prediction_error
277+
assert all(col in df.columns for col in index_columns)
278+
assert df["estimator_name"].unique().tolist() == report.report_names_
279+
assert df["split_index"].isnull().all()

skore/tests/unit/sklearn/plot/prediction_error/test_cross_validation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def test_constructor(regression_data_no_split):
172172
display = report.metrics.prediction_error()
173173

174174
index_columns = ["estimator_name", "split_index"]
175-
for df in [display.prediction_error]:
176-
assert all(col in df.columns for col in index_columns)
177-
assert df["estimator_name"].unique() == report.estimator_name_
178-
assert df["split_index"].unique().tolist() == list(range(cv))
175+
df = display.prediction_error
176+
assert all(col in df.columns for col in index_columns)
177+
assert df["estimator_name"].unique() == report.estimator_name_
178+
assert df["split_index"].unique().tolist() == list(range(cv))

skore/tests/unit/sklearn/plot/prediction_error/test_estimator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -285,10 +285,10 @@ def test_constructor(regression_data):
285285
display = report.metrics.prediction_error()
286286

287287
index_columns = ["estimator_name", "split_index"]
288-
for df in [display.prediction_error]:
289-
assert all(col in df.columns for col in index_columns)
290-
assert df["estimator_name"].unique() == report.estimator_name_
291-
assert df["split_index"].isnull().all()
292-
np.testing.assert_allclose(df["y_true"], y_test)
293-
np.testing.assert_allclose(df["y_pred"], estimator.predict(X_test))
294-
np.testing.assert_allclose(df["residuals"], y_test - estimator.predict(X_test))
288+
df = display.prediction_error
289+
assert all(col in df.columns for col in index_columns)
290+
assert df["estimator_name"].unique() == report.estimator_name_
291+
assert df["split_index"].isnull().all()
292+
np.testing.assert_allclose(df["y_true"], y_test)
293+
np.testing.assert_allclose(df["y_pred"], estimator.predict(X_test))
294+
np.testing.assert_allclose(df["residuals"], y_test - estimator.predict(X_test))

0 commit comments

Comments
 (0)