Skip to content

Commit f399bb2

Browse files
feat: modify rank so that it works with over() for lazy backends (#2533)
--------- Co-authored-by: Marco Gorelli <[email protected]>
1 parent c7a6080 commit f399bb2

File tree

7 files changed

+216
-27
lines changed

7 files changed

+216
-27
lines changed

narwhals/_duckdb/expr.py

Lines changed: 73 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from narwhals._duckdb.expr_list import DuckDBExprListNamespace
2020
from narwhals._duckdb.expr_str import DuckDBExprStringNamespace
2121
from narwhals._duckdb.expr_struct import DuckDBExprStructNamespace
22+
from narwhals._duckdb.utils import UnorderableWindowInputs
2223
from narwhals._duckdb.utils import WindowInputs
2324
from narwhals._duckdb.utils import col
2425
from narwhals._duckdb.utils import ensure_type
@@ -41,6 +42,7 @@
4142
from narwhals._compliant.typing import EvalSeries
4243
from narwhals._duckdb.dataframe import DuckDBLazyFrame
4344
from narwhals._duckdb.namespace import DuckDBNamespace
45+
from narwhals._duckdb.typing import UnorderableWindowFunction
4446
from narwhals._duckdb.typing import WindowFunction
4547
from narwhals._expression_parsing import ExprMetadata
4648
from narwhals.dtypes import DType
@@ -79,6 +81,10 @@ def __init__(
7981
# This can only be set by `_with_window_function`.
8082
self._window_function: WindowFunction | None = None
8183

84+
# These can only be set by `_with_unorderable_window_function`
85+
self._unorderable_window_function: UnorderableWindowFunction | None = None
86+
self._previous_call: EvalSeries[DuckDBLazyFrame, duckdb.Expression] | None = None
87+
8288
def __call__(self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]:
8389
return self._call(df)
8490

@@ -263,6 +269,22 @@ def _with_window_function(self, window_function: WindowFunction) -> Self:
263269
result._window_function = window_function
264270
return result
265271

272+
def _with_unorderable_window_function(
273+
self,
274+
unorderable_window_function: UnorderableWindowFunction,
275+
previous_call: EvalSeries[DuckDBLazyFrame, duckdb.Expression],
276+
) -> Self:
277+
result = self.__class__(
278+
self._call,
279+
evaluate_output_names=self._evaluate_output_names,
280+
alias_output_names=self._alias_output_names,
281+
backend_version=self._backend_version,
282+
version=self._version,
283+
)
284+
result._unorderable_window_function = unorderable_window_function
285+
result._previous_call = previous_call
286+
return result
287+
266288
@classmethod
267289
def _alias_native(cls, expr: duckdb.Expression, name: str) -> duckdb.Expression:
268290
return expr.alias(name)
@@ -495,6 +517,19 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
495517
window_function(WindowInputs(expr, partition_by, order_by))
496518
for expr in self._call(df)
497519
]
520+
elif (
521+
unorderable_window_function := self._unorderable_window_function
522+
) is not None:
523+
assert order_by is None # noqa: S101
524+
525+
def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
526+
assert self._previous_call is not None # noqa: S101
527+
return [
528+
unorderable_window_function(
529+
UnorderableWindowInputs(expr, partition_by)
530+
)
531+
for expr in self._previous_call(df)
532+
]
498533
else:
499534
partition_by_sql = generate_partition_by_sql(*partition_by)
500535
template = f"{{expr}} over ({partition_by_sql})"
@@ -728,30 +763,54 @@ def rank(self, method: RankMethod, *, descending: bool) -> Self:
728763
else: # method == "ordinal"
729764
func = FunctionExpression("row_number")
730765

731-
def _rank(_input: duckdb.Expression) -> duckdb.Expression:
732-
if descending:
733-
by_sql = f"{_input} desc nulls last"
734-
else:
735-
by_sql = f"{_input} asc nulls last"
736-
order_by_sql = f"order by {by_sql}"
766+
def _rank(
767+
_input: duckdb.Expression,
768+
*,
769+
descending: bool,
770+
partition_by: Sequence[str | duckdb.Expression] | None = None,
771+
) -> duckdb.Expression:
772+
order_by_sql = (
773+
f"order by {_input} desc nulls last"
774+
if descending
775+
else f"order by {_input} asc nulls last"
776+
)
737777
count_expr = FunctionExpression("count", StarExpression())
738-
778+
if partition_by is not None:
779+
window = f"{generate_partition_by_sql(*partition_by)} {order_by_sql}"
780+
count_window = f"{generate_partition_by_sql(*partition_by, _input)}"
781+
else:
782+
window = order_by_sql
783+
count_window = generate_partition_by_sql(_input)
739784
if method == "max":
740785
expr = (
741-
SQLExpression(f"{func} OVER ({order_by_sql})")
742-
+ SQLExpression(f"{count_expr} OVER (PARTITION BY {_input})")
786+
SQLExpression(f"{func} OVER ({window})")
787+
+ SQLExpression(f"{count_expr} over ({count_window})")
743788
- lit(1)
744789
)
745790
elif method == "average":
746-
expr = SQLExpression(f"{func} OVER ({order_by_sql})") + (
747-
SQLExpression(f"{count_expr} OVER (PARTITION BY {_input})") - lit(1)
791+
expr = SQLExpression(f"{func} OVER ({window})") + (
792+
SQLExpression(f"{count_expr} over ({count_window})") - lit(1)
748793
) / lit(2.0)
749794
else:
750-
expr = SQLExpression(f"{func} OVER ({order_by_sql})")
751-
795+
expr = SQLExpression(f"{func} OVER ({window})")
752796
return when(_input.isnotnull(), expr)
753797

754-
return self._with_callable(_rank)
798+
def _unpartitioned_rank(_input: duckdb.Expression) -> duckdb.Expression:
799+
return _rank(_input, descending=descending)
800+
801+
def _partitioned_rank(
802+
window_inputs: UnorderableWindowInputs,
803+
) -> duckdb.Expression:
804+
return _rank(
805+
window_inputs.expr,
806+
descending=descending,
807+
partition_by=window_inputs.partition_by,
808+
)
809+
810+
return self._with_callable(_unpartitioned_rank)._with_unorderable_window_function(
811+
_partitioned_rank,
812+
self._call,
813+
)
755814

756815
def log(self, base: float) -> Self:
757816
def _log(_input: duckdb.Expression) -> duckdb.Expression:

narwhals/_duckdb/typing.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66
if TYPE_CHECKING:
77
import duckdb
88

9+
from narwhals._duckdb.utils import UnorderableWindowInputs
910
from narwhals._duckdb.utils import WindowInputs
1011

1112
class WindowFunction(Protocol):
1213
def __call__(self, window_inputs: WindowInputs) -> duckdb.Expression: ...
14+
15+
class UnorderableWindowFunction(Protocol):
16+
def __call__(
17+
self, window_inputs: UnorderableWindowInputs
18+
) -> duckdb.Expression: ...

narwhals/_duckdb/utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,18 @@ def __init__(
5454
self.order_by = order_by
5555

5656

57+
class UnorderableWindowInputs:
58+
__slots__ = ("expr", "partition_by")
59+
60+
def __init__(
61+
self,
62+
expr: duckdb.Expression,
63+
partition_by: Sequence[str],
64+
) -> None:
65+
self.expr = expr
66+
self.partition_by = partition_by
67+
68+
5769
def concat_str(*exprs: duckdb.Expression, separator: str = "") -> duckdb.Expression:
5870
"""Concatenate many strings, NULL inputs are skipped.
5971
@@ -244,10 +256,10 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> st
244256
raise AssertionError(msg)
245257

246258

247-
def generate_partition_by_sql(*partition_by: str) -> str:
259+
def generate_partition_by_sql(*partition_by: str | duckdb.Expression) -> str:
248260
if not partition_by:
249261
return ""
250-
by_sql = ", ".join([f"{col(x)}" for x in partition_by])
262+
by_sql = ", ".join([f"{col(x) if isinstance(x, str) else x}" for x in partition_by])
251263
return f"partition by {by_sql}"
252264

253265

narwhals/_spark_like/expr.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from narwhals._spark_like.expr_list import SparkLikeExprListNamespace
1919
from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace
2020
from narwhals._spark_like.expr_struct import SparkLikeExprStructNamespace
21+
from narwhals._spark_like.utils import UnorderableWindowInputs
2122
from narwhals._spark_like.utils import WindowInputs
2223
from narwhals._spark_like.utils import import_functions
2324
from narwhals._spark_like.utils import import_native_dtypes
@@ -42,6 +43,7 @@
4243
from narwhals._expression_parsing import ExprMetadata
4344
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
4445
from narwhals._spark_like.namespace import SparkLikeNamespace
46+
from narwhals._spark_like.typing import UnorderableWindowFunction
4547
from narwhals._spark_like.typing import WindowFunction
4648
from narwhals.dtypes import DType
4749
from narwhals.typing import FillNullStrategy
@@ -96,6 +98,10 @@ def __init__(
9698
# This can only be set by `_with_window_function`.
9799
self._window_function: WindowFunction | None = None
98100

101+
# These can only be set by `_with_unorderable_window_function`
102+
self._unorderable_window_function: UnorderableWindowFunction | None = None
103+
self._previous_call: EvalSeries[SparkLikeLazyFrame, Column] | None = None
104+
99105
def __call__(self, df: SparkLikeLazyFrame) -> Sequence[Column]:
100106
return self._call(df)
101107

@@ -210,6 +216,23 @@ def _with_window_function(self, window_function: WindowFunction) -> Self:
210216
result._window_function = window_function
211217
return result
212218

219+
def _with_unorderable_window_function(
220+
self,
221+
unorderable_window_function: UnorderableWindowFunction,
222+
previous_call: EvalSeries[SparkLikeLazyFrame, Column],
223+
) -> Self:
224+
result = self.__class__(
225+
self._call,
226+
evaluate_output_names=self._evaluate_output_names,
227+
alias_output_names=self._alias_output_names,
228+
backend_version=self._backend_version,
229+
version=self._version,
230+
implementation=self._implementation,
231+
)
232+
result._unorderable_window_function = unorderable_window_function
233+
result._previous_call = previous_call
234+
return result
235+
213236
@classmethod
214237
def _alias_native(cls, expr: Column, name: str) -> Column:
215238
return expr.alias(name)
@@ -629,6 +652,15 @@ def over(self, partition_by: Sequence[str], order_by: Sequence[str] | None) -> S
629652

630653
def func(df: SparkLikeLazyFrame) -> list[Column]:
631654
return [fn(WindowInputs(expr, partition, order_by)) for expr in self(df)]
655+
elif (fn_unorderable := self._unorderable_window_function) is not None:
656+
assert order_by is None # noqa: S101
657+
658+
def func(df: SparkLikeLazyFrame) -> list[Column]:
659+
assert self._previous_call is not None # noqa: S101
660+
return [
661+
fn_unorderable(UnorderableWindowInputs(expr, partition))
662+
for expr in self._previous_call(df)
663+
]
632664
else:
633665

634666
def func(df: SparkLikeLazyFrame) -> list[Column]:
@@ -786,10 +818,19 @@ def rolling_std(
786818
def rank(self, method: RankMethod, *, descending: bool) -> Self:
787819
func_name = self._REMAP_RANK_METHOD[method]
788820

789-
def _rank(_input: Column) -> Column:
821+
def _rank(
822+
_input: Column,
823+
*,
824+
descending: bool,
825+
partition_by: Sequence[str | Column] | None = None,
826+
) -> Column:
790827
order_by = self._sort(_input, descending=descending, nulls_last=True)
791-
window = self.partition_by().orderBy(*order_by)
792-
count_window = self.partition_by(_input)
828+
if partition_by is not None:
829+
window = self.partition_by(*partition_by).orderBy(*order_by)
830+
count_window = self.partition_by(*partition_by, _input)
831+
else:
832+
window = self.partition_by().orderBy(*order_by)
833+
count_window = self.partition_by(_input)
793834
if method == "max":
794835
expr = (
795836
getattr(self._F, func_name)().over(window)
@@ -807,7 +848,19 @@ def _rank(_input: Column) -> Column:
807848

808849
return self._F.when(_input.isNotNull(), expr)
809850

810-
return self._with_callable(_rank)
851+
def _unpartitioned_rank(_input: Column) -> Column:
852+
return _rank(_input, descending=descending)
853+
854+
def _partitioned_rank(window_inputs: UnorderableWindowInputs) -> Column:
855+
return _rank(
856+
window_inputs.expr,
857+
descending=descending,
858+
partition_by=window_inputs.partition_by,
859+
)
860+
861+
return self._with_callable(_unpartitioned_rank)._with_unorderable_window_function(
862+
_partitioned_rank, self._call
863+
)
811864

812865
def log(self, base: float) -> Self:
813866
def _log(_input: Column) -> Column:

narwhals/_spark_like/typing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
if TYPE_CHECKING:
77
from sqlframe.base.column import Column
88

9+
from narwhals._spark_like.utils import UnorderableWindowInputs
910
from narwhals._spark_like.utils import WindowInputs
1011

1112
class WindowFunction(Protocol):
1213
def __call__(self, window_inputs: WindowInputs, /) -> Column: ...
14+
15+
class UnorderableWindowFunction(Protocol):
16+
def __call__(self, window_inputs: UnorderableWindowInputs, /) -> Column: ...

narwhals/_spark_like/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,26 @@ class WindowInputs:
4343
def __init__(
4444
self,
4545
expr: Column,
46-
partition_by: Sequence[str] | Sequence[Column],
46+
partition_by: Sequence[str | Column],
4747
order_by: Sequence[str],
4848
) -> None:
4949
self.expr = expr
5050
self.partition_by = partition_by
5151
self.order_by = order_by
5252

5353

54+
class UnorderableWindowInputs:
55+
__slots__ = ("expr", "partition_by")
56+
57+
def __init__(
58+
self,
59+
expr: Column,
60+
partition_by: Sequence[str | Column],
61+
) -> None:
62+
self.expr = expr
63+
self.partition_by = partition_by
64+
65+
5466
# NOTE: don't lru_cache this as `ModuleType` isn't hashable
5567
def native_to_narwhals_dtype( # noqa: C901, PLR0912
5668
dtype: _NativeDType, version: Version, spark_types: ModuleType

0 commit comments

Comments
 (0)