Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
153 commits
Select commit Hold shift + click to select a range
096e3e5
create CrossValidationComparisonReport
auguste-probabl Mar 24, 2025
08f78ef
rename estimator_reports_ to reports_
auguste-probabl Mar 26, 2025
c2d3bf7
Fix docstring examples
auguste-probabl Mar 26, 2025
12b0840
wip
auguste-probabl Mar 26, 2025
b917d6f
wip
auguste-probabl Mar 26, 2025
c9f4282
wip
auguste-probabl Mar 28, 2025
2ec229b
fix mypy, remove X and y from args
auguste-probabl Apr 1, 2025
4242a25
add test
auguste-probabl Apr 1, 2025
48e374a
fix imports in doctests
auguste-probabl Apr 1, 2025
175d1fa
turn off progress bar for internal CVReports
auguste-probabl Apr 1, 2025
d796300
fix doctests
auguste-probabl Apr 1, 2025
651898d
fix doctests
auguste-probabl Apr 2, 2025
5f42cfb
remove whitespace
auguste-probabl Apr 2, 2025
a1de5d0
remove X and y from metrics
auguste-probabl Apr 3, 2025
e6367c3
make default aggregate param ('mean', 'std')
auguste-probabl Apr 3, 2025
adf1461
fix doctest
auguste-probabl Apr 3, 2025
0f41c63
clean
auguste-probabl Apr 3, 2025
f26c7da
Add aggregate param to all metrics
auguste-probabl Apr 3, 2025
106a741
Replace models by Dummies in doctests, for faster testing
auguste-probabl Apr 3, 2025
08c1546
Factorize report in fixtures
auguste-probabl Apr 3, 2025
0b589a7
replace "random_state" with "seed"
auguste-probabl Apr 3, 2025
9653030
add timings
auguste-probabl Apr 3, 2025
f6798b8
Fix doctest
auguste-probabl Apr 4, 2025
4f04a13
Deduplicate report names
auguste-probabl Apr 4, 2025
aad6e12
Remove all doctests results in metrics_accessors
auguste-probabl Apr 4, 2025
59062ad
stop using `m1`, `m2`... as model names
auguste-probabl Apr 4, 2025
0fd33e3
rename
auguste-probabl Apr 4, 2025
f18e784
add `class_sep` in tests
auguste-probabl Apr 4, 2025
6f1d743
add test for regression
auguste-probabl Apr 4, 2025
ca89bc2
Add test
auguste-probabl Apr 4, 2025
135fae9
Add test
auguste-probabl Apr 4, 2025
ef74c24
make favorability work
auguste-probabl Apr 4, 2025
10a4477
replace custom 'timings' with report_metrics
auguste-probabl Apr 4, 2025
4a86bb4
add tests for each metric
auguste-probabl Apr 4, 2025
d23af94
move and rename tests
auguste-probabl Apr 4, 2025
b8fa36d
add docstrings
auguste-probabl Apr 4, 2025
dff1404
fix `timings` docstring
auguste-probabl Apr 4, 2025
327b1be
add test for `custom_metric`
auguste-probabl Apr 4, 2025
df819ef
test report itself
auguste-probabl Apr 4, 2025
1506d42
Remove plots implementation
auguste-probabl Apr 7, 2025
a04de43
test: fix tests
auguste-probabl Apr 7, 2025
df07fd7
avoid warnings about precision by changing DummyClassifier strategy
auguste-probabl Apr 7, 2025
18d33b3
simplify tests
auguste-probabl Apr 7, 2025
e50cac8
Back out "replace custom 'timings' with report_metrics"
auguste-probabl Apr 7, 2025
3d773e3
distinguish timing tests
auguste-probabl Apr 7, 2025
3c30987
Fix progress bar description
auguste-probabl Apr 7, 2025
878a149
ComparisonReport: Rename estimator_reports_ to reports_
auguste-probabl Apr 9, 2025
c5f9145
put validation logic in a static method
auguste-probabl Apr 9, 2025
28d3aaf
Put static methods before init
auguste-probabl Apr 9, 2025
13a3bef
ComparisonReport: put validation logic in static method
auguste-probabl Apr 9, 2025
e856fb8
make nested progress bar work
auguste-probabl Apr 9, 2025
3a018b9
Make it impossible to pass the same CVReport several times
auguste-probabl Apr 9, 2025
a179284
fix doctest
auguste-probabl Apr 9, 2025
5cf0c3b
Fix import
auguste-probabl Apr 10, 2025
3a56636
fix doctest
auguste-probabl Apr 9, 2025
e8e399e
move cross validation validation logic to ComparisonReport
auguste-probabl Apr 10, 2025
021f519
rename function
auguste-probabl Apr 10, 2025
c8286f6
pull out logic from _validate functions
auguste-probabl Apr 10, 2025
822d09b
add code path where reports_ are CrossValidationReports
auguste-probabl Apr 10, 2025
79f6571
inline ml task checking
auguste-probabl Apr 10, 2025
bb756d9
inline report_names deduplication
auguste-probabl Apr 10, 2025
458dc72
inline duplicates check
auguste-probabl Apr 10, 2025
c8a231e
inline _validate_* completely
auguste-probabl Apr 10, 2025
8272d23
encapsulate whole validation logic into static method
auguste-probabl Apr 10, 2025
b3e7fff
remove comments
auguste-probabl Apr 10, 2025
b6b6c8b
Add aggregate argument to ComparisonReport._MetricsAccessor
auguste-probabl Apr 10, 2025
aa2cc2e
Rename variable
auguste-probabl Apr 10, 2025
a48f4fc
Sync all Progresses
auguste-probabl Apr 10, 2025
ffbbff0
refactor results post-processing to static method
auguste-probabl Apr 10, 2025
0b811da
Copy-paste _combine_cross_validation_results
auguste-probabl Apr 10, 2025
39f18d4
introduce _reports_type private attribute
auguste-probabl Apr 10, 2025
7ea3276
combine results differentlly depending on compared reports type
auguste-probabl Apr 10, 2025
679e6f6
pass different kwargs depending on compared reports type
auguste-probabl Apr 10, 2025
ad5eff4
compute timings differently depending on compared reports type
auguste-probabl Apr 10, 2025
93fba0c
replace CrossValidationComparisonReport with ComparisonReport
auguste-probabl Apr 10, 2025
de94b97
rename dir
auguste-probabl Apr 10, 2025
6b8d537
move ComparisonReport of EstimatorReports tests to own directory
auguste-probabl Apr 10, 2025
c08d278
raise NotImplementedError for ComparisonReport[CVReport]
auguste-probabl Apr 10, 2025
c00bfda
showcase ComparisonReport[CrossValidation] in doctest
auguste-probabl Apr 10, 2025
635332b
remove CrossValidationComparison
auguste-probabl Apr 10, 2025
e5726aa
move _reports_type initialization to _validate_reports
auguste-probabl Apr 10, 2025
cb73925
fmt
auguste-probabl Apr 11, 2025
17d2357
mypy
auguste-probabl Apr 11, 2025
c3cb78b
refactor tests
auguste-probabl Apr 11, 2025
ce985e9
refactor tests
auguste-probabl Apr 11, 2025
560f738
properly catch passing `X_y` to ComparisonReport[CVReport].get_predic…
auguste-probabl Apr 11, 2025
14aa994
properly catch passing `X_y` to ComparisonReport[CVReport].metrics
auguste-probabl Apr 11, 2025
f211e50
fix
auguste-probabl Apr 11, 2025
c904f97
make sure that passing `aggregate` to ComparisonReport[EReport] works
auguste-probabl Apr 11, 2025
1338319
move reports_type initialization inside if block
auguste-probabl Apr 11, 2025
28835cb
remove dead code
auguste-probabl Apr 11, 2025
6865c74
increase coverage
auguste-probabl Apr 11, 2025
64286e9
remove useless type check
auguste-probabl Apr 11, 2025
df5917f
test that aggregate is used in cache
auguste-probabl Apr 11, 2025
af5ed29
make doctest more copy-pasteable
auguste-probabl Apr 11, 2025
d81c8a4
add reproducer for pickling issue
auguste-probabl Apr 14, 2025
573f247
add non-regression for pickling issue for cvreports just to be sure
auguste-probabl Apr 14, 2025
6e624fe
fix bug
auguste-probabl Apr 14, 2025
5b37f44
change the format of report_metrics to be more legible
auguste-probabl Apr 11, 2025
fe6b665
Update skore/src/skore/sklearn/_comparison/report.py
auguste-probabl Apr 18, 2025
8c82921
fix
auguste-probabl Apr 18, 2025
88a29e4
update docstring
auguste-probabl Apr 18, 2025
d23df29
update docstrings
auguste-probabl Apr 18, 2025
c234727
change impl of deduplication
auguste-probabl Apr 18, 2025
ce2d793
add docstring
auguste-probabl Apr 18, 2025
1ffebf7
update error message
auguste-probabl Apr 18, 2025
0c9dbd6
change error message
auguste-probabl Apr 18, 2025
bbe115f
add fixme comment
auguste-probabl Apr 18, 2025
d40adaf
test joblib roundtrip
auguste-probabl Apr 18, 2025
a33bb72
add failing test for checking that `aggregate` is used
auguste-probabl Apr 18, 2025
46eca9a
fix: pass `aggregate` down
auguste-probabl Apr 18, 2025
d7c5d72
refine docstrings
auguste-probabl Apr 18, 2025
a680c93
make all setups as identical as possible
auguste-probabl Apr 18, 2025
2aace72
create report fixture
auguste-probabl Apr 18, 2025
1e8a0dd
create report_classification fixture
auguste-probabl Apr 18, 2025
e53e0db
create report_regression fixture
auguste-probabl Apr 18, 2025
0ea12c6
create estimator_reports fixture
auguste-probabl Apr 18, 2025
abbf420
remove copy()
auguste-probabl Apr 18, 2025
3594b04
remove copy
auguste-probabl Apr 18, 2025
887ebbf
reshape metrics df when aggregate is not None
auguste-probabl Apr 18, 2025
ecc77d8
Update skore/src/skore/sklearn/_comparison/metrics_accessor.py
auguste-probabl Apr 22, 2025
683ffa9
merge main
auguste-probabl Apr 22, 2025
df117f1
Add X to ComparisonReport.get_predictions
auguste-probabl Apr 22, 2025
4ab4b3a
correct test_cache
auguste-probabl Apr 22, 2025
fc280f7
remove useless variables
auguste-probabl Apr 22, 2025
594bdc7
Move staticmethod to function
auguste-probabl Apr 22, 2025
11854fb
Stop casting dict keys to strings, raise instead
auguste-probabl Apr 22, 2025
7de215d
move _combine_* functions to their own utils module
auguste-probabl Apr 22, 2025
610c4dc
lint
auguste-probabl Apr 22, 2025
8232d08
fix sphinx reference
auguste-probabl Apr 22, 2025
6d10e64
complete docstring
auguste-probabl Apr 22, 2025
9d9b906
add docstring
auguste-probabl Apr 22, 2025
8194931
Add examples to combine_* functions
auguste-probabl Apr 22, 2025
70c0683
add docstring
auguste-probabl Apr 22, 2025
59ecfac
remove comments
auguste-probabl Apr 22, 2025
30ca76a
Rename argument
auguste-probabl Apr 23, 2025
08c77c0
rename function
auguste-probabl Apr 23, 2025
0270707
Remove intermediate variables
auguste-probabl Apr 23, 2025
617b680
apply suggestion
auguste-probabl Apr 23, 2025
8553574
remove copy
auguste-probabl Apr 23, 2025
0a2259f
refactor logic to function
auguste-probabl Apr 23, 2025
98f755b
Use CategoricalDType to get metric order
auguste-probabl Apr 23, 2025
3d8b772
lint
auguste-probabl Apr 23, 2025
9ce3298
Make ComparisonReport work for X_y
auguste-probabl Apr 23, 2025
4f8e229
refactor
auguste-probabl Apr 23, 2025
8532e06
move test utils to their own module
auguste-probabl Apr 22, 2025
aee5fd9
fix test_cache
auguste-probabl Apr 22, 2025
c24d953
fix mypy
auguste-probabl Apr 23, 2025
56b81fc
refactor common tests
auguste-probabl Apr 23, 2025
b58e2a8
move cross-validation tests
auguste-probabl Apr 23, 2025
5ad3ec6
fix typing for python 3.9
auguste-probabl Apr 23, 2025
8170b02
merge main
auguste-probabl Apr 29, 2025
19b6926
reapply diff from main
auguste-probabl Apr 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 200 additions & 75 deletions skore/src/skore/sklearn/_comparison/metrics_accessor.py

