Skip to content

Commit 7cbf6ce

Browse files
Pranav ChoudharyPranav Choudhary
authored andcommitted
[ENH] Add return_raw parameter to Benchmarking.run() for sktime Evaluator compatibility (#125)
1 parent d2b2777 commit 7cbf6ce

1 file changed

Lines changed: 85 additions & 17 deletions

File tree

pyaptamer/benchmarking/_base.py

Lines changed: 85 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class Benchmarking:
3636
Attributes
3737
----------
3838
results : pd.DataFrame
39-
DataFrame produced by :meth:`run`.
39+
Summary DataFrame produced by :meth:`run`.
4040
4141
- Index: pandas.MultiIndex with two levels (names shown in parentheses)
4242
- level 0 "estimator": estimator name
@@ -46,6 +46,15 @@ class Benchmarking:
4646
- "train" = mean of cross_validate(...)[f"train_{metric}"]
4747
- "test" = mean of cross_validate(...)[f"test_{metric}"]
4848
49+
raw_results_ : pd.DataFrame or None
50+
Per-fold scores, populated after every :meth:`run` call.
51+
52+
- Index: pandas.MultiIndex with three levels
53+
- level 0 "estimator": estimator name
54+
- level 1 "metric": evaluator name
55+
- level 2 "fold": fold index (0-based)
56+
- Columns: ["train", "test"] (both floats)
57+
4958
Example
5059
-------
5160
>>> import numpy as np
@@ -73,13 +82,15 @@ class Benchmarking:
7382
>>> summary = bench.run() # doctest: +SKIP
7483
"""
7584

76-
def __init__(self, estimators, metrics, X, y, cv=None):
85+
def __init__(self, estimators, metrics, X, y, cv=None, labels=None):
7786
self.estimators = estimators if isinstance(estimators, list) else [estimators]
7887
self.metrics = metrics if isinstance(metrics, list) else [metrics]
7988
self.X = X
8089
self.y = y
8190
self.cv = cv
91+
self.labels = labels
8292
self.results = None
93+
self.raw_results_ = None
8394

8495
def _to_scorers(self, metrics):
8596
"""Convert metric callables to a dict of scorers."""
@@ -96,7 +107,7 @@ def _to_scorers(self, metrics):
96107
return scorers
97108

98109
def _to_df(self, results):
99-
"""Convert nested results to a unified DataFrame."""
110+
"""Convert nested mean results to a summary DataFrame."""
100111
records = []
101112
index = []
102113

@@ -108,28 +119,74 @@ def _to_df(self, results):
108119
index = pd.MultiIndex.from_tuples(index, names=["estimator", "metric"])
109120
return pd.DataFrame(records, index=index, columns=["train", "test"])
110121

111-
def run(self):
122+
def _to_raw_df(self, raw_results):
123+
"""Convert nested per-fold results to a raw DataFrame."""
124+
records = []
125+
index = []
126+
127+
for est_name, est_scores in raw_results.items():
128+
for metric_name, fold_scores in est_scores.items():
129+
for fold_idx, (train_score, test_score) in enumerate(
130+
zip(fold_scores["train"], fold_scores["test"])
131+
):
132+
records.append({"train": train_score, "test": test_score})
133+
index.append((est_name, metric_name, fold_idx))
134+
135+
index = pd.MultiIndex.from_tuples(
136+
index, names=["estimator", "metric", "fold"]
137+
)
138+
return pd.DataFrame(records, index=index, columns=["train", "test"])
139+
140+
def run(self, return_raw=False):
112141
"""
113142
Train each estimator and evaluate with cross-validation.
114143
144+
Parameters
145+
----------
146+
return_raw : bool, default=False
147+
If `False` (default), returns only the summary DataFrame.
148+
If `True`, also returns `raw_results_` as the second element
149+
of a tuple, containing per-fold scores keyed by
150+
`(estimator, metric, fold)`.
151+
115152
Returns
116153
-------
117154
results : pd.DataFrame
155+
Summary DataFrame with mean scores.
118156
119-
- Index: pandas.MultiIndex with two levels (names shown in parentheses)
120-
- level 0 "estimator": estimator name
121-
- level 1 "metric": evaluator name
122-
- Columns: ["train", "test"] (both floats)
123-
- Cell values: mean scores (float) computed across CV folds:
124-
- "train" = mean of cross_validate(...)[f"train_{metric}"]
125-
- "test" = mean of cross_validate(...)[f"test_{metric}"]
157+
- Index: pandas.MultiIndex `(estimator, metric)`
158+
- Columns: ["train", "test"] (floats)
126159
160+
(results, raw_results) : tuple[pd.DataFrame, pd.DataFrame]
161+
Returned only when `return_raw=True`. `raw_results` has a
162+
three-level MultiIndex `(estimator, metric, fold)` and contains
163+
the raw per-fold scores.
127164
"""
128165
self.scorers_ = self._to_scorers(self.metrics)
129166
results = {}
130-
131-
for estimator in self.estimators:
132-
est_name = estimator.__class__.__name__
167+
raw_results = {}
168+
169+
if self.labels is not None:
170+
if len(self.labels) != len(self.estimators):
171+
raise ValueError("Length of labels must match length of estimators.")
172+
names = self.labels
173+
else:
174+
counts = {}
175+
for est in self.estimators:
176+
name = est.__class__.__name__
177+
counts[name] = counts.get(name, 0) + 1
178+
179+
names = []
180+
seen = {}
181+
for est in self.estimators:
182+
name = est.__class__.__name__
183+
if counts[name] > 1:
184+
seen[name] = seen.get(name, 0) + 1
185+
names.append(f"{name}_{seen[name]}")
186+
else:
187+
names.append(name)
188+
189+
for estimator, est_name in zip(self.estimators, names):
133190

134191
cv_results = cross_validate(
135192
estimator,
@@ -140,15 +197,26 @@ def run(self):
140197
return_train_score=True,
141198
)
142199

143-
# average across folds
144200
est_scores = {}
201+
est_raw_scores = {}
145202
for metric in self.scorers_.keys():
203+
train_folds = cv_results[f"train_{metric}"]
204+
test_folds = cv_results[f"test_{metric}"]
146205
est_scores[metric] = {
147-
"train": float(np.mean(cv_results[f"train_{metric}"])),
148-
"test": float(np.mean(cv_results[f"test_{metric}"])),
206+
"train": float(np.mean(train_folds)),
207+
"test": float(np.mean(test_folds)),
208+
}
209+
est_raw_scores[metric] = {
210+
"train": train_folds.tolist(),
211+
"test": test_folds.tolist(),
149212
}
150213

151214
results[est_name] = est_scores
215+
raw_results[est_name] = est_raw_scores
152216

153217
self.results = self._to_df(results)
218+
self.raw_results_ = self._to_raw_df(raw_results)
219+
220+
if return_raw:
221+
return self.results, self.raw_results_
154222
return self.results

0 commit comments

Comments
 (0)