Skip to content

Commit 4d35577

Browse files
authored
Support rolling aggregations in in-memory cudf-polars execution (#18681)
Building on the groupby rewrite infrastructure, we pull essentially the same trick for rolling aggregation. Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - Basit Ayantunde (https://github.com/lamarrr) - Tianyu Liu (https://github.com/kingcrimsontianyu) - Tom Augspurger (https://github.com/TomAugspurger) - David Wendt (https://github.com/davidwendt) - Matthew Roeschke (https://github.com/mroeschke) URL: #18681
1 parent 9d84875 commit 4d35577

File tree

17 files changed

+1229
-199
lines changed

17 files changed

+1229
-199
lines changed

cpp/src/rolling/detail/rolling_utils.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ struct is_supported_rolling_aggregation_impl {
3030
constexpr bool operator()() const noexcept
3131
{
3232
return (kind == aggregation::Kind::LEAD || kind == aggregation::Kind::LAG ||
33-
kind == aggregation::Kind::COLLECT_LIST || aggregation::Kind::COLLECT_SET) ||
33+
kind == aggregation::Kind::COLLECT_LIST || kind == aggregation::Kind::COLLECT_SET) ||
3434
corresponding_rolling_operator<T, kind>::type::is_supported();
3535
}
3636
};

python/cudf_polars/cudf_polars/containers/column.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,44 @@ def sorted_like(self, like: Column, /) -> Self:
177177
null_order=like.null_order,
178178
)
179179

