Skip to content

Commit e2fcf12

Browse files
authored
Introduce basic "cudf" backend for Dask Expressions (#14805)
Mostly addresses #15027 dask/dask-expr#728 exposed the necessary mechanisms for us to define a custom dask-expr backend for `cudf`. The new dispatching mechanisms are effectively the same as those in `dask.dataframe`. The only difference is that we are now registering/implementing "expression-based" collections. This PR does the following: - Defines a basic `DataFrameBackendEntrypoint` class for collection creation, and registers new collections using `get_collection_type`. - Refactors the `dask_cudf` import structure to properly support the `"dataframe.query-planning"` configuration. - Modifies CI to test dask-expr support for some of the `dask_cudf` tests. This coverage can be expanded in follow-up work. ~**Experimental Change**: This PR patches `dask_expr._expr.Expr.__new__` to enable type-based dispatching. This effectively allows us to surgically replace problematic `Expr` subclasses that do not work for cudf-backed data. For example, this PR replaces the upstream `TakeLast` expression to avoid using `squeeze` (since this method is not supported by cudf). This particular fix can be moved upstream relatively easily. However, having this kind of "patching" mechanism may be valuable for more complicated pandas/cudf discrepancies.~ ## Usage example ```python from dask import config config.set({"dataframe.query-planning": True}) import dask_cudf df = dask_cudf.DataFrame.from_dict( {"x": range(100), "y": [1, 2, 3, 4] * 25, "z": ["1", "2"] * 50}, npartitions=10, ) df["y2"] = df["x"] + df["y"] agg = df.groupby("y").agg({"y2": "mean"})["y2"] agg.simplify().pprint() ``` Dask cuDF should now be using dask-expr for "query planning": ``` Projection: columns='y2' GroupbyAggregation: arg={'y2': 'mean'} observed=True split_out=1'y' Assign: y2= Projection: columns=['y'] FromPandas: frame='<dataframe>' npartitions=10 columns=['x', 'y'] Add: Projection: columns='x' FromPandas: frame='<dataframe>' npartitions=10 columns=['x', 'y'] Projection: columns='y' FromPandas: frame='<dataframe>' npartitions=10 columns=['x', 'y'] ``` ## TODO - [x] Add basic tests - [x] Confirm that general design makes sense **Follow Up Work**: - Expand dask-expr test coverage - Fix local and upstream bugs - Add documentation once "critical mass" is reached Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) - Lawrence Mitchell (https://github.com/wence-) - Vyas Ramasubramani (https://github.com/vyasr) - Bradley Dice (https://github.com/bdice) Approvers: - Lawrence Mitchell (https://github.com/wence-) - Ray Douglass (https://github.com/raydouglass) URL: #14805
1 parent 63c9ed7 commit e2fcf12

24 files changed

+545
-123
lines changed

ci/test_python_other.sh

+8
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@ rapids-logger "pytest dask_cudf"
2929
--cov-report=xml:"${RAPIDS_COVERAGE_DIR}/dask-cudf-coverage.xml" \
3030
--cov-report=term
3131

32+
# Run tests in dask_cudf/tests and dask_cudf/io/tests with dask-expr
33+
rapids-logger "pytest dask_cudf + dask_expr"
34+
DASK_DATAFRAME__QUERY_PLANNING=True ./ci/run_dask_cudf_pytests.sh \
35+
--junitxml="${RAPIDS_TESTS_DIR}/junit-dask-cudf-expr.xml" \
36+
--numprocesses=8 \
37+
--dist=loadscope \
38+
.
39+
3240
rapids-logger "pytest custreamz"
3341
./ci/run_custreamz_pytests.sh \
3442
--junitxml="${RAPIDS_TESTS_DIR}/junit-custreamz.xml" \

ci/test_wheel_dask_cudf.sh

+9
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,12 @@ python -m pytest \
3838
--numprocesses=8 \
3939
.
4040
popd
41+
42+
# Run tests in dask_cudf/tests and dask_cudf/io/tests with dask-expr
43+
rapids-logger "pytest dask_cudf + dask_expr"
44+
pushd python/dask_cudf/dask_cudf
45+
DASK_DATAFRAME__QUERY_PLANNING=True python -m pytest \
46+
--junitxml="${RAPIDS_TESTS_DIR}/junit-dask-cudf-expr.xml" \
47+
--numprocesses=8 \
48+
.
49+
popd
+54-8
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,75 @@
1-
# Copyright (c) 2018-2023, NVIDIA CORPORATION.
1+
# Copyright (c) 2018-2024, NVIDIA CORPORATION.
22

3+
from dask import config
4+
5+
# For dask>2024.2.0, we can silence the loud deprecation
6+
# warning before importing `dask.dataframe` (this won't
7+
# do anything for dask==2024.2.0)
8+
config.set({"dataframe.query-planning-warning": False})
9+
10+
import dask.dataframe as dd
311
from dask.dataframe import from_delayed
412

513
import cudf
614

715
from . import backends
816
from ._version import __git_commit__, __version__
9-
from .core import DataFrame, Series, concat, from_cudf, from_dask_dataframe
10-
from .groupby import groupby_agg
11-
from .io import read_csv, read_json, read_orc, read_text, to_orc
17+
from .core import concat, from_cudf, from_dask_dataframe
18+
from .expr import QUERY_PLANNING_ON
19+
20+
21+
def read_csv(*args, **kwargs):
22+
with config.set({"dataframe.backend": "cudf"}):
23+
return dd.read_csv(*args, **kwargs)
24+
25+
26+
def read_json(*args, **kwargs):
27+
with config.set({"dataframe.backend": "cudf"}):
28+
return dd.read_json(*args, **kwargs)
29+
30+
31+
def read_orc(*args, **kwargs):
32+
with config.set({"dataframe.backend": "cudf"}):
33+
return dd.read_orc(*args, **kwargs)
34+
35+
36+
def read_parquet(*args, **kwargs):
37+
with config.set({"dataframe.backend": "cudf"}):
38+
return dd.read_parquet(*args, **kwargs)
39+
40+
41+
def raise_not_implemented_error(attr_name):
42+
def inner_func(*args, **kwargs):
43+
raise NotImplementedError(
44+
f"Top-level {attr_name} API is not available for dask-expr."
45+
)
46+
47+
return inner_func
48+
49+
50+
if QUERY_PLANNING_ON:
51+
from .expr._collection import DataFrame, Index, Series
52+
53+
groupby_agg = raise_not_implemented_error("groupby_agg")
54+
read_text = raise_not_implemented_error("read_text")
55+
to_orc = raise_not_implemented_error("to_orc")
56+
else:
57+
from .core import DataFrame, Index, Series
58+
from .groupby import groupby_agg
59+
from .io import read_text, to_orc
1260

13-
try:
14-
from .io import read_parquet
15-
except ImportError:
16-
pass
1761

1862
__all__ = [
1963
"DataFrame",
2064
"Series",
65+
"Index",
2166
"from_cudf",
2267
"from_dask_dataframe",
2368
"concat",
2469
"from_delayed",
2570
]
2671

72+
2773
if not hasattr(cudf.DataFrame, "mean"):
2874
cudf.DataFrame.mean = None
2975
del cudf

python/dask_cudf/dask_cudf/backends.py

+59-4
Original file line numberDiff line numberDiff line change
@@ -627,13 +627,68 @@ def read_csv(*args, **kwargs):
627627

628628
@staticmethod
629629
def read_hdf(*args, **kwargs):
630-
from dask_cudf import from_dask_dataframe
631-
632630
# HDF5 reader not yet implemented in cudf
633631
warnings.warn(
634632
"read_hdf is not yet implemented in cudf/dask_cudf. "
635633
"Moving to cudf from pandas. Expect poor performance!"
636634
)
637-
return from_dask_dataframe(
638-
_default_backend(dd.read_hdf, *args, **kwargs)
635+
return _default_backend(dd.read_hdf, *args, **kwargs).to_backend(
636+
"cudf"
637+
)
638+
639+
640+
# Define "cudf" backend entrypoint for dask-expr
641+
class CudfDXBackendEntrypoint(DataFrameBackendEntrypoint):
642+
"""Backend-entrypoint class for Dask-Expressions
643+
644+
This class is registered under the name "cudf" for the
645+
``dask-expr.dataframe.backends`` entrypoint in ``setup.cfg``.
646+
Dask-DataFrame will use the methods defined in this class
647+
in place of ``dask_expr.<creation-method>`` when the
648+
"dataframe.backend" configuration is set to "cudf":
649+
650+
Examples
651+
--------
652+
>>> import dask
653+
>>> import dask_expr
654+
>>> with dask.config.set({"dataframe.backend": "cudf"}):
655+
... ddf = dx.from_dict({"a": range(10)})
656+
>>> type(ddf._meta)
657+
<class 'cudf.core.dataframe.DataFrame'>
658+
"""
659+
660+
@classmethod
661+
def to_backend_dispatch(cls):
662+
return CudfBackendEntrypoint.to_backend_dispatch()
663+
664+
@classmethod
665+
def to_backend(cls, *args, **kwargs):
666+
return CudfBackendEntrypoint.to_backend(*args, **kwargs)
667+
668+
@staticmethod
669+
def from_dict(
670+
data,
671+
npartitions,
672+
orient="columns",
673+
dtype=None,
674+
columns=None,
675+
constructor=cudf.DataFrame,
676+
):
677+
import dask_expr as dx
678+
679+
return _default_backend(
680+
dx.from_dict,
681+
data,
682+
npartitions=npartitions,
683+
orient=orient,
684+
dtype=dtype,
685+
columns=columns,
686+
constructor=constructor,
639687
)
688+
689+
690+
# Import/register cudf-specific classes for dask-expr
691+
try:
692+
import dask_cudf.expr # noqa: F401
693+
except ImportError:
694+
pass

python/dask_cudf/dask_cudf/core.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -685,18 +685,27 @@ def reduction(
685685

686686
@_dask_cudf_nvtx_annotate
687687
def from_cudf(data, npartitions=None, chunksize=None, sort=True, name=None):
688+
from dask_cudf import QUERY_PLANNING_ON
689+
688690
if isinstance(getattr(data, "index", None), cudf.MultiIndex):
689691
raise NotImplementedError(
690692
"dask_cudf does not support MultiIndex Dataframes."
691693
)
692694

693-
name = name or ("from_cudf-" + tokenize(data, npartitions or chunksize))
695+
# Dask-expr doesn't support the `name` argument
696+
name = {}
697+
if not QUERY_PLANNING_ON:
698+
name = {
699+
"name": name
700+
or ("from_cudf-" + tokenize(data, npartitions or chunksize))
701+
}
702+
694703
return dd.from_pandas(
695704
data,
696705
npartitions=npartitions,
697706
chunksize=chunksize,
698707
sort=sort,
699-
name=name,
708+
**name,
700709
)
701710

702711

@@ -711,7 +720,10 @@ def from_cudf(data, npartitions=None, chunksize=None, sort=True, name=None):
711720
rather than pandas objects.\n
712721
"""
713722
)
714-
+ textwrap.dedent(dd.from_pandas.__doc__)
723+
# TODO: `dd.from_pandas.__doc__` is empty when
724+
# `DASK_DATAFRAME__QUERY_PLANNING=True`
725+
# since dask-expr does not provide a docstring for from_pandas.
726+
+ textwrap.dedent(dd.from_pandas.__doc__ or "")
715727
)
716728

717729

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION.
2+
3+
from dask import config
4+
5+
# Check if dask-dataframe is using dask-expr.
6+
# For dask>=2024.3.0, a null value will default to True
7+
QUERY_PLANNING_ON = config.get("dataframe.query-planning", None) is not False
8+
9+
# Register custom expressions and collections
10+
try:
11+
import dask_cudf.expr._collection
12+
import dask_cudf.expr._expr
13+
14+
except ImportError as err:
15+
if QUERY_PLANNING_ON:
16+
# Dask *should* raise an error before this.
17+
# However, we can still raise here to be certain.
18+
raise RuntimeError(
19+
"Failed to register the 'cudf' backend for dask-expr."
20+
" Please make sure you have dask-expr installed.\n"
21+
f"Error Message: {err}"
22+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION.
2+
3+
from functools import cached_property
4+
5+
from dask_expr import (
6+
DataFrame as DXDataFrame,
7+
FrameBase,
8+
Index as DXIndex,
9+
Series as DXSeries,
10+
get_collection_type,
11+
)
12+
from dask_expr._collection import new_collection
13+
from dask_expr._util import _raise_if_object_series
14+
15+
from dask import config
16+
from dask.dataframe.core import is_dataframe_like
17+
18+
import cudf
19+
20+
##
21+
## Custom collection classes
22+
##
23+
24+
25+
# VarMixin can be removed if cudf#15179 is addressed.
26+
# See: https://github.com/rapidsai/cudf/issues/15179
27+
class VarMixin:
28+
def var(
29+
self,
30+
axis=0,
31+
skipna=True,
32+
ddof=1,
33+
numeric_only=False,
34+
split_every=False,
35+
**kwargs,
36+
):
37+
_raise_if_object_series(self, "var")
38+
axis = self._validate_axis(axis)
39+
self._meta.var(axis=axis, skipna=skipna, numeric_only=numeric_only)
40+
frame = self
41+
if is_dataframe_like(self._meta) and numeric_only:
42+
# Convert to pandas - cudf does something weird here
43+
index = self._meta.to_pandas().var(numeric_only=True).index
44+
frame = frame[list(index)]
45+
return new_collection(
46+
frame.expr.var(
47+
axis, skipna, ddof, numeric_only, split_every=split_every
48+
)
49+
)
50+
51+
52+
class DataFrame(VarMixin, DXDataFrame):
53+
@classmethod
54+
def from_dict(cls, *args, **kwargs):
55+
with config.set({"dataframe.backend": "cudf"}):
56+
return DXDataFrame.from_dict(*args, **kwargs)
57+
58+
def groupby(
59+
self,
60+
by,
61+
group_keys=True,
62+
sort=None,
63+
observed=None,
64+
dropna=None,
65+
**kwargs,
66+
):
67+
from dask_cudf.expr._groupby import GroupBy
68+
69+
if isinstance(by, FrameBase) and not isinstance(by, DXSeries):
70+
raise ValueError(
71+
f"`by` must be a column name or list of columns, got {by}."
72+
)
73+
74+
return GroupBy(
75+
self,
76+
by,
77+
group_keys=group_keys,
78+
sort=sort,
79+
observed=observed,
80+
dropna=dropna,
81+
**kwargs,
82+
)
83+
84+
85+
class Series(VarMixin, DXSeries):
86+
def groupby(self, by, **kwargs):
87+
from dask_cudf.expr._groupby import SeriesGroupBy
88+
89+
return SeriesGroupBy(self, by, **kwargs)
90+
91+
@cached_property
92+
def list(self):
93+
from dask_cudf.accessors import ListMethods
94+
95+
return ListMethods(self)
96+
97+
@cached_property
98+
def struct(self):
99+
from dask_cudf.accessors import StructMethods
100+
101+
return StructMethods(self)
102+
103+
104+
class Index(DXIndex):
105+
pass # Same as pandas (for now)
106+
107+
108+
get_collection_type.register(cudf.DataFrame, lambda _: DataFrame)
109+
get_collection_type.register(cudf.Series, lambda _: Series)
110+
get_collection_type.register(cudf.BaseIndex, lambda _: Index)

0 commit comments

Comments
 (0)