Skip to content

Commit 7c08f39

Browse files
committed
removed unnecessary comments, fixed lint errors
1 parent 2dba727 commit 7c08f39

File tree

3 files changed

+165
-1
lines changed

3 files changed

+165
-1
lines changed

skore/tests/unit/sklearn/plot/roc_curve/test_comparison_cross_validation.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,79 @@ def test_multiclass_classification_kwargs(pyplot, multiclass_classification_repo
232232
display.plot(despine=False)
233233
assert display.ax_[0].spines["top"].get_visible()
234234
assert display.ax_[0].spines["right"].get_visible()
235+
236+
237+
def test_data_source_binary_classification(pyplot, binary_classification_data_no_split):
238+
"""
239+
Test passing data_source to ROC plot in ComparisonReport with CrossValidationReport
240+
"""
241+
estimator, X, y = binary_classification_data_no_split
242+
estimator_1 = LogisticRegression()
243+
estimator_2 = LogisticRegression(C=10)
244+
245+
report = ComparisonReport(
246+
reports={
247+
"estimator_1": CrossValidationReport(estimator_1, X, y),
248+
"estimator_2": CrossValidationReport(estimator_2, X, y),
249+
}
250+
)
251+
252+
display = report.metrics.roc(data_source="X_y", X=X, y=y)
253+
assert display.data_source == "X_y"
254+
display.plot()
255+
256+
display = report.metrics.roc(data_source="train")
257+
assert display.data_source == "train"
258+
display.plot()
259+
260+
display = report.metrics.roc(data_source="test")
261+
assert display.data_source == "test"
262+
display.plot()
263+
264+
n_reports = len(report.reports_)
265+
n_splits = report.reports_[0]._cv_splitter.n_splits
266+
expected_auc_entries = n_reports * n_splits
267+
268+
assert len(display.roc_auc) == expected_auc_entries
269+
auc_values = display.roc_auc["roc_auc"].values
270+
assert all(0 <= auc <= 1 for auc in auc_values)
271+
272+
273+
def test_data_source_multiclass_classification(
274+
pyplot, multiclass_classification_data_no_split
275+
):
276+
"Test data_source in ROC plot for ComparisonReport with multiclass and CV report"
277+
estimator, X, y = multiclass_classification_data_no_split
278+
estimator_1 = LogisticRegression()
279+
estimator_2 = LogisticRegression(C=10)
280+
281+
report = ComparisonReport(
282+
reports={
283+
"estimator_1": CrossValidationReport(estimator_1, X, y),
284+
"estimator_2": CrossValidationReport(estimator_2, X, y),
285+
}
286+
)
287+
288+
class_labels = np.unique(y)
289+
290+
display = report.metrics.roc(data_source="X_y", X=X, y=y)
291+
assert display.data_source == "X_y"
292+
display.plot()
293+
294+
display = report.metrics.roc(data_source="train")
295+
assert display.data_source == "train"
296+
display.plot()
297+
298+
display = report.metrics.roc(data_source="test")
299+
assert display.data_source == "test"
300+
display.plot()
301+
302+
n_reports = len(report.reports_)
303+
n_splits = report.reports_[0]._cv_splitter.n_splits
304+
n_classes = len(class_labels)
305+
expected_combinations = n_reports * n_classes * n_splits
306+
307+
assert len(display.roc_auc) == expected_combinations
308+
309+
auc_values = display.roc_auc["roc_auc"].values
310+
assert all(0 <= auc <= 1 for auc in auc_values)

skore/tests/unit/sklearn/plot/roc_curve/test_comparison_estimator.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,92 @@ def test_multiclass_classification(pyplot, multiclass_classification_data):
141141
assert display.ax_.get_title() == "ROC Curve"
142142

143143

144+
def test_data_source_binary_classification(pyplot, binary_classification_data):
145+
"""Test data_source in ROC plot for ComparisonReport."""
146+
estimator, X_train, X_test, y_train, y_test = binary_classification_data
147+
estimator_2 = clone(estimator).set_params(C=10).fit(X_train, y_train)
148+
149+
report = ComparisonReport(
150+
reports={
151+
"estimator_1": EstimatorReport(
152+
estimator,
153+
X_train=X_train,
154+
y_train=y_train,
155+
X_test=X_test,
156+
y_test=y_test,
157+
),
158+
"estimator_2": EstimatorReport(
159+
estimator_2,
160+
X_train=X_train,
161+
y_train=y_train,
162+
X_test=X_test,
163+
y_test=y_test,
164+
),
165+
}
166+
)
167+
168+
display = report.metrics.roc(data_source="X_y", X=X_train, y=y_train)
169+
assert display.data_source == "X_y"
170+
display.plot()
171+
172+
display = report.metrics.roc(data_source="train")
173+
assert display.data_source == "train"
174+
display.plot()
175+
176+
display = report.metrics.roc(data_source="test")
177+
assert display.data_source == "test"
178+
display.plot()
179+
180+
train_auc = display.roc_auc["roc_auc"].values
181+
assert len(train_auc) == 2
182+
assert all(0 <= auc <= 1 for auc in train_auc)
183+
184+
185+
def test_data_source_multiclass_classification(pyplot, multiclass_classification_data):
186+
"""Test data_source in ROC plot for ComparisonReport with multiclass data"""
187+
estimator, X_train, X_test, y_train, y_test = multiclass_classification_data
188+
estimator_2 = clone(estimator).set_params(C=10).fit(X_train, y_train)
189+
190+
report = ComparisonReport(
191+
reports={
192+
"estimator_1": EstimatorReport(
193+
estimator,
194+
X_train=X_train,
195+
y_train=y_train,
196+
X_test=X_test,
197+
y_test=y_test,
198+
),
199+
"estimator_2": EstimatorReport(
200+
estimator_2,
201+
X_train=X_train,
202+
y_train=y_train,
203+
X_test=X_test,
204+
y_test=y_test,
205+
),
206+
}
207+
)
208+
209+
class_labels = report.reports_[0].estimator_.classes_
210+
211+
display = report.metrics.roc(data_source="X_y", X=X_train, y=y_train)
212+
assert display.data_source == "X_y"
213+
display.plot()
214+
215+
display = report.metrics.roc(data_source="train")
216+
assert display.data_source == "train"
217+
display.plot()
218+
219+
display = report.metrics.roc(data_source="test")
220+
assert display.data_source == "test"
221+
display.plot()
222+
223+
expected_combinations = len(report.report_names_) * len(class_labels)
224+
assert len(display.roc_auc) == expected_combinations
225+
226+
auc_values = display.roc_auc["roc_auc"].values
227+
assert all(0 <= auc <= 1 for auc in auc_values)
228+
229+
144230
def test_binary_classification_kwargs(pyplot, binary_classification_data):
145231
"""Check that we can pass keyword arguments to the ROC curve plot for
146232
cross-validation."""

skore/tests/unit/sklearn/plot/roc_curve/test_estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def test_binary_classification(pyplot, binary_classification_data):
5959
assert display.ax_.get_title() == "ROC Curve for LogisticRegression"
6060
assert display.data_source == "test"
6161

62+
6263
def test_multiclass_classification(pyplot, multiclass_classification_data):
6364
"""Check the attributes and default plotting behaviour of the ROC curve plot with
6465
multiclass data."""
@@ -110,6 +111,7 @@ def test_multiclass_classification(pyplot, multiclass_classification_data):
110111
assert display.ax_.get_title() == "ROC Curve for LogisticRegression"
111112
assert display.data_source == "test"
112113

114+
113115
def test_data_source_binary_classification(pyplot, binary_classification_data):
114116
"""Check that we can pass the `data_source` argument to the ROC curve plot."""
115117
estimator, X_train, X_test, y_train, y_test = binary_classification_data
@@ -123,7 +125,7 @@ def test_data_source_binary_classification(pyplot, binary_classification_data):
123125
display.lines_[0].get_label()
124126
== f"AUC = {display.roc_auc['roc_auc'].item():0.2f}"
125127
)
126-
128+
127129
display = report.metrics.roc(data_source="train")
128130
display.plot()
129131
assert (

0 commit comments

Comments
 (0)