Skip to content

Commit a279663

Browse files
committed
WIP: Implement expression-based rolling
Still need to do the grouped version but that will be easier.
1 parent e11cb8e commit a279663

File tree

4 files changed

+193
-29
lines changed

4 files changed

+193
-29
lines changed

python/cudf_polars/cudf_polars/dsl/expressions/rolling.py

Lines changed: 93 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES.
22
# SPDX-License-Identifier: Apache-2.0
33
# TODO: remove need for this
44
# ruff: noqa: D101
@@ -8,24 +8,108 @@
88

99
from typing import TYPE_CHECKING, Any
1010

11-
from cudf_polars.dsl.expressions.base import Expr
11+
import pylibcudf as plc
12+
13+
from cudf_polars.containers import Column
14+
from cudf_polars.dsl import expr
15+
from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
16+
from cudf_polars.dsl.utils.windows import range_window_bounds
1217

1318
if TYPE_CHECKING:
14-
import pylibcudf as plc
19+
import pyarrow as pa
20+
21+
from cudf_polars.containers import DataFrame
22+
from cudf_polars.typing import ClosedInterval
23+
24+
__all__ = ["GroupedRollingWindow", "RollingWindow", "to_request"]
25+
26+
27+
def to_request(
28+
value: expr.Expr, orderby: Column, df: DataFrame
29+
) -> plc.rolling.RollingRequest:
30+
"""
31+
Produce a rolling request for evaluation with pylibcudf.
1532
16-
__all__ = ["GroupedRollingWindow", "RollingWindow"]
33+
Parameters
34+
----------
35+
value
36+
The expression to perform the rolling aggregation on.
37+
orderby
38+
Orderby column, used as input to the request when the aggregation is Len.
39+
df
40+
DataFrame used to evaluate the inputs to the aggregation.
41+
"""
42+
min_periods = 1
43+
if isinstance(value, expr.Len):
44+
# A count aggregation, we need a column so use the orderby column
45+
col = orderby
46+
elif isinstance(value, expr.Agg):
47+
child = value.children[0]
48+
col = child.evaluate(df, context=ExecutionContext.ROLLING)
49+
if value.name == "var":
50+
# Polars variance produces null if nvalues <= ddof
51+
# libcudf produces NaN. However, we can get the polars
52+
# behaviour by setting the minimum window size to ddof +
53+
# 1.
54+
min_periods = value.options + 1
55+
else:
56+
col = value.evaluate(df, context=ExecutionContext.ROLLING)
57+
return plc.rolling.RollingRequest(col.obj, min_periods, value.agg_request)
1758

1859

1960
class RollingWindow(Expr):
20-
__slots__ = ("options",)
21-
_non_child = ("dtype", "options")
61+
__slots__ = ("closed_window", "following", "orderby", "preceding")
62+
_non_child = ("dtype", "preceding", "following", "closed_window", "orderby")
2263

23-
def __init__(self, dtype: plc.DataType, options: Any, agg: Expr) -> None:
64+
def __init__(
65+
self,
66+
dtype: plc.DataType,
67+
preceding: pa.Scalar,
68+
following: pa.Scalar,
69+
closed_window: ClosedInterval,
70+
orderby: str,
71+
agg: Expr,
72+
) -> None:
2473
self.dtype = dtype
25-
self.options = options
74+
self.preceding = preceding
75+
self.following = following
76+
self.closed_window = closed_window
77+
self.orderby = orderby
2678
self.children = (agg,)
2779
self.is_pointwise = False
28-
raise NotImplementedError("Rolling window not implemented")
80+
if not plc.rolling.is_valid_rolling_aggregation(agg.dtype, agg.agg_request):
81+
raise NotImplementedError(f"Unsupported rolling aggregation {agg}")
82+
83+
def do_evaluate( # noqa: D102
84+
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
85+
) -> Column:
86+
if context != ExecutionContext.FRAME:
87+
raise RuntimeError("Rolling aggregation inside groupby/over/rolling")
88+
(agg,) = self.children
89+
orderby = df.column_map[self.orderby]
90+
preceding, following = range_window_bounds(
91+
self.preceding, self.following, self.closed_window
92+
)
93+
if orderby.obj.null_count() != 0:
94+
raise RuntimeError(
95+
f"Index column '{self.orderby}' in rolling may not contain nulls"
96+
)
97+
if not orderby.check_sorted(
98+
order=plc.types.Order.ASCENDING, null_order=plc.types.NullOrder.AFTER
99+
):
100+
raise RuntimeError(
101+
f"Index column '{self.orderby}' in rolling is not sorted, please sort first"
102+
)
103+
(result,) = plc.rolling.grouped_range_rolling_window(
104+
plc.Table([]),
105+
orderby.obj,
106+
plc.types.Order.ASCENDING,
107+
plc.types.NullOrder.AFTER,
108+
preceding,
109+
following,
110+
[to_request(agg, orderby, df)],
111+
).columns()
112+
return Column(result)
29113

