|
19 | 19 | from narwhals._duckdb.expr_list import DuckDBExprListNamespace
|
20 | 20 | from narwhals._duckdb.expr_str import DuckDBExprStringNamespace
|
21 | 21 | from narwhals._duckdb.expr_struct import DuckDBExprStructNamespace
|
| 22 | +from narwhals._duckdb.utils import UnorderableWindowInputs |
22 | 23 | from narwhals._duckdb.utils import WindowInputs
|
23 | 24 | from narwhals._duckdb.utils import col
|
24 | 25 | from narwhals._duckdb.utils import ensure_type
|
|
41 | 42 | from narwhals._compliant.typing import EvalSeries
|
42 | 43 | from narwhals._duckdb.dataframe import DuckDBLazyFrame
|
43 | 44 | from narwhals._duckdb.namespace import DuckDBNamespace
|
| 45 | + from narwhals._duckdb.typing import UnorderableWindowFunction |
44 | 46 | from narwhals._duckdb.typing import WindowFunction
|
45 | 47 | from narwhals._expression_parsing import ExprMetadata
|
46 | 48 | from narwhals.dtypes import DType
|
@@ -79,6 +81,10 @@ def __init__(
|
79 | 81 | # This can only be set by `_with_window_function`.
|
80 | 82 | self._window_function: WindowFunction | None = None
|
81 | 83 |
|
| 84 | + # These can only be set by `_with_unorderable_window_function` |
| 85 | + self._unorderable_window_function: UnorderableWindowFunction | None = None |
| 86 | + self._previous_call: EvalSeries[DuckDBLazyFrame, duckdb.Expression] | None = None |
| 87 | + |
82 | 88 | def __call__(self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]:
|
83 | 89 | return self._call(df)
|
84 | 90 |
|
@@ -263,6 +269,22 @@ def _with_window_function(self, window_function: WindowFunction) -> Self:
|
263 | 269 | result._window_function = window_function
|
264 | 270 | return result
|
265 | 271 |
|
| 272 | + def _with_unorderable_window_function( |
| 273 | + self, |
| 274 | + unorderable_window_function: UnorderableWindowFunction, |
| 275 | + previous_call: EvalSeries[DuckDBLazyFrame, duckdb.Expression], |
| 276 | + ) -> Self: |
| 277 | + result = self.__class__( |
| 278 | + self._call, |
| 279 | + evaluate_output_names=self._evaluate_output_names, |
| 280 | + alias_output_names=self._alias_output_names, |
| 281 | + backend_version=self._backend_version, |
| 282 | + version=self._version, |
| 283 | + ) |
| 284 | + result._unorderable_window_function = unorderable_window_function |
| 285 | + result._previous_call = previous_call |
| 286 | + return result |
| 287 | + |
266 | 288 | @classmethod
|
267 | 289 | def _alias_native(cls, expr: duckdb.Expression, name: str) -> duckdb.Expression:
|
268 | 290 | return expr.alias(name)
|
@@ -495,6 +517,19 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
|
495 | 517 | window_function(WindowInputs(expr, partition_by, order_by))
|
496 | 518 | for expr in self._call(df)
|
497 | 519 | ]
|
| 520 | + elif ( |
| 521 | + unorderable_window_function := self._unorderable_window_function |
| 522 | + ) is not None: |
| 523 | + assert order_by is None # noqa: S101 |
| 524 | + |
| 525 | + def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: |
| 526 | + assert self._previous_call is not None # noqa: S101 |
| 527 | + return [ |
| 528 | + unorderable_window_function( |
| 529 | + UnorderableWindowInputs(expr, partition_by) |
| 530 | + ) |
| 531 | + for expr in self._previous_call(df) |
| 532 | + ] |
498 | 533 | else:
|
499 | 534 | partition_by_sql = generate_partition_by_sql(*partition_by)
|
500 | 535 | template = f"{{expr}} over ({partition_by_sql})"
|
@@ -728,30 +763,54 @@ def rank(self, method: RankMethod, *, descending: bool) -> Self:
|
728 | 763 | else: # method == "ordinal"
|
729 | 764 | func = FunctionExpression("row_number")
|
730 | 765 |
|
731 |
| - def _rank(_input: duckdb.Expression) -> duckdb.Expression: |
732 |
| - if descending: |
733 |
| - by_sql = f"{_input} desc nulls last" |
734 |
| - else: |
735 |
| - by_sql = f"{_input} asc nulls last" |
736 |
| - order_by_sql = f"order by {by_sql}" |
| 766 | + def _rank( |
| 767 | + _input: duckdb.Expression, |
| 768 | + *, |
| 769 | + descending: bool, |
| 770 | + partition_by: Sequence[str | duckdb.Expression] | None = None, |
| 771 | + ) -> duckdb.Expression: |
| 772 | + order_by_sql = ( |
| 773 | + f"order by {_input} desc nulls last" |
| 774 | + if descending |
| 775 | + else f"order by {_input} asc nulls last" |
| 776 | + ) |
737 | 777 | count_expr = FunctionExpression("count", StarExpression())
|
738 |
| - |
| 778 | + if partition_by is not None: |
| 779 | + window = f"{generate_partition_by_sql(*partition_by)} {order_by_sql}" |
| 780 | + count_window = f"{generate_partition_by_sql(*partition_by, _input)}" |
| 781 | + else: |
| 782 | + window = order_by_sql |
| 783 | + count_window = generate_partition_by_sql(_input) |
739 | 784 | if method == "max":
|
740 | 785 | expr = (
|
741 |
| - SQLExpression(f"{func} OVER ({order_by_sql})") |
742 |
| - + SQLExpression(f"{count_expr} OVER (PARTITION BY {_input})") |
| 786 | + SQLExpression(f"{func} OVER ({window})") |
| 787 | + + SQLExpression(f"{count_expr} over ({count_window})") |
743 | 788 | - lit(1)
|
744 | 789 | )
|
745 | 790 | elif method == "average":
|
746 |
| - expr = SQLExpression(f"{func} OVER ({order_by_sql})") + ( |
747 |
| - SQLExpression(f"{count_expr} OVER (PARTITION BY {_input})") - lit(1) |
| 791 | + expr = SQLExpression(f"{func} OVER ({window})") + ( |
| 792 | + SQLExpression(f"{count_expr} over ({count_window})") - lit(1) |
748 | 793 | ) / lit(2.0)
|
749 | 794 | else:
|
750 |
| - expr = SQLExpression(f"{func} OVER ({order_by_sql})") |
751 |
| - |
| 795 | + expr = SQLExpression(f"{func} OVER ({window})") |
752 | 796 | return when(_input.isnotnull(), expr)
|
753 | 797 |
|
754 |
| - return self._with_callable(_rank) |
| 798 | + def _unpartitioned_rank(_input: duckdb.Expression) -> duckdb.Expression: |
| 799 | + return _rank(_input, descending=descending) |
| 800 | + |
| 801 | + def _partitioned_rank( |
| 802 | + window_inputs: UnorderableWindowInputs, |
| 803 | + ) -> duckdb.Expression: |
| 804 | + return _rank( |
| 805 | + window_inputs.expr, |
| 806 | + descending=descending, |
| 807 | + partition_by=window_inputs.partition_by, |
| 808 | + ) |
| 809 | + |
| 810 | + return self._with_callable(_unpartitioned_rank)._with_unorderable_window_function( |
| 811 | + _partitioned_rank, |
| 812 | + self._call, |
| 813 | + ) |
755 | 814 |
|
756 | 815 | def log(self, base: float) -> Self:
|
757 | 816 | def _log(_input: duckdb.Expression) -> duckdb.Expression:
|
|
0 commit comments