-
Notifications
You must be signed in to change notification settings - Fork 948
Support rolling aggregations in in-memory cudf-polars execution #18681
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support rolling aggregations in in-memory cudf-polars execution #18681
Conversation
f46043d
to
a279663
Compare
8f5fc2f
to
e6b72ed
Compare
if is_top: | ||
# In polars sum(empty_group) => 0, but in libcudf sum(empty_group) => null | ||
# So must post-process by replacing nulls, but only if we're a "top-level" agg. | ||
rep = expr.Literal( | ||
agg.dtype, pa.scalar(0, type=plc.interop.to_arrow(agg.dtype)) | ||
) | ||
return ( | ||
[named_expr], | ||
named_expr.reconstruct( | ||
expr.UnaryFunction(agg.dtype, "fill_null", (), col, rep) | ||
), | ||
True, | ||
) | ||
else: | ||
return [named_expr], expr.NamedExpr(name, col), True | ||
return [(named_expr, True)], expr.NamedExpr( | ||
name, | ||
# In polars sum(empty_group) => 0, but in libcudf | ||
# sum(empty_group) => null So must post-process by | ||
# replacing nulls, but only if we're a "top-level" | ||
# agg. | ||
replace_nulls( | ||
col, | ||
pa.scalar(0, type=plc.interop.to_arrow(agg.dtype)), | ||
is_top=is_top, | ||
), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simple refactor now that replace_nulls
is used in two places.
def replace(nodes: Sequence[NodeT], replacements: Mapping[NodeT, NodeT]) -> list[NodeT]: | ||
""" | ||
Replace nodes in expressions. | ||
|
||
Parameters | ||
---------- | ||
nodes | ||
Sequence of nodes to perform replacements in. | ||
replacements | ||
Mapping from nodes to be replaced to their replacements. | ||
|
||
Returns | ||
------- | ||
list | ||
Of nodes with replacements performed. | ||
""" | ||
mapper: GenericTransformer[NodeT, NodeT] = CachingVisitor( | ||
_replace, state={"replacements": replacements} | ||
) | ||
return [mapper(node) for node in nodes] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I use this when rewriting the rolling expression.
__all__ = ["rewrite_rolling"] | ||
|
||
|
||
def rewrite_rolling( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same idea as the groupby, but with slightly different inputs to the agg decomposition.
def duration_to_int( | ||
dtype: plc.DataType, | ||
months: int, | ||
weeks: int, | ||
days: int, | ||
nanoseconds: int, | ||
parsed_int: bool, # noqa: FBT001 | ||
negative: bool, # noqa: FBT001 | ||
) -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like a Duration
object in libcudf so I can say Add 1 week to this date
, but that doesn't exist so we need to convert to single duration
s.
@@ -132,7 +132,6 @@ def pytest_configure(config: pytest.Config) -> None: | |||
"tests/unit/lazyframe/test_lazyframe.py::test_round[dtype1-123.55-1-123.6]": "Rounding midpoints is handled incorrectly", | |||
"tests/unit/lazyframe/test_lazyframe.py::test_cast_frame": "Casting that raises not supported on GPU", | |||
"tests/unit/lazyframe/test_lazyframe.py::test_lazy_cache_hit": "Debug output on stderr doesn't match", | |||
"tests/unit/operations/aggregation/test_aggregations.py::test_duration_function_literal": "Broadcasting inside groupby-agg not supported", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We now notice and raise, so fallback works.
@@ -176,8 +175,8 @@ def pytest_configure(config: pytest.Config) -> None: | |||
"tests/unit/operations/test_group_by.py::test_group_by_median_by_dtype[input16-expected16-input_dtype16-output_dtype16]": "Unsupported groupby-agg for a particular dtype", | |||
"tests/unit/operations/test_group_by.py::test_group_by_binary_agg_with_literal": "Incorrect broadcasting of literals in groupby-agg", | |||
"tests/unit/operations/test_group_by.py::test_group_by_lit_series": "Incorrect broadcasting of literals in groupby-agg", | |||
"tests/unit/operations/test_group_by.py::test_aggregated_scalar_elementwise_15602": "Unsupported boolean function/dtype combination in groupby-agg", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
likewise.
"tests/unit/operations/test_join.py::test_cross_join_slice_pushdown": "Need to implement slice pushdown for cross joins", | ||
"tests/unit/operations/test_rolling.py::test_rolling_group_by_empty_groups_by_take_6330": "Ordering difference, might be polars bug", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to open a polars issue to check.
Implicit casting to boolean of the aggregation::Kind enum meant that there was no compiler warning here, but the function always returned true.
Needed to determine the type of the orderby column for rolling windows.
If we have `col("a") + col("b").max()` the aggregated column should be broadcast across the collected list column and summed, but we do not support this, so notice and raise.
We can't use expressions as strings in the test names because those sometimes have object addresses in them.
e6b72ed
to
e17572c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving C++ changes
/merge |
Description
Building on the groupby rewrite infrastructure, we pull essentially the same trick for rolling aggregation.
Checklist