Skip to content

feat: modify rank so that it works with over() for lazy backends #2533

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 73 additions & 14 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from narwhals._duckdb.expr_list import DuckDBExprListNamespace
from narwhals._duckdb.expr_str import DuckDBExprStringNamespace
from narwhals._duckdb.expr_struct import DuckDBExprStructNamespace
from narwhals._duckdb.utils import UnorderableWindowInputs
from narwhals._duckdb.utils import WindowInputs
from narwhals._duckdb.utils import col
from narwhals._duckdb.utils import ensure_type
Expand All @@ -41,6 +42,7 @@
from narwhals._compliant.typing import EvalSeries
from narwhals._duckdb.dataframe import DuckDBLazyFrame
from narwhals._duckdb.namespace import DuckDBNamespace
from narwhals._duckdb.typing import UnorderableWindowFunction
from narwhals._duckdb.typing import WindowFunction
from narwhals._expression_parsing import ExprMetadata
from narwhals.dtypes import DType
Expand Down Expand Up @@ -79,6 +81,10 @@ def __init__(
# This can only be set by `_with_window_function`.
self._window_function: WindowFunction | None = None

# These can only be set by `_with_unorderable_window_function`
self._unorderable_window_function: UnorderableWindowFunction | None = None
self._previous_call: EvalSeries[DuckDBLazyFrame, duckdb.Expression] | None = None

def __call__(self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]:
return self._call(df)

Expand Down Expand Up @@ -263,6 +269,22 @@ def _with_window_function(self, window_function: WindowFunction) -> Self:
result._window_function = window_function
return result

def _with_unorderable_window_function(
self,
unorderable_window_function: UnorderableWindowFunction,
previous_call: EvalSeries[DuckDBLazyFrame, duckdb.Expression],
) -> Self:
result = self.__class__(
self._call,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
backend_version=self._backend_version,
version=self._version,
)
result._unorderable_window_function = unorderable_window_function
result._previous_call = previous_call
return result

@classmethod
def _alias_native(cls, expr: duckdb.Expression, name: str) -> duckdb.Expression:
return expr.alias(name)
Expand Down Expand Up @@ -495,6 +517,19 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
window_function(WindowInputs(expr, partition_by, order_by))
for expr in self._call(df)
]
elif (
unorderable_window_function := self._unorderable_window_function
) is not None:
assert order_by is None # noqa: S101

def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
assert self._previous_call is not None # noqa: S101
return [
unorderable_window_function(
UnorderableWindowInputs(expr, partition_by)
)
for expr in self._previous_call(df)
]
else:
partition_by_sql = generate_partition_by_sql(*partition_by)
template = f"{{expr}} over ({partition_by_sql})"
Expand Down Expand Up @@ -728,30 +763,54 @@ def rank(self, method: RankMethod, *, descending: bool) -> Self:
else: # method == "ordinal"
func = FunctionExpression("row_number")

def _rank(_input: duckdb.Expression) -> duckdb.Expression:
if descending:
by_sql = f"{_input} desc nulls last"
else:
by_sql = f"{_input} asc nulls last"
order_by_sql = f"order by {by_sql}"
def _rank(
_input: duckdb.Expression,
*,
descending: bool,
partition_by: Sequence[str | duckdb.Expression] | None = None,
) -> duckdb.Expression:
order_by_sql = (
f"order by {_input} desc nulls last"
if descending
else f"order by {_input} asc nulls last"
)
count_expr = FunctionExpression("count", StarExpression())