30114

31115
class GroupedRollingWindow(Expr):

python/cudf_polars/cudf_polars/dsl/ir.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import cudf_polars.dsl.expr as expr
3232
from cudf_polars.containers import Column, DataFrame
33+
from cudf_polars.dsl.expressions import rolling
3334
from cudf_polars.dsl.nodebase import Node
3435
from cudf_polars.dsl.to_ast import to_ast, to_parquet_filter
3536
from cudf_polars.dsl.utils.windows import range_window_bounds
@@ -1016,22 +1017,7 @@ def do_evaluate(
10161017
df: DataFrame,
10171018
) -> DataFrame:
10181019
keys = broadcast(*(k.evaluate(df) for k in keys_in), target_length=df.num_rows)
1019-
requests: list[plc.rolling.RollingRequest] = []
1020-
names: list[str] = []
10211020
orderby = index.evaluate(df)
1022-
for request in aggs:
1023-
name = request.name
1024-
value = request.value
1025-
if isinstance(value, expr.Len):
1026-
# A count aggregation, we need a column so use the orderby column
1027-
col = orderby.obj
1028-
elif isinstance(value, expr.Agg):
1029-
(child,) = value.children
1030-
col = child.evaluate(df).obj
1031-
else:
1032-
col = value.evaluate(df).obj
1033-
requests.append(plc.rolling.RollingRequest(col, 1, value.agg_request))
1034-
names.append(name)
10351021
preceding_window, following_window = range_window_bounds(
10361022
preceding, following, closed_window
10371023
)
@@ -1048,19 +1034,23 @@ def do_evaluate(
10481034
values = plc.rolling.grouped_range_rolling_window(
10491035
plc.Table([k.obj for k in keys]),
10501036
orderby.obj,
1051-
orderby.order,
1037+
plc.types.Order.ASCENDING, # Polars requires ascending orderby.
10521038
plc.types.NullOrder.AFTER, # Doesn't matter, polars doesn't allow nulls
10531039
preceding_window,
10541040
following_window,
1055-
requests,
1041+
[rolling.to_request(request.value, orderby, df) for request in aggs],
10561042
)
10571043
return DataFrame(
10581044
itertools.chain(
10591045
keys,
10601046
[orderby],
10611047
(
10621048
Column(col, name=name)
1063-
for col, name in zip(values.columns(), names, strict=True)
1049+
for col, name in zip(
1050+
values.columns(),
1051+
(request.name for request in aggs),
1052+
strict=True,
1053+
)
10641054
),
10651055
)
10661056
).slice(zlice)

python/cudf_polars/cudf_polars/dsl/translate.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,13 @@
2222

2323
from cudf_polars.dsl import expr, ir
2424
from cudf_polars.dsl.to_ast import insert_colrefs
25+
from cudf_polars.dsl.traversal import traversal
26+
from cudf_polars.dsl.utils.aggregations import decompose_single_agg
2527
from cudf_polars.dsl.utils.groupby import rewrite_groupby
28+
from cudf_polars.dsl.utils.naming import unique_names
29+
from cudf_polars.dsl.utils.replace import replace
2630
from cudf_polars.dsl.utils.rolling import rewrite_rolling
31+
from cudf_polars.dsl.utils.windows import offsets_to_windows
2732
from cudf_polars.typing import Schema
2833
from cudf_polars.utils import config, dtypes, sorting
2934

@@ -51,6 +56,7 @@ def __init__(self, visitor: NodeTraverser, engine: GPUEngine):
5156
self.visitor = visitor
5257
self.config_options = config.ConfigOptions.from_polars_engine(engine)
5358
self.errors: list[Exception] = []
59+
self.schema_stack = []
5460

5561
def translate_ir(self, *, n: int | None = None) -> ir.IR:
5662
"""
@@ -106,10 +112,13 @@ def translate_ir(self, *, n: int | None = None) -> ir.IR:
106112
self.errors.append(e)
107113
return ir.ErrorNode(schema, str(e))
108114
try:
115+
self.schema_stack.append(schema)
109116
result = _translate_ir(node, self, schema)
110117
except Exception as e:
111118
self.errors.append(e)
112119
return ir.ErrorNode(schema, str(e))
120+
finally:
121+
self.schema_stack.pop()
113122
if any(
114123
isinstance(dtype, pl.Null)
115124
for dtype in pl.datatypes.unpack_dtypes(*polars_schema.values())
@@ -621,9 +630,44 @@ def _(node: pl_expr.Window, translator: Translator, dtype: plc.DataType) -> expr
621630
# TODO: raise in groupby?
622631
if isinstance(node.options, pl_expr.RollingGroupOptions):
623632
# pl.col("a").rolling(...)
624-
return expr.RollingWindow(
625-
dtype, node.options, translator.translate_expr(n=node.function)
633+
agg = translator.translate_expr(n=node.function)
634+
name_generator = unique_names(
635+
e.name for e in traversal([agg]) if isinstance(e, expr.Col)
626636
)
637+
named_aggs, named_post_agg, _ = decompose_single_agg(
638+
expr.NamedExpr(next(name_generator), agg), name_generator, is_top=True
639+
)
640+
orderby = node.options.index_column
641+
preceding, following = offsets_to_windows(
642+
translator.schema_stack[-1][orderby],
643+
node.options.offset,
644+
node.options.period,
645+
)
646+
closed_window = node.options.closed_window
647+
if isinstance(named_post_agg.value, expr.Col):
648+
(named_agg,) = named_aggs
649+
return expr.RollingWindow(
650+
named_agg.value.dtype,
651+
preceding,
652+
following,
653+
closed_window,
654+
orderby,
655+
named_agg.value,
656+
)
657+
return replace(
658+
[named_post_agg.value], # type: ignore[misc]
659+
{
660+
expr.Col(agg.value.dtype, agg.name): expr.RollingWindow(
661+
agg.value.dtype,
662+
preceding,
663+
following,
664+
closed_window,
665+
orderby,
666+
agg.value,
667+
)
668+
for agg in named_aggs
669+
},
670+
)[0]
627671
elif isinstance(node.options, pl_expr.WindowMapping):
628672
# pl.col("a").over(...)
629673
return expr.GroupedRollingWindow(
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Utilities for replacing nodes in a DAG."""
5+
6+
from __future__ import annotations
7+
8+
from typing import TYPE_CHECKING
9+
10+
from cudf_polars.dsl.traversal import CachingVisitor, reuse_if_unchanged
11+
12+
if TYPE_CHECKING:
13+
from collections.abc import Mapping, Sequence
14+
15+
from cudf_polars.typing import GenericTransformer, NodeT
16+
17+
__all__ = ["replace"]
18+
19+
20+
def _replace(node: NodeT, fn: GenericTransformer[NodeT, NodeT]) -> NodeT:
21+
try:
22+
return fn.state["replacements"][node]
23+
except KeyError:
24+
return reuse_if_unchanged(node, fn)
25+
26+
27+
def replace(nodes: Sequence[NodeT], replacements: Mapping[NodeT, NodeT]) -> list[NodeT]:
28+
"""
29+
Replace nodes in expressions.
30+
31+
Parameters
32+
----------
33+
nodes
34+
Sequence of nodes to perform replacements in.
35+
replacements
36+
Mapping from nodes to be replaced to their replacements.
37+
38+
Returns
39+
-------
40+
list
41+
Of nodes with replacements performed.
42+
"""
43+
mapper: GenericTransformer[NodeT, NodeT] = CachingVisitor(
44+
_replace, state={"replacements": replacements}
45+
)
46+
return [mapper(node) for node in nodes]

0 commit comments

Comments
 (0)