Large diffs are not rendered by default.

258 changes: 200 additions & 58 deletions skore/src/skore/sklearn/_comparison/report.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
from __future__ import annotations

import time
from collections import Counter
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast

import joblib
import numpy as np
from numpy.typing import ArrayLike

from skore.externals._pandas_accessors import DirNamesMixin
from skore.sklearn._base import _BaseReport
from skore.sklearn._cross_validation.report import CrossValidationReport
from skore.sklearn._estimator.report import EstimatorReport
from skore.utils._progress_bar import progress_decorator

if TYPE_CHECKING:
from skore.sklearn._estimator.metrics_accessor import _MetricsAccessor

ReportType = Literal["EstimatorReport", "CrossValidationReport"]


class ComparisonReport(_BaseReport, DirNamesMixin):
"""Report for comparison of instances of :class:`skore.EstimatorReport`.
"""Report for comparing reports.

This object can be used to compare several :class:`skore.EstimatorReport`s, or
several :class:`~skore.CrossValidationReport`s.

.. caution:: Reports passed to `ComparisonReport` are not copied. If you pass
a report to `ComparisonReport`, and then modify the report outside later, it
Expand All @@ -28,13 +35,9 @@ class ComparisonReport(_BaseReport, DirNamesMixin):

Parameters
----------
reports : list of :class:`~skore.EstimatorReport` instances or dict
Estimator reports to compare.

