@@ -202,7 +202,17 @@ def _validate_reports(
202202 else :
203203 deduped_report_names = report_names
204204
205- return reports_list , deduped_report_names
205+ if isinstance (reports_list [0 ], CrossValidationReport ):
206+ reports_type = "CrossValidationReport"
207+ elif isinstance (reports_list [0 ], EstimatorReport ):
208+ reports_type = "EstimatorReport"
209+ else :
210+ raise TypeError (
211+ "Report type is undetermined. "
212+ "This error should have been caught during validation."
213+ )
214+
215+ return reports_list , deduped_report_names , reports_type
206216
207217 def __init__ (
208218 self ,
@@ -226,17 +236,9 @@ def __init__(
226236 - all estimators have non-empty X_test and y_test,
227237 - all estimators have the same X_test and y_test.
228238 """
229- self .reports_ , self .report_names_ = ComparisonReport ._validate_reports (reports )
230-
231- if isinstance (self .reports_ [0 ], CrossValidationReport ):
232- self ._reports_type = "CrossValidationReport"
233- elif isinstance (self .reports_ [0 ], EstimatorReport ):
234- self ._reports_type = "EstimatorReport"
235- else :
236- raise TypeError (
237- "Report type is undetermined. "
238- "This error should have been caught during validation."
239- )
239+ self .reports_ , self .report_names_ , self ._reports_type = (
240+ ComparisonReport ._validate_reports (reports )
241+ )
240242
241243 self ._progress_info : Optional [dict [str , Any ]] = None
242244 self ._parent_progress = None
0 commit comments