Skip to content

[WIP] Fix empty partition prediction with ParallelPostFit #912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 88 additions & 29 deletions dask_ml/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,23 +231,22 @@ def transform(self, X):
"""
self._check_method("transform")
X = self._check_array(X)
meta = self.transform_meta
output_meta = self.transform_meta

if isinstance(X, da.Array):
if meta is None:
meta = _get_output_dask_ar_meta_for_estimator(
if output_meta is None:
output_meta = _get_output_dask_ar_meta_for_estimator(
_transform, self._postfit_estimator, X
)
return X.map_blocks(
_transform, estimator=self._postfit_estimator, meta=meta
_transform, estimator=self._postfit_estimator, meta=output_meta
)
elif isinstance(X, dd._Frame):
if meta is None:
# dask-dataframe relies on dd.core.no_default
# for infering meta
meta = dd.core.no_default
return X.map_partitions(
_transform, estimator=self._postfit_estimator, meta=meta
return _get_output_df_for_estimator(
model_fn=_transform,
X=X,
output_meta=output_meta,
estimator=self._postfit_estimator,
)
else:
return _transform(X, estimator=self._postfit_estimator)
Expand Down Expand Up @@ -311,25 +310,30 @@ def predict(self, X):
"""
self._check_method("predict")
X = self._check_array(X)
meta = self.predict_meta
output_meta = self.predict_meta

if isinstance(X, da.Array):
if meta is None:
meta = _get_output_dask_ar_meta_for_estimator(
if output_meta is None:
output_meta = _get_output_dask_ar_meta_for_estimator(
_predict, self._postfit_estimator, X
)

result = X.map_blocks(
_predict, estimator=self._postfit_estimator, drop_axis=1, meta=meta
_predict,
estimator=self._postfit_estimator,
drop_axis=1,
meta=output_meta,
)
return result

elif isinstance(X, dd._Frame):
if meta is None:
meta = dd.core.no_default
return X.map_partitions(
_predict, estimator=self._postfit_estimator, meta=meta
return _get_output_df_for_estimator(
model_fn=_predict,
X=X,
output_meta=output_meta,
estimator=self._postfit_estimator,
)

else:
return _predict(X, estimator=self._postfit_estimator)

Expand All @@ -355,25 +359,26 @@ def predict_proba(self, X):

self._check_method("predict_proba")

meta = self.predict_proba_meta
output_meta = self.predict_proba_meta

if isinstance(X, da.Array):
if meta is None:
meta = _get_output_dask_ar_meta_for_estimator(
if output_meta is None:
output_meta = _get_output_dask_ar_meta_for_estimator(
_predict_proba, self._postfit_estimator, X
)
# XXX: multiclass
return X.map_blocks(
_predict_proba,
estimator=self._postfit_estimator,
meta=meta,
meta=output_meta,
chunks=(X.chunks[0], len(self._postfit_estimator.classes_)),
)
elif isinstance(X, dd._Frame):
if meta is None:
meta = dd.core.no_default
return X.map_partitions(
_predict_proba, estimator=self._postfit_estimator, meta=meta
return _get_output_df_for_estimator(
model_fn=_predict_proba,
X=X,
output_meta=output_meta,
estimator=self._postfit_estimator,
)
else:
return _predict_proba(X, estimator=self._postfit_estimator)
Expand Down Expand Up @@ -626,18 +631,63 @@ def _first_block(dask_object):
return dask_object


def _predict(part, estimator):
def _predict(part, estimator, output_meta=None):
if part.shape[0] == 0 and output_meta is not None:
empty_output = handle_empty_partitions(output_meta)
if empty_output is not None:
return empty_output
return estimator.predict(part)


def _predict_proba(part, estimator):
def _predict_proba(part, estimator, output_meta=None):
if part.shape[0] == 0 and output_meta is not None:
empty_output = handle_empty_partitions(output_meta)
if empty_output is not None:
return empty_output

return estimator.predict_proba(part)


def _transform(part, estimator):
def _transform(part, estimator, output_meta=None):
if part.shape[0] == 0 and output_meta is not None:
empty_output = handle_empty_partitions(output_meta)
if empty_output is not None:
return empty_output

return estimator.transform(part)


def handle_empty_partitions(output_meta):
if hasattr(output_meta, "__array_function__"):
if len(output_meta.shape) == 1:
shape = 0
else:
shape = list(output_meta.shape)
shape[0] = 0
ar = np.zeros(
shape=shape,
dtype=output_meta.dtype,
like=output_meta,
)
return ar
elif "scipy.sparse" in type(output_meta).__module__:
# sparse matrices dont support
# `like` due to non implimented __array_function__
# Refer https://github.com/scipy/scipy/issues/10362
# Note below works for both cupy and scipy sparse matrices
# TODO: REMOVE code duplication
if len(ar.shape) == 1:
shape = 0
else:
shape = list(ar.shape)
shape[0] = 0

ar = type(output_meta)(shape, dtype=output_meta.dtype)
return ar
elif hasattr(output_meta, "iloc"):
return output_meta.iloc[:0, :]


def _get_output_dask_ar_meta_for_estimator(model_fn, estimator, input_dask_ar):
"""
Returns the output metadata array
Expand Down Expand Up @@ -692,3 +742,12 @@ def _get_output_dask_ar_meta_for_estimator(model_fn, estimator, input_dask_ar):
warnings.warn(msg)
ar = np.zeros(shape=(1, input_dask_ar.shape[1]), dtype=input_dask_ar.dtype)
return model_fn(ar, estimator)


def _get_output_df_for_estimator(model_fn, X, output_meta, estimator):
if output_meta is None:
# dask-dataframe relies on dd.core.no_default
# for infering meta
output_meta = model_fn(X._meta_nonempty, estimator)

return X.map_partitions(model_fn, estimator, output_meta, meta=output_meta)
21 changes: 19 additions & 2 deletions tests/test_parallel_post_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def test_predict_meta_override():
# Failure when not proving predict_meta
# because of value dependent model
wrap = ParallelPostFit(base)
with pytest.raises(ValueError):
# TODO: Fix
with pytest.raises(IndexError):
wrap.predict(dd_X)

# Success when providing meta over-ride
Expand All @@ -89,7 +90,8 @@ def test_predict_proba_meta_override():
# Failure when not proving predict_proba_meta
# because of value dependent model
wrap = ParallelPostFit(base)
with pytest.raises(ValueError):
# TODO: Fix below
with pytest.raises(IndexError):
wrap.predict_proba(dd_X)

# Success when providing meta over-ride
Expand Down Expand Up @@ -289,3 +291,18 @@ def shape(self):
match="provide explicit `predict_proba_meta` to the dask_ml.wrapper",
):
clf.predict_proba(fake_dask_ar)


def test_predict_empty_partitions():
df = pd.DataFrame({"x": [1, 2, 3, 4, 5, 6, 7, 8], "y": [True, False] * 4})
ddf = dd.from_pandas(df, npartitions=4)

clf = ParallelPostFit(LogisticRegression())
clf = clf.fit(df[["x"]], df["y"])

ddf_with_empty_part = ddf[ddf.x < 5][["x"]]
result = clf.predict(ddf_with_empty_part).compute()

expected = clf.estimator.predict(ddf_with_empty_part.compute())

assert_eq_ar(result, expected)