Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
0f835d6
don't pass metric just to get n_curves
auguste-probabl May 6, 2025
c3c9af1
Move NotImplementedError statement
auguste-probabl May 6, 2025
7943eda
Turn internal roc curve data into DataFrames
auguste-probabl May 6, 2025
00ec307
add tests
auguste-probabl May 7, 2025
110c76a
refactor: Add ReportType type
auguste-probabl May 13, 2025
c01befa
Add "comparison-cross-validation" report type
auguste-probabl May 7, 2025
b386587
ComparisonReport[CVReport]: Make it possible to create RocCurveDisplay
auguste-probabl May 7, 2025
a2b2705
deal with multiclass
auguste-probabl May 7, 2025
5133f74
move legend in multiclass case so that it doesn't overlap
auguste-probabl May 12, 2025
2a90021
test(roc-curve): Split tests by report type
auguste-probabl May 13, 2025
e67dadc
test(roc-curve): Simplify test names
auguste-probabl May 14, 2025
04b8ed4
test(roc-curve): Move comparison-cross-validation tests with the others
auguste-probabl May 13, 2025
2c76973
Fix test
auguste-probabl May 14, 2025
80faa1b
roc-curve: Change LineCollection to Line2D
auguste-probabl May 15, 2025
e691ffa
refactor: Create _filter_by helper
auguste-probabl May 14, 2025
f98f423
refactor(roc-curve): Use _filter_by
auguste-probabl May 15, 2025
f17e64b
fix docstring
auguste-probabl May 23, 2025
0dd2463
replace _filter_by by .query
auguste-probabl May 23, 2025
ddbf71a
fix docstring
auguste-probabl May 26, 2025
0135f5f
use fixtures
auguste-probabl May 26, 2025
65fc418
rename test
auguste-probabl May 26, 2025
af3e9c0
refactor checks of display attributes
auguste-probabl May 26, 2025
17cf734
add fixtures
auguste-probabl May 26, 2025
e41c4c5
use .query
auguste-probabl May 26, 2025
0e37cca
test passing kwargs in comparison[cv] case; fix bugs
auguste-probabl May 26, 2025
768b9cd
use self.pos_label rather than defining a variable
auguste-probabl May 26, 2025
af4508d
fix bug
auguste-probabl May 26, 2025
0cb804c
replace .iloc[0] with .item()
auguste-probabl May 26, 2025
4b3530c
convert some columns to categories
auguste-probabl May 26, 2025
b660ee2
write query once and use it for roc_curve and roc_auc
auguste-probabl May 26, 2025
2ef044f
scale figsize
auguste-probabl May 26, 2025
660fc59
fix: Use `!r` to deal with string labels
auguste-probabl May 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 110 additions & 43 deletions skore/src/skore/sklearn/_comparison/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,9 +1263,6 @@ def _get_display(
display : display_class
The display.
"""
if self._parent._reports_type == "CrossValidationReport":
raise NotImplementedError()

if "seed" in display_kwargs and display_kwargs["seed"] is None:
cache_key = None
else:
Expand All @@ -1288,55 +1285,121 @@ def _get_display(
y_true: list[YPlotData] = []
y_pred: list[YPlotData] = []

for report, report_name in zip(
self._parent.reports_, self._parent.report_names_
):
report_X, report_y, _ = report.metrics._get_X_y_and_data_source_hash(
data_source=data_source,
X=X,
y=y,
)
if self._parent._reports_type == "EstimatorReport":
for report, report_name in zip(
self._parent.reports_, self._parent.report_names_
):
report_X, report_y, _ = (
report.metrics._get_X_y_and_data_source_hash(
data_source=data_source,
X=X,
y=y,
)
)

y_true.append(
YPlotData(
estimator_name=report_name,
split_index=None,
y=report_y,
y_true.append(
YPlotData(
estimator_name=report_name,
split_index=None,
y=report_y,
)
)
)
results = _get_cached_response_values(
cache=report._cache,
estimator_hash=report._hash,
estimator=report._estimator,
X=report_X,
response_method=response_method,
results = _get_cached_response_values(
cache=report._cache,
estimator_hash=report._hash,
estimator=report._estimator,
X=report_X,
response_method=response_method,
data_source=data_source,
data_source_hash=None,
pos_label=display_kwargs.get("pos_label"),
)
for key, value, is_cached in results:
if not is_cached:
report._cache[key] = value
if key[-1] != "predict_time":
y_pred.append(
YPlotData(
estimator_name=report_name,
split_index=None,
y=value,
)
)

progress.update(main_task, advance=1, refresh=True)

display = display_class._compute_data_for_display(
y_true=y_true,
y_pred=y_pred,
report_type="comparison-estimator",
estimators=[report.estimator_ for report in self._parent.reports_],
estimator_names=self._parent.report_names_,
ml_task=self._parent._ml_task,
data_source=data_source,
data_source_hash=None,
pos_label=display_kwargs.get("pos_label"),
**display_kwargs,
)
for key, value, is_cached in results:
if not is_cached:
report._cache[key] = value
if key[-1] != "predict_time":
y_pred.append(

else:
for report, report_name in zip(
self._parent.reports_, self._parent.report_names_
):
for split_index, estimator_report in enumerate(
report.estimator_reports_
):
report_X, report_y, _ = (
estimator_report.metrics._get_X_y_and_data_source_hash(
data_source=data_source,
X=X,
y=y,
)
)

y_true.append(
YPlotData(
estimator_name=report_name,
split_index=None,
y=value,
split_index=split_index,
y=report_y,
)
)
progress.update(main_task, advance=1, refresh=True)

display = display_class._compute_data_for_display(
y_true=y_true,
y_pred=y_pred,
report_type="comparison-estimator",
estimators=[report.estimator_ for report in self._parent.reports_],
estimator_names=self._parent.report_names_,
ml_task=self._parent._ml_task,
data_source=data_source,
**display_kwargs,
)
results = _get_cached_response_values(
cache=estimator_report._cache,
estimator_hash=estimator_report._hash,
estimator=estimator_report.estimator_,
X=report_X,
response_method=response_method,
data_source=data_source,
data_source_hash=None,
pos_label=display_kwargs.get("pos_label"),
)
for key, value, is_cached in results:
if not is_cached:
report._cache[key] = value
if key[-1] != "predict_time":
y_pred.append(
YPlotData(
estimator_name=report_name,
split_index=split_index,
y=value,
)
)

progress.update(main_task, advance=1, refresh=True)

display = display_class._compute_data_for_display(
y_true=y_true,
y_pred=y_pred,
report_type="comparison-cross-validation",
estimators=[
estimator_report.estimator_
for report in self._parent.reports_
for estimator_report in report.estimator_reports_
],
estimator_names=self._parent.report_names_,
ml_task=self._parent._ml_task,
data_source=data_source,
**display_kwargs,
)

if cache_key is not None:
# Unless seed is an int (i.e. the call is deterministic),
Expand Down Expand Up @@ -1476,6 +1539,8 @@ def precision_recall(
>>> display = comparison_report.metrics.precision_recall()
>>> display.plot()
"""
if self._parent._reports_type == "CrossValidationReport":
raise NotImplementedError()
response_method = ("predict_proba", "decision_function")
display_kwargs = {"pos_label": pos_label}
display = cast(
Expand Down Expand Up @@ -1560,6 +1625,8 @@ def prediction_error(
>>> display = comparison_report.metrics.prediction_error()
>>> display.plot(kind="actual_vs_predicted")
"""
if self._parent._reports_type == "CrossValidationReport":
raise NotImplementedError()
display_kwargs = {"subsample": subsample, "seed": seed}
display = cast(
PredictionErrorDisplay,
Expand Down
26 changes: 18 additions & 8 deletions skore/src/skore/sklearn/_plot/metrics/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
_validate_style_kwargs,
sample_mpl_colormap,
)
from skore.sklearn.types import MLTask, PositiveLabel, YPlotData
from skore.sklearn.types import MLTask, PositiveLabel, ReportType, YPlotData


class PrecisionRecallCurveDisplay(
Expand Down Expand Up @@ -81,7 +81,8 @@ class PrecisionRecallCurveDisplay(
ml_task : {"binary-classification", "multiclass-classification"}
The machine learning task.

report_type : {"comparison-estimator", "cross-validation", "estimator"}
report_type : {"comparison-cross-validation", "comparison-estimator", \
"cross-validation", "estimator"}
The type of report.

Attributes
Expand Down Expand Up @@ -121,7 +122,7 @@ def __init__(
pos_label: Optional[PositiveLabel],
data_source: Literal["train", "test", "X_y"],
ml_task: MLTask,
report_type: Literal["comparison-estimator", "cross-validation", "estimator"],
report_type: ReportType,
) -> None:
self.precision = precision
self.recall = recall
Expand Down Expand Up @@ -480,10 +481,15 @@ def plot(
if pr_curve_kwargs is None:
pr_curve_kwargs = self._default_pr_curve_kwargs

if self.ml_task == "binary-classification":
n_curves = len(self.average_precision[self.pos_label])
else:
n_curves = len(self.average_precision)

pr_curve_kwargs = self._validate_curve_kwargs(
curve_param_name="pr_curve_kwargs",
curve_kwargs=pr_curve_kwargs,
metric=self.average_precision,
n_curves=n_curves,
report_type=self.report_type,
)

Expand Down Expand Up @@ -512,10 +518,13 @@ def plot(
estimator_names=self.estimator_names,
pr_curve_kwargs=pr_curve_kwargs,
)
elif self.report_type == "comparison-cross-validation":
raise NotImplementedError()
else:
raise ValueError(
f"`report_type` should be one of 'estimator', 'cross-validation', "
f"or 'comparison-estimator'. Got '{self.report_type}' instead."
"`report_type` should be one of 'estimator', 'cross-validation', "
"'comparison-cross-validation' or 'comparison-estimator'. "
f"Got '{self.report_type}' instead."
)

xlabel = "Recall"
Expand All @@ -541,7 +550,7 @@ def _compute_data_for_display(
y_true: Sequence[YPlotData],
y_pred: Sequence[YPlotData],
*,
report_type: Literal["comparison-estimator", "cross-validation", "estimator"],
report_type: ReportType,
estimators: Sequence[BaseEstimator],
estimator_names: list[str],
ml_task: MLTask,
Expand All @@ -561,7 +570,8 @@ def _compute_data_for_display(
confidence values, or non-thresholded measure of decisions (as returned by
"decision_function" on some classifiers).

report_type : {"comparison-estimator", "cross-validation", "estimator"}
report_type : {"comparison-cross-validation", "comparison-estimator", \
"cross-validation", "estimator"}
The type of report.

estimators : list of estimator instances
Expand Down
13 changes: 9 additions & 4 deletions skore/src/skore/sklearn/_plot/metrics/prediction_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
_validate_style_kwargs,
sample_mpl_colormap,
)
from skore.sklearn.types import MLTask, YPlotData
from skore.sklearn.types import MLTask, ReportType, YPlotData

RangeData = namedtuple("RangeData", ["min", "max"])

Expand Down Expand Up @@ -62,7 +62,8 @@ class PredictionErrorDisplay(StyleDisplayMixin, HelpDisplayMixin):
ml_task : {"regression", "multioutput-regression"}
The machine learning task.

report_type : {"cross-validation", "estimator", "comparison-estimator"}
report_type : {"comparison-cross-validation", "comparison-estimator", \
"cross-validation", "estimator"}
The type of report.

Attributes
Expand Down Expand Up @@ -113,7 +114,7 @@ def __init__(
estimator_names: list[str],
data_source: Literal["train", "test", "X_y"],
ml_task: MLTask,
report_type: Literal["cross-validation", "estimator", "comparison-estimator"],
report_type: ReportType,
) -> None:
self.y_true = y_true
self.y_pred = y_pred
Expand Down Expand Up @@ -557,7 +558,7 @@ def _compute_data_for_display(
y_true: list[YPlotData],
y_pred: list[YPlotData],
*,
report_type: Literal["cross-validation", "estimator", "comparison-estimator"],
report_type: ReportType,
estimator_names: list[str],
ml_task: MLTask,
data_source: Literal["train", "test", "X_y"],
Expand All @@ -575,6 +576,10 @@ def _compute_data_for_display(
y_pred : list of array-like of shape (n_samples,)
Predicted target values.

report_type : {"comparison-cross-validation", "comparison-estimator", \
"cross-validation", "estimator"}
The type of report.

estimators : list of estimator instances
The estimators from which `y_pred` is obtained.

Expand Down
Loading