1
1
import logging
2
2
from typing import TYPE_CHECKING
3
3
4
+ import numpy as np
4
5
from dask import delayed
5
6
6
7
from dask_sql .datacontainer import DataContainer
@@ -43,7 +44,7 @@ class CreateModelPlugin(BaseRelPlugin):
43
44
unsupervised algorithms). This means, you typically
44
45
want to set this parameter.
45
46
* wrap_predict: Boolean flag, whether to wrap the selected
46
- model with a :class:`dask_ml .wrappers.ParallelPostFit`.
47
+ model with a :class:`dask_sql.physical.rel.custom .wrappers.ParallelPostFit`.
47
48
Have a look into the
48
49
[dask-ml docu](https://ml.dask.org/meta-estimators.html#parallel-prediction-and-transformation)
49
50
to learn more about it. Defaults to false. Typically you set
@@ -165,10 +166,7 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai
165
166
model = Incremental (estimator = model )
166
167
167
168
if wrap_predict :
168
- try :
169
- from dask_ml .wrappers import ParallelPostFit
170
- except ImportError : # pragma: no cover
171
- raise ValueError ("Wrapping requires dask-ml to be installed." )
169
+ from dask_sql .physical .rel .custom .wrappers import ParallelPostFit
172
170
173
171
# When `wrap_predict` is set to True we train on single partition frames
174
172
# because this is only useful for non dask distributed models
@@ -183,7 +181,16 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai
183
181
184
182
delayed_model = [delayed (model .fit )(x_p , y_p ) for x_p , y_p in zip (X_d , y_d )]
185
183
model = delayed_model [0 ].compute ()
186
- model = ParallelPostFit (estimator = model )
184
+ if "sklearn" in model_class :
185
+ output_meta = np .array ([])
186
+ model = ParallelPostFit (
187
+ estimator = model ,
188
+ predict_meta = output_meta ,
189
+ predict_proba_meta = output_meta ,
190
+ transform_meta = output_meta ,
191
+ )
192
+ else :
193
+ model = ParallelPostFit (estimator = model )
187
194
188
195
else :
189
196
model .fit (X , y , ** fit_kwargs )
0 commit comments