@@ -158,22 +158,20 @@ def cache_predictions(self, response_methods="auto", n_jobs=None):
158158 """
159159 if self ._ml_task in ("binary-classification" , "multiclass-classification" ):
160160 if response_methods == "auto" :
161- response_methods = ( "predict" ,)
161+ response_methods = [ "predict" ]
162162 if hasattr (self ._estimator , "predict_proba" ):
163- response_methods += ( "predict_proba" ,)
163+ response_methods += [ "predict_proba" ]
164164 if hasattr (self ._estimator , "decision_function" ):
165- response_methods += ( "decision_function" ,)
165+ response_methods += [ "decision_function" ]
166166 pos_labels = self ._estimator .classes_
167167 else :
168168 if response_methods == "auto" :
169- response_methods = ( "predict" ,)
169+ response_methods = [ "predict" ]
170170 pos_labels = [None ]
171171
172- data_sources = ("test" ,)
173- Xs = (self ._X_test ,)
172+ data_sources = [("test" , self ._X_test )]
174173 if self ._X_train is not None :
175- data_sources += ("train" ,)
176- Xs += (self ._X_train ,)
174+ data_sources += [("train" , self ._X_train )]
177175
178176 parallel = joblib .Parallel (n_jobs = n_jobs , return_as = "generator_unordered" )
179177 generator = parallel (
@@ -187,7 +185,7 @@ def cache_predictions(self, response_methods="auto", n_jobs=None):
187185 data_source = data_source ,
188186 )
189187 for response_method , pos_label , (data_source , X ) in product (
190- response_methods , pos_labels , zip ( data_sources , Xs )
188+ response_methods , pos_labels , data_sources
191189 )
192190 )
193191 # trigger the computation
0 commit comments