Skip to content

Commit 3bb2a29

Browse files
move _reports_type initialization to _validate_reports
1 parent 1c1f614 commit 3bb2a29

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

skore/src/skore/sklearn/_comparison/report.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,17 @@ def _validate_reports(
202202
else:
203203
deduped_report_names = report_names
204204

205-
return reports_list, deduped_report_names
205+
if isinstance(reports_list[0], CrossValidationReport):
206+
reports_type = "CrossValidationReport"
207+
elif isinstance(reports_list[0], EstimatorReport):
208+
reports_type = "EstimatorReport"
209+
else:
210+
raise TypeError(
211+
"Report type is undetermined. "
212+
"This error should have been caught during validation."
213+
)
214+
215+
return reports_list, deduped_report_names, reports_type
206216

207217
def __init__(
208218
self,
@@ -226,17 +236,9 @@ def __init__(
226236
- all estimators have non-empty X_test and y_test,
227237
- all estimators have the same X_test and y_test.
228238
"""
229-
self.reports_, self.report_names_ = ComparisonReport._validate_reports(reports)
230-
231-
if isinstance(self.reports_[0], CrossValidationReport):
232-
self._reports_type = "CrossValidationReport"
233-
elif isinstance(self.reports_[0], EstimatorReport):
234-
self._reports_type = "EstimatorReport"
235-
else:
236-
raise TypeError(
237-
"Report type is undetermined. "
238-
"This error should have been caught during validation."
239-
)
239+
self.reports_, self.report_names_, self._reports_type = (
240+
ComparisonReport._validate_reports(reports)
241+
)
240242

241243
self._progress_info: Optional[dict[str, Any]] = None
242244
self._parent_progress = None

0 commit comments

Comments
 (0)