* If `reports` is a list, the class name of each estimator is used.
* If `reports` is a dict, it is expected to have estimator names as keys
and :class:`~skore.EstimatorReport` instances as values.
If the keys are not strings, they will be converted to strings.
reports : list of reports or dict
Reports to compare. If a dict, keys will be used to label the estimators;
if a list, the labels are computed from the estimator class names.

n_jobs : int, default=None
Number of jobs to run in parallel. Training the estimators and computing
Expand All @@ -46,11 +49,14 @@ class ComparisonReport(_BaseReport, DirNamesMixin):

Attributes
----------
estimator_reports_ : list of :class:`~skore.EstimatorReport`
The compared estimator reports.
reports_ : list of :class:`~skore.EstimatorReport` or list of
:class:`~skore.CrossValidationReport`
The compared reports.

report_names_ : list of str
The names of the compared estimator reports.
The names of the compared estimators. If the names are not customized (i.e. the
class names are used), a de-duplication process is used to make sure that the
names are distinct.

See Also
--------
Expand Down Expand Up @@ -85,80 +91,168 @@ class ComparisonReport(_BaseReport, DirNamesMixin):
... y_test=y_test
... )
>>> report = ComparisonReport([estimator_report_1, estimator_report_2])
>>> report.report_names_
['LogisticRegression_1', 'LogisticRegression_2']
>>> report = ComparisonReport(
... {"model1": estimator_report_1, "model2": estimator_report_2}
... )
>>> report.report_names_
['model1', 'model2']

