Skip to content

Commit 3a97df2

Browse files
committed
chore: More readable version of iterating
1 parent 4d28474 commit 3a97df2

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

skore/src/skore/sklearn/_estimator/report.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -158,22 +158,20 @@ def cache_predictions(self, response_methods="auto", n_jobs=None):
158158
"""
159159
if self._ml_task in ("binary-classification", "multiclass-classification"):
160160
if response_methods == "auto":
161-
response_methods = ("predict",)
161+
response_methods = ["predict"]
162162
if hasattr(self._estimator, "predict_proba"):
163-
response_methods += ("predict_proba",)
163+
response_methods += ["predict_proba"]
164164
if hasattr(self._estimator, "decision_function"):
165-
response_methods += ("decision_function",)
165+
response_methods += ["decision_function"]
166166
pos_labels = self._estimator.classes_
167167
else:
168168
if response_methods == "auto":
169-
response_methods = ("predict",)
169+
response_methods = ["predict"]
170170
pos_labels = [None]
171171

172-
data_sources = ("test",)
173-
Xs = (self._X_test,)
172+
data_sources = [("test", self._X_test)]
174173
if self._X_train is not None:
175-
data_sources += ("train",)
176-
Xs += (self._X_train,)
174+
data_sources += [("train", self._X_train)]
177175

178176
parallel = joblib.Parallel(n_jobs=n_jobs, return_as="generator_unordered")
179177
generator = parallel(
@@ -187,7 +185,7 @@ def cache_predictions(self, response_methods="auto", n_jobs=None):
187185
data_source=data_source,
188186
)
189187
for response_method, pos_label, (data_source, X) in product(
190-
response_methods, pos_labels, zip(data_sources, Xs)
188+
response_methods, pos_labels, data_sources
191189
)
192190
)
193191
# trigger the computation

0 commit comments

Comments
 (0)