Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,79 @@ def test_multiclass_classification_kwargs(pyplot, multiclass_classification_repo
display.plot(despine=False)
assert display.ax_[0].spines["top"].get_visible()
assert display.ax_[0].spines["right"].get_visible()


def test_data_source_binary_classification(pyplot, binary_classification_data_no_split):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check that the results when changing the data_source do change please?
For instance, you can use a subsample of X and y when giving the data source X_y so you don't have to create a new one, and then check that the dataframes display.frame() outputed are different with assert not and the equals function.

"""
Test passing data_source to ROC plot in ComparisonReport with CrossValidationReport
"""
estimator, X, y = binary_classification_data_no_split
estimator_1 = LogisticRegression()
estimator_2 = LogisticRegression(C=10)

report = ComparisonReport(
reports={
"estimator_1": CrossValidationReport(estimator_1, X, y),
"estimator_2": CrossValidationReport(estimator_2, X, y),
}
)

display = report.metrics.roc(data_source="X_y", X=X, y=y)
assert display.data_source == "X_y"
display.plot()

display = report.metrics.roc(data_source="train")
assert display.data_source == "train"
display.plot()

display = report.metrics.roc(data_source="test")
assert display.data_source == "test"
display.plot()

n_reports = len(report.reports_)
n_splits = report.reports_[0]._cv_splitter.n_splits
expected_auc_entries = n_reports * n_splits

assert len(display.roc_auc) == expected_auc_entries
auc_values = display.roc_auc["roc_auc"].values
assert all(0 <= auc <= 1 for auc in auc_values)


def test_data_source_multiclass_classification(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, can you check that the outputs are different please?

pyplot, multiclass_classification_data_no_split
):
"Test data_source in ROC plot for ComparisonReport with multiclass and CV report"
estimator, X, y = multiclass_classification_data_no_split
estimator_1 = LogisticRegression()
estimator_2 = LogisticRegression(C=10)

report = ComparisonReport(
reports={
"estimator_1": CrossValidationReport(estimator_1, X, y),
"estimator_2": CrossValidationReport(estimator_2, X, y),
}
)

class_labels = np.unique(y)

display = report.metrics.roc(data_source="X_y", X=X, y=y)
assert display.data_source == "X_y"
display.plot()

display = report.metrics.roc(data_source="train")
assert display.data_source == "train"
display.plot()

display = report.metrics.roc(data_source="test")
assert display.data_source == "test"
display.plot()

n_reports = len(report.reports_)
n_splits = report.reports_[0]._cv_splitter.n_splits
n_classes = len(class_labels)
expected_combinations = n_reports * n_classes * n_splits

assert len(display.roc_auc) == expected_combinations

auc_values = display.roc_auc["roc_auc"].values
assert all(0 <= auc <= 1 for auc in auc_values)
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,92 @@ def test_multiclass_classification(pyplot, multiclass_classification_data):
assert display.ax_.get_title() == "ROC Curve"


def test_data_source_binary_classification(pyplot, binary_classification_data):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as in previous file: checking that the outputs are different.

"""Test data_source in ROC plot for ComparisonReport."""
estimator, X_train, X_test, y_train, y_test = binary_classification_data
estimator_2 = clone(estimator).set_params(C=10).fit(X_train, y_train)

report = ComparisonReport(
reports={
"estimator_1": EstimatorReport(
estimator,
X_train=X_train,
y_train=y_train,
X_test=X_test,
y_test=y_test,
),
"estimator_2": EstimatorReport(
estimator_2,
X_train=X_train,
y_train=y_train,
X_test=X_test,
y_test=y_test,
),
}
)

display = report.metrics.roc(data_source="X_y", X=X_train, y=y_train)
assert display.data_source == "X_y"
display.plot()

display = report.metrics.roc(data_source="train")
assert display.data_source == "train"
display.plot()

display = report.metrics.roc(data_source="test")
assert display.data_source == "test"
display.plot()

train_auc = display.roc_auc["roc_auc"].values
assert len(train_auc) == 2
assert all(0 <= auc <= 1 for auc in train_auc)


def test_data_source_multiclass_classification(pyplot, multiclass_classification_data):
"""Test data_source in ROC plot for ComparisonReport with multiclass data"""
estimator, X_train, X_test, y_train, y_test = multiclass_classification_data
estimator_2 = clone(estimator).set_params(C=10).fit(X_train, y_train)

report = ComparisonReport(
reports={
"estimator_1": EstimatorReport(
estimator,
X_train=X_train,
y_train=y_train,
X_test=X_test,
y_test=y_test,
),
"estimator_2": EstimatorReport(
estimator_2,
X_train=X_train,
y_train=y_train,
X_test=X_test,
y_test=y_test,
),
}
)

class_labels = report.reports_[0].estimator_.classes_

display = report.metrics.roc(data_source="X_y", X=X_train, y=y_train)
assert display.data_source == "X_y"
display.plot()

display = report.metrics.roc(data_source="train")
assert display.data_source == "train"
display.plot()

display = report.metrics.roc(data_source="test")
assert display.data_source == "test"
display.plot()

expected_combinations = len(report.report_names_) * len(class_labels)
assert len(display.roc_auc) == expected_combinations

auc_values = display.roc_auc["roc_auc"].values
assert all(0 <= auc <= 1 for auc in auc_values)


def test_binary_classification_kwargs(pyplot, binary_classification_data):
"""Check that we can pass keyword arguments to the ROC curve plot for
cross-validation."""
Expand Down
10 changes: 10 additions & 0 deletions skore/tests/unit/sklearn/plot/roc_curve/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_binary_classification(pyplot, binary_classification_data):
assert display.ax_.get_aspect() in ("equal", 1.0)
assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01)
assert display.ax_.get_title() == "ROC Curve for LogisticRegression"
assert display.data_source == "test"


def test_multiclass_classification(pyplot, multiclass_classification_data):
Expand Down Expand Up @@ -108,6 +109,7 @@ def test_multiclass_classification(pyplot, multiclass_classification_data):
assert display.ax_.get_aspect() in ("equal", 1.0)
assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01)
assert display.ax_.get_title() == "ROC Curve for LogisticRegression"
assert display.data_source == "test"


def test_data_source_binary_classification(pyplot, binary_classification_data):
Expand All @@ -116,6 +118,14 @@ def test_data_source_binary_classification(pyplot, binary_classification_data):
report = EstimatorReport(
estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
)
display = report.metrics.roc(data_source="X_y", X=X_train, y=y_train)
assert display.data_source == "X_y"
display.plot()
assert (
display.lines_[0].get_label()
== f"AUC = {display.roc_auc['roc_auc'].item():0.2f}"
)

display = report.metrics.roc(data_source="train")
display.plot()
assert (
Expand Down