>>> from sklearn.datasets import make_classification
>>> from sklearn.linear_model import LogisticRegression
>>> from skore import ComparisonReport, CrossValidationReport
>>> X, y = make_classification(random_state=42)
>>> estimator_1 = LogisticRegression()
>>> estimator_2 = LogisticRegression(C=2) # Different regularization
>>> report_1 = CrossValidationReport(estimator_1, X, y)
>>> report_2 = CrossValidationReport(estimator_2, X, y)
>>> report = ComparisonReport([report_1, report_2])
>>> report = ComparisonReport({"model1": report_1, "model2": report_2})
Comment on lines +101 to +111
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only docstring where we showcase ComparisonReport[CVReport]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As docstring it is fine. I will check in which example it is natural to use it.

"""

_ACCESSOR_CONFIG: dict[str, dict[str, str]] = {
"metrics": {"name": "metrics"},
}
metrics: _MetricsAccessor

def __init__(
self,
reports: Union[list[EstimatorReport], dict[str, EstimatorReport]],
*,
n_jobs: Optional[int] = None,
) -> None:
"""
ComparisonReport instance initializer.
_reports_type: ReportType

@staticmethod
def _validate_reports(
reports: Union[
list[EstimatorReport],
dict[str, EstimatorReport],
list[CrossValidationReport],
dict[str, CrossValidationReport],
],
) -> tuple[
Union[list[EstimatorReport], list[CrossValidationReport]],
list[str],
ReportType,
]:
"""Validate that reports are in the right format for comparison.

Notes
-----
We check that the estimator reports can be compared:
- all reports are estimator reports,
- all estimators are in the same ML use case,
- all estimators have non-empty X_test and y_test,
- all estimators have the same X_test and y_test.
Parameters
----------
reports : list of reports or dict
The reports to be validated.

Returns
-------
list of EstimatorReport or list of CrossValidationReport
The validated reports.
list of str
The report names, either taken from dict keys or computed from the estimator
class names.
{"EstimatorReport", "CrossValidationReport"}
The inferred type of the reports that will be compared.
"""
if not isinstance(reports, Iterable):
raise TypeError(f"Expected reports to be an iterable; got {type(reports)}")
raise TypeError(
f"Expected reports to be a list or dict; got {type(reports)}"
)

