Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 12 additions & 39 deletions skore/src/skore/_sklearn/_comparison/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
ScoringName,
YPlotData,
)
from skore._utils._accessor import _check_supported_ml_task
from skore._utils._accessor import (
_check_any_sub_report_has_metric,
_check_supported_ml_task,
)
from skore._utils._fixes import _validate_joblib_parallel_params
from skore._utils._index import flatten_multi_index
from skore._utils._progress_bar import progress_decorator
Expand Down Expand Up @@ -357,11 +360,7 @@ def timings(
)
return timings

@available_if(
_check_supported_ml_task(
supported_ml_tasks=["binary-classification", "multiclass-classification"]
)
)
@available_if(_check_any_sub_report_has_metric("accuracy"))
def accuracy(
self,
*,
Expand Down Expand Up @@ -427,11 +426,7 @@ def accuracy(
aggregate=aggregate,
).frame()

@available_if(
_check_supported_ml_task(
supported_ml_tasks=["binary-classification", "multiclass-classification"]
)
)
@available_if(_check_any_sub_report_has_metric("precision"))
def precision(
self,
*,
Expand Down Expand Up @@ -535,11 +530,7 @@ class is set to the one provided when creating the report. If `None`,
aggregate=aggregate,
).frame()

@available_if(
_check_supported_ml_task(
supported_ml_tasks=["binary-classification", "multiclass-classification"]
)
)
@available_if(_check_any_sub_report_has_metric("recall"))
def recall(
self,
*,
Expand Down Expand Up @@ -644,9 +635,7 @@ class is set to the one provided when creating the report. If `None`,
aggregate=aggregate,
).frame()

@available_if(
_check_supported_ml_task(supported_ml_tasks=["binary-classification"])
)
@available_if(_check_any_sub_report_has_metric("brier_score"))
def brier_score(
self,
*,
Expand Down Expand Up @@ -712,11 +701,7 @@ def brier_score(
aggregate=aggregate,
).frame()

@available_if(
_check_supported_ml_task(
supported_ml_tasks=["binary-classification", "multiclass-classification"]
)
)
@available_if(_check_any_sub_report_has_metric("roc_auc"))
def roc_auc(
self,
*,
Expand Down Expand Up @@ -819,11 +804,7 @@ def roc_auc(
aggregate=aggregate,
).frame()

@available_if(
_check_supported_ml_task(
supported_ml_tasks=["binary-classification", "multiclass-classification"]
)
)
@available_if(_check_any_sub_report_has_metric("log_loss"))
def log_loss(
self,
*,
Expand Down Expand Up @@ -889,11 +870,7 @@ def log_loss(
aggregate=aggregate,
).frame()

@available_if(
_check_supported_ml_task(
supported_ml_tasks=["regression", "multioutput-regression"]
)
)
@available_if(_check_any_sub_report_has_metric("r2"))
def r2(
self,
*,
Expand Down Expand Up @@ -971,11 +948,7 @@ def r2(
aggregate=aggregate,
).frame()

@available_if(
_check_supported_ml_task(
supported_ml_tasks=["regression", "multioutput-regression"]
)
)
@available_if(_check_any_sub_report_has_metric("rmse"))
def rmse(
self,
*,
Expand Down
16 changes: 16 additions & 0 deletions skore/src/skore/_utils/_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,19 @@ def check(accessor: Any) -> bool:
return all(_check_has_coef(e) for e in parent_estimators)

return check


########################################################################################
# Accessor related to `ComparisonReport`
########################################################################################


def _check_any_sub_report_has_metric(metric: str) -> Callable[[Any], bool]:
"""Check whether any sub-report of the ComparisonReport supports `metric`."""

def check(accessor: Any) -> bool:
return any(
hasattr(report.metrics, metric) for report in accessor._parent.reports_
Copy link
Collaborator

@thomass-dev thomass-dev Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: you can factorize your function to check_any_sub_report_has_accessor with the following :

Suggested change
hasattr(report.metrics, metric) for report in accessor._parent.reports_
reduce(getattr, accessor.split("."), report)
for report in accessor._parent.reports_
@available_if(_check_any_sub_report_has_accessor("metrics.rmse"))

)

return check
Loading