if partition_by is not None:
window = f"{generate_partition_by_sql(*partition_by)} {order_by_sql}"
count_window = f"{generate_partition_by_sql(*partition_by, _input)}"
else:
window = order_by_sql
count_window = generate_partition_by_sql(_input)
if method == "max":
expr = (
SQLExpression(f"{func} OVER ({order_by_sql})")
+ SQLExpression(f"{count_expr} OVER (PARTITION BY {_input})")
SQLExpression(f"{func} OVER ({window})")
+ SQLExpression(f"{count_expr} over ({count_window})")
- lit(1)
)
elif method == "average":
expr = SQLExpression(f"{func} OVER ({order_by_sql})") + (
SQLExpression(f"{count_expr} OVER (PARTITION BY {_input})") - lit(1)
expr = SQLExpression(f"{func} OVER ({window})") + (
SQLExpression(f"{count_expr} over ({count_window})") - lit(1)
) / lit(2.0)
else:
expr = SQLExpression(f"{func} OVER ({order_by_sql})")

expr = SQLExpression(f"{func} OVER ({window})")
return when(_input.isnotnull(), expr)

return self._with_callable(_rank)
def _unpartitioned_rank(_input: duckdb.Expression) -> duckdb.Expression:
return _rank(_input, descending=descending)

def _partitioned_rank(
window_inputs: UnorderableWindowInputs,
) -> duckdb.Expression:
return _rank(
window_inputs.expr,
descending=descending,
partition_by=window_inputs.partition_by,
)

return self._with_callable(_unpartitioned_rank)._with_unorderable_window_function(
_partitioned_rank,
self._call,
)

def log(self, base: float) -> Self:
def _log(_input: duckdb.Expression) -> duckdb.Expression:
Expand Down
6 changes: 6 additions & 0 deletions narwhals/_duckdb/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
if TYPE_CHECKING:
import duckdb

from narwhals._duckdb.utils import UnorderableWindowInputs
from narwhals._duckdb.utils import WindowInputs

class WindowFunction(Protocol):
def __call__(self, window_inputs: WindowInputs) -> duckdb.Expression: ...

class UnorderableWindowFunction(Protocol):
def __call__(
self, window_inputs: UnorderableWindowInputs
) -> duckdb.Expression: ...
16 changes: 14 additions & 2 deletions narwhals/_duckdb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ def __init__(
self.order_by = order_by


class UnorderableWindowInputs:
__slots__ = ("expr", "partition_by")

def __init__(
self,
expr: duckdb.Expression,
partition_by: Sequence[str],
) -> None:
self.expr = expr
self.partition_by = partition_by


def concat_str(*exprs: duckdb.Expression, separator: str = "") -> duckdb.Expression:
"""Concatenate many strings, NULL inputs are skipped.

Expand Down Expand Up @@ -244,10 +256,10 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> st
raise AssertionError(msg)


def generate_partition_by_sql(*partition_by: str) -> str:
def generate_partition_by_sql(*partition_by: str | duckdb.Expression) -> str:
if not partition_by:
return ""
by_sql = ", ".join([f"{col(x)}" for x in partition_by])
by_sql = ", ".join([f"{col(x) if isinstance(x, str) else x}" for x in partition_by])
return f"partition by {by_sql}"


Expand Down
61 changes: 57 additions & 4 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from narwhals._spark_like.expr_list import SparkLikeExprListNamespace
from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace
from narwhals._spark_like.expr_struct import SparkLikeExprStructNamespace
from narwhals._spark_like.utils import UnorderableWindowInputs
from narwhals._spark_like.utils import WindowInputs
from narwhals._spark_like.utils import import_functions
from narwhals._spark_like.utils import import_native_dtypes
Expand All @@ -42,6 +43,7 @@
from narwhals._expression_parsing import ExprMetadata
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.namespace import SparkLikeNamespace
from narwhals._spark_like.typing import UnorderableWindowFunction
from narwhals._spark_like.typing import WindowFunction
from narwhals.dtypes import DType
from narwhals.typing import FillNullStrategy
Expand Down Expand Up @@ -96,6 +98,10 @@ def __init__(
# This can only be set by `_with_window_function`.
self._window_function: WindowFunction | None = None

# These can only be set by `_with_unorderable_window_function`
self._unorderable_window_function: UnorderableWindowFunction | None = None
self._previous_call: EvalSeries[SparkLikeLazyFrame, Column] | None = None

def __call__(self, df: SparkLikeLazyFrame) -> Sequence[Column]:
return self._call(df)

Expand Down Expand Up @@ -210,6 +216,23 @@ def _with_window_function(self, window_function: WindowFunction) -> Self:
result._window_function = window_function
return result

def _with_unorderable_window_function(
self,
unorderable_window_function: UnorderableWindowFunction,
previous_call: EvalSeries[SparkLikeLazyFrame, Column],
) -> Self:
result = self.__class__(
self._call,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
backend_version=self._backend_version,
version=self._version,
implementation=self._implementation,
)
result._unorderable_window_function = unorderable_window_function
result._previous_call = previous_call
return result

@classmethod
def _alias_native(cls, expr: Column, name: str) -> Column:
return expr.alias(name)
Expand Down Expand Up @@ -629,6 +652,15 @@ def over(self, partition_by: Sequence[str], order_by: Sequence[str] | None) -> S

def func(df: SparkLikeLazyFrame) -> list[Column]:
return [fn(WindowInputs(expr, partition, order_by)) for expr in self(df)]
elif (fn_unorderable := self._unorderable_window_function) is not None:
assert order_by is None # noqa: S101

def func(df: SparkLikeLazyFrame) -> list[Column]:
assert self._previous_call is not None # noqa: S101
return [
fn_unorderable(UnorderableWindowInputs(expr, partition))
for expr in self._previous_call(df)
]
else:

def func(df: SparkLikeLazyFrame) -> list[Column]:
Expand Down Expand Up @@ -786,10 +818,19 @@ def rolling_std(
def rank(self, method: RankMethod, *, descending: bool) -> Self:
func_name = self._REMAP_RANK_METHOD[method]

def _rank(_input: Column) -> Column:
def _rank(
_input: Column,
*,
descending: bool,
partition_by: Sequence[str | Column] | None = None,
) -> Column:
order_by = self._sort(_input, descending=descending, nulls_last=True)
window = self.partition_by().orderBy(*order_by)
count_window = self.partition_by(_input)
if partition_by is not None:
window = self.partition_by(*partition_by).orderBy(*order_by)
count_window = self.partition_by(*partition_by, _input)
else:
window = self.partition_by().orderBy(*order_by)
count_window = self.partition_by(_input)
if method == "max":
expr = (
getattr(self._F, func_name)().over(window)
Expand All @@ -807,7 +848,19 @@ def _rank(_input: Column) -> Column:

return self._F.when(_input.isNotNull(), expr)

return self._with_callable(_rank)
def _unpartitioned_rank(_input: Column) -> Column:
return _rank(_input, descending=descending)

def _partitioned_rank(window_inputs: UnorderableWindowInputs) -> Column:
return _rank(
window_inputs.expr,
descending=descending,
partition_by=window_inputs.partition_by,
)

return self._with_callable(_unpartitioned_rank)._with_unorderable_window_function(
_partitioned_rank, self._call
)

def log(self, base: float) -> Self:
def _log(_input: Column) -> Column:
Expand Down
4 changes: 4 additions & 0 deletions narwhals/_spark_like/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
if TYPE_CHECKING:
from sqlframe.base.column import Column

from narwhals._spark_like.utils import UnorderableWindowInputs
from narwhals._spark_like.utils import WindowInputs

class WindowFunction(Protocol):
def __call__(self, window_inputs: WindowInputs, /) -> Column: ...

class UnorderableWindowFunction(Protocol):
def __call__(self, window_inputs: UnorderableWindowInputs, /) -> Column: ...
14 changes: 13 additions & 1 deletion narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,26 @@ class WindowInputs:
def __init__(
self,
expr: Column,
partition_by: Sequence[str] | Sequence[Column],
partition_by: Sequence[str | Column],
order_by: Sequence[str],
) -> None:
self.expr = expr
self.partition_by = partition_by
self.order_by = order_by


class UnorderableWindowInputs:
__slots__ = ("expr", "partition_by")

def __init__(
self,
expr: Column,
partition_by: Sequence[str | Column],
) -> None:
self.expr = expr
self.partition_by = partition_by


# NOTE: don't lru_cache this as `ModuleType` isn't hashable
def native_to_narwhals_dtype( # noqa: C901, PLR0912
dtype: _NativeDType, version: Version, spark_types: ModuleType
Expand Down
Loading
Loading