Skip to content

Commit b0b2c9f

Browse files
feat(skore): Adapt RocCurveDisplay.frame() for data_source="both" (#2168)
For EstimatorReports Closes #2147
1 parent 868e8b6 commit b0b2c9f

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

โ€Žskore/src/skore/_sklearn/_plot/metrics/roc_curve.pyโ€Ž

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,9 @@ def frame(self, with_roc_auc: bool = False) -> DataFrame:
11271127
else: # self.report_type == "comparison-cross-validation"
11281128
indexing_columns = ["estimator_name", "split"]
11291129

1130+
if self.data_source == "both":
1131+
indexing_columns += ["data_source"]
1132+
11301133
if self.ml_task == "binary-classification":
11311134
columns = indexing_columns + statistical_columns
11321135
else:

โ€Žskore/tests/unit/displays/roc_curve/test_estimator.pyโ€Ž

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,10 +344,37 @@ def test_frame_multiclass_classification(
344344
assert df["label"].nunique() == len(estimator.classes_)
345345

346346
if with_roc_auc:
347-
for (_), group in df.groupby(["label"], observed=True):
347+
for _, group in df.groupby(["label"], observed=True):
348348
assert group["roc_auc"].nunique() == 1
349349

350350

351+
@pytest.mark.parametrize("with_roc_auc", [False, True])
352+
def test_frame_multiclass_classification_data_source_both(
353+
logistic_multiclass_classification_with_train_test, with_roc_auc
354+
):
355+
"""
356+
Test the frame method with multiclass classification data and data_source="both".
357+
"""
358+
estimator, X_train, X_test, y_train, y_test = (
359+
logistic_multiclass_classification_with_train_test
360+
)
361+
report = EstimatorReport(
362+
estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
363+
)
364+
df = report.metrics.roc(data_source="both").frame(with_roc_auc=with_roc_auc)
365+
expected_index = ["data_source", "label"]
366+
expected_columns = ["threshold", "fpr", "tpr"]
367+
if with_roc_auc:
368+
expected_columns.append("roc_auc")
369+
370+
check_frame_structure(df, expected_index, expected_columns)
371+
assert df["label"].nunique() == len(estimator.classes_)
372+
373+
if with_roc_auc:
374+
for _, group in df.groupby(["label"], observed=True):
375+
assert group["roc_auc"].nunique() == 2
376+
377+
351378
def test_legend(
352379
pyplot,
353380
logistic_binary_classification_with_train_test,

0 commit comments

Comments
ย (0)