180+
def check_sorted(
181+
self,
182+
*,
183+
order: plc.types.Order,
184+
null_order: plc.types.NullOrder,
185+
) -> bool:
186+
"""
187+
Check if the column is sorted.
188+
189+
Parameters
190+
----------
191+
order
192+
The requested sort order.
193+
null_order
194+
Where nulls sort to.
195+
196+
Returns
197+
-------
198+
True if the column is sorted, false otherwise.
199+
200+
Notes
201+
-----
202+
If the sortedness flag is not set, this launches a kernel to
203+
check sortedness.
204+
"""
205+
if self.obj.size() <= 1 or self.obj.size() == self.obj.null_count():
206+
return True
207+
if self.is_sorted == plc.types.Sorted.YES:
208+
return self.order == order and (
209+
self.obj.null_count() == 0 or self.null_order == null_order
210+
)
211+
if plc.sorting.is_sorted(plc.Table([self.obj]), [order], [null_order]):
212+
self.sorted = plc.types.Sorted.YES
213+
self.order = order
214+
self.null_order = null_order
215+
return True
216+
return False
217+
180218
def astype(self, dtype: plc.DataType) -> Column:
181219
"""
182220
Cast the column to as the requested dtype.

python/cudf_polars/cudf_polars/dsl/expressions/rolling.py

Lines changed: 110 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES.
22
# SPDX-License-Identifier: Apache-2.0
33
# TODO: remove need for this
44
# ruff: noqa: D101
@@ -8,24 +8,125 @@
88

99
from typing import TYPE_CHECKING, Any
1010

11-
from cudf_polars.dsl.expressions.base import Expr
11+
import pylibcudf as plc
12+
13+
from cudf_polars.containers import Column
14+
from cudf_polars.dsl import expr
15+
from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
16+
from cudf_polars.dsl.utils.windows import range_window_bounds
1217

1318
if TYPE_CHECKING:
14-
import pylibcudf as plc
19+
import pyarrow as pa
20+
21+
from cudf_polars.containers import DataFrame
22+
from cudf_polars.typing import ClosedInterval
23+
24+
__all__ = ["GroupedRollingWindow", "RollingWindow", "to_request"]
25+
26+
27+
def to_request(
28+
value: expr.Expr, orderby: Column, df: DataFrame
29+
) -> plc.rolling.RollingRequest:
30+
"""
31+
Produce a rolling request for evaluation with pylibcudf.
1532
16-
__all__ = ["GroupedRollingWindow", "RollingWindow"]
33+
Parameters
34+
----------
35+
value
36+
The expression to perform the rolling aggregation on.
37+
orderby
38+
Orderby column, used as input to the request when the aggregation is Len.
39+
df
40+
DataFrame used to evaluate the inputs to the aggregation.
41+
"""
42+
min_periods = 1
43+
if isinstance(value, expr.Len):
44+
# A count aggregation, we need a column so use the orderby column
45+
col = orderby
46+
elif isinstance(value, expr.Agg):
47+
child = value.children[0]
48+
col = child.evaluate(df, context=ExecutionContext.ROLLING)
49+
if value.name == "var":
50+
# Polars variance produces null if nvalues <= ddof
51+
# libcudf produces NaN. However, we can get the polars
52+
# behaviour by setting the minimum window size to ddof +
53+
# 1.
54+
min_periods = value.options + 1
55+
else:
56+
col = value.evaluate(
57+
df, context=ExecutionContext.ROLLING
58+
) # pragma: no cover; raise before we get here because we
59+
# don't do correct handling of empty groups
60+
return plc.rolling.RollingRequest(col.obj, min_periods, value.agg_request)
1761

1862

1963
class RollingWindow(Expr):
20-
__slots__ = ("options",)
21-
_non_child = ("dtype", "options")
64+
__slots__ = ("closed_window", "following", "orderby", "preceding")
65+
_non_child = ("dtype", "preceding", "following", "closed_window", "orderby")
2266

23-
def __init__(self, dtype: plc.DataType, options: Any, agg: Expr) -> None:
67+
def __init__(
68+
self,
69+
dtype: plc.DataType,
70+
preceding: pa.Scalar,
71+
following: pa.Scalar,
72+
closed_window: ClosedInterval,
73+
orderby: str,
74+
agg: Expr,
75+
) -> None:
2476
self.dtype = dtype
25-
self.options = options
77+
self.preceding = preceding
78+
self.following = following
79+
self.closed_window = closed_window
80+
self.orderby = orderby
2681
self.children = (agg,)
2782
self.is_pointwise = False
28-
raise NotImplementedError("Rolling window not implemented")
83+
if agg.agg_request.kind() == plc.aggregation.Kind.COLLECT_LIST:
84+
raise NotImplementedError(
85+
"Incorrect handling of empty groups for list collection"
86+
)
87+
if not plc.rolling.is_valid_rolling_aggregation(agg.dtype, agg.agg_request):
88+
raise NotImplementedError(f"Unsupported rolling aggregation {agg}")
89+
90+
def do_evaluate( # noqa: D102
91+
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
92+
) -> Column:
93+
if context != ExecutionContext.FRAME:
94+
raise RuntimeError(
95+
"Rolling aggregation inside groupby/over/rolling"
96+
) # pragma: no cover; translation raises first
97+
(agg,) = self.children
98+
orderby = df.column_map[self.orderby]
99+
# Polars casts integral orderby to int64, but only for calculating window bounds
100+
if (
101+
plc.traits.is_integral(orderby.obj.type())
102+
and orderby.obj.type().id() != plc.TypeId.INT64
103+
):
104+
orderby_obj = plc.unary.cast(orderby.obj, plc.DataType(plc.TypeId.INT64))
105+
else:
106+
orderby_obj = orderby.obj
107+
preceding, following = range_window_bounds(
108+
self.preceding, self.following, self.closed_window
109+
)
110+
if orderby.obj.null_count() != 0:
111+
raise RuntimeError(
112+
f"Index column '{self.orderby}' in rolling may not contain nulls"
113+
)
114+
if not orderby.check_sorted(
115+
order=plc.types.Order.ASCENDING, null_order=plc.types.NullOrder.BEFORE
116+
):
117+
raise RuntimeError(
118+
f"Index column '{self.orderby}' in rolling is not sorted, please sort first"
119+
)
120+
(result,) = plc.rolling.grouped_range_rolling_window(
121+
plc.Table([]),
122+
orderby_obj,
123+
plc.types.Order.ASCENDING,
124+
plc.types.NullOrder.BEFORE,
125+
preceding,
126+
following,
127+
[to_request(agg, orderby, df)],
128+
).columns()
129+
return Column(result)
29130

30131

31132
class GroupedRollingWindow(Expr):

0 commit comments

Comments
 (0)