if len(reports) < 2:
raise ValueError("At least 2 instances of EstimatorReport are needed")
raise ValueError(
f"Expected at least 2 reports to compare; got {len(reports)}"
)

report_names = (
list(map(str, reports.keys())) if isinstance(reports, dict) else None
)
reports = list(reports.values()) if isinstance(reports, dict) else reports
if isinstance(reports, list):
report_names = None
reports_list = reports
else: # dict
report_names = list(reports.keys())
for key in report_names:
if not isinstance(key, str):
raise TypeError(
f"Expected all report names to be strings; got {type(key)}"
)
reports_list = cast(
Union[list[EstimatorReport], list[CrossValidationReport]],
list(reports.values()),
)

if not all(isinstance(report, EstimatorReport) for report in reports):
raise TypeError("Expected instances of EstimatorReport")
reports_type: ReportType
if all(isinstance(report, EstimatorReport) for report in reports_list):
reports_list = cast(list[EstimatorReport], reports_list)
reports_type = "EstimatorReport"

# FIXME: We should only check y_test since it is all we need to tell us
# whether we have a distinct ML task at hand.
test_dataset_hashes = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a

FIXME: We should only check that y_test since it the variable to tell us whether
we have a distinct ML task at hand.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, this constraint of "same test data" is not completely necessary to me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to just do the change?

Suggested change
test_dataset_hashes = {
test_dataset_hashes = {
joblib.hash(report.y_test) for report in reports_list
if report.y_test is not None)
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly. However, I would potentially make it in a subsequent PR for sure.

joblib.hash((report.X_test, report.y_test))
for report in reports_list
if not ((report.X_test is None) and (report.y_test is None))
}
if len(test_dataset_hashes) > 1:
raise ValueError(
"Expected all estimators to have the same testing data."
)

elif all(isinstance(report, CrossValidationReport) for report in reports_list):
reports_list = cast(list[CrossValidationReport], reports_list)
reports_type = "CrossValidationReport"
else:
raise TypeError(
f"Expected list or dict of {EstimatorReport.__name__} "
f"or list of dict of {CrossValidationReport.__name__}"
)

test_dataset_hashes = {
joblib.hash((report.X_test, report.y_test))
for report in reports
if not ((report.X_test is None) and (report.y_test is None))
}
if len(test_dataset_hashes) > 1:
raise ValueError("Expected all estimators to have the same testing data.")
if len(set(id(report) for report in reports_list)) < len(reports_list):
raise ValueError("Expected reports to be distinct objects")
Comment on lines +202 to +203
Copy link
Contributor Author

@auguste-probabl auguste-probabl Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new constraint results from #1536; indeed, if the same report is passed twice, when we reset the progress bar at the end of the first CVReport computation (see below), when the second CVReport computation starts, the progress object is set to None, whereas it should be set to the ComparisonReport's progress object.

self_obj._parent_progress = None
self_obj._progress_info = None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that right now this constraint is only really useful when comparing CVReports, not when comparing EstimatorReports. I didn't put it in the if-block because I can imagine that one day EstimatorReports might have more progress bars of their own.


ml_tasks = {report: report._ml_task for report in reports}
ml_tasks = {report: report._ml_task for report in reports_list}
if len(set(ml_tasks.values())) > 1:
raise ValueError(
f"Expected all estimators to have the same ML usecase; got {ml_tasks}"
)

if report_names is None:
self.report_names_ = [report.estimator_name_ for report in reports]
deduped_report_names = _deduplicate_report_names(
[report.estimator_name_ for report in reports_list]
)
else:
self.report_names_ = report_names
deduped_report_names = report_names

return reports_list, deduped_report_names, reports_type

self.estimator_reports_ = reports
def __init__(
self,
reports: Union[
list[EstimatorReport],
dict[str, EstimatorReport],
list[CrossValidationReport],
dict[str, CrossValidationReport],
],
*,
n_jobs: Optional[int] = None,
) -> None:
"""
ComparisonReport instance initializer.

Notes
-----
We check that the estimator reports can be compared:
- all reports are estimator reports,
- all estimators are in the same ML use case,
- all estimators have non-empty X_test and y_test,
- all estimators have the same X_test and y_test.
"""
self.reports_, self.report_names_, self._reports_type = (
ComparisonReport._validate_reports(reports)
)

# used to know if a parent launches a progress bar manager
self._progress_info: Optional[dict[str, Any]] = None
self._parent_progress = None

# NEEDED FOR METRICS ACCESSOR
self.n_jobs = n_jobs
self._rng = np.random.default_rng(time.time_ns())
self._hash = self._rng.integers(
low=np.iinfo(np.int64).min, high=np.iinfo(np.int64).max
)
self._cache: dict[tuple[Any, ...], Any] = {}
self._ml_task = self.estimator_reports_[0]._ml_task
self._ml_task = self.reports_[0]._ml_task

def clear_cache(self) -> None:
"""Clear the cache.
Expand Down Expand Up @@ -193,7 +287,7 @@ def clear_cache(self) -> None:
>>> report._cache
{}
"""
for report in self.estimator_reports_:
for report in self.reports_:
report.clear_cache()
self._cache = {}

Expand Down Expand Up @@ -222,7 +316,7 @@ def cache_predictions(
>>> from sklearn.datasets import make_classification
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.model_selection import train_test_split
>>> from skore import ComparisonReport
>>> from skore import ComparisonReport, EstimatorReport
>>> X, y = make_classification(random_state=42)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
>>> estimator_1 = LogisticRegression()
Expand Down Expand Up @@ -255,22 +349,21 @@ def cache_predictions(
progress = self._progress_info["current_progress"]
main_task = self._progress_info["current_task"]

total_estimators = len(self.estimator_reports_)
total_estimators = len(self.reports_)
progress.update(main_task, total=total_estimators)

for estimator_report in self.estimator_reports_:
for report in self.reports_:
# Pass the progress manager to child tasks
estimator_report._parent_progress = progress
estimator_report.cache_predictions(
response_methods=response_methods, n_jobs=n_jobs
)
report._parent_progress = progress
report.cache_predictions(response_methods=response_methods, n_jobs=n_jobs)
progress.update(main_task, advance=1, refresh=True)

def get_predictions(
self,
*,
data_source: Literal["train", "test", "X_y"],
response_method: Literal["predict", "predict_proba", "decision_function"],
X: Optional[ArrayLike] = None,
pos_label: Optional[Any] = None,
) -> ArrayLike:
"""Get estimator's predictions.
Expand All @@ -290,6 +383,10 @@ def get_predictions(
response_method : {"predict", "predict_proba", "decision_function"}
The response method to use.

X : array-like of shape (n_samples, n_features), optional
When `data_source` is "X_y", the input features on which to compute the
response method.

pos_label : int, float, bool or str, default=None
The positive class when it comes to binary classification. When
`response_method="predict_proba"`, it will select the column corresponding
Expand Down Expand Up @@ -343,9 +440,10 @@ def get_predictions(
report.get_predictions(
data_source=data_source,
response_method=response_method,
X=X,
pos_label=pos_label,
)
for report in self.estimator_reports_
for report in self.reports_
]

####################################################################################
Expand All @@ -363,3 +461,47 @@ def _get_help_legend(self) -> str:
def __repr__(self) -> str:
"""Return a string representation."""
return f"{self.__class__.__name__}(...)"


def _deduplicate_report_names(report_names: list[str]) -> list[str]:
"""De-duplicate report names that appear several times.

Leave the other report names alone.

Parameters
----------
report_names : list of str
The list of report names to be checked.

Returns
-------
list of str
The de-duplicated list of report names.

Examples
--------
>>> _deduplicate_report_names(['a', 'b'])
['a', 'b']
>>> _deduplicate_report_names(['a', 'a'])
['a_1', 'a_2']
>>> _deduplicate_report_names(['a', 'b', 'a'])
['a_1', 'b', 'a_2']
>>> _deduplicate_report_names(['a', 'b', 'a', 'b'])
['a_1', 'b_1', 'a_2', 'b_2']
>>> _deduplicate_report_names([])
[]
>>> _deduplicate_report_names(['a'])
['a']
"""
counts = Counter(report_names)
if len(report_names) == len(counts):
return report_names

names = report_names.copy()
seen: Counter = Counter()
for i in range(len(names)):
name = names[i]
seen[name] += 1
if counts[name] > 1:
names[i] = f"{name}_{seen[name]}"
return names
Loading