Skip to content

Commit 702eea5

Browse files
authored
fix: support various reductions in pyspark (#1870)
1 parent 267eb53 commit 702eea5

File tree

9 files changed

+142
-174
lines changed

9 files changed

+142
-174
lines changed

narwhals/_spark_like/dataframe.py

+40-13
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66
from typing import Literal
77
from typing import Sequence
88

9+
from pyspark.sql import Window
10+
from pyspark.sql import functions as F # noqa: N812
11+
912
from narwhals._spark_like.utils import native_to_narwhals_dtype
1013
from narwhals._spark_like.utils import parse_exprs_and_named_exprs
14+
from narwhals.typing import CompliantLazyFrame
1115
from narwhals.utils import Implementation
1216
from narwhals.utils import check_column_exists
1317
from narwhals.utils import parse_columns_to_drop
@@ -26,7 +30,6 @@
2630
from narwhals._spark_like.namespace import SparkLikeNamespace
2731
from narwhals.dtypes import DType
2832
from narwhals.utils import Version
29-
from narwhals.typing import CompliantLazyFrame
3033

3134

3235
class SparkLikeLazyFrame(CompliantLazyFrame):
@@ -94,7 +97,9 @@ def select(
9497
*exprs: SparkLikeExpr,
9598
**named_exprs: SparkLikeExpr,
9699
) -> Self:
97-
new_columns = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs)
100+
new_columns, returns_scalar = parse_exprs_and_named_exprs(self)(
101+
*exprs, **named_exprs
102+
)
98103

99104
if not new_columns:
100105
# return empty dataframe, like Polars does
@@ -105,8 +110,38 @@ def select(
105110

106111
return self._from_native_frame(spark_df)
107112

108-
new_columns_list = [col.alias(col_name) for col_name, col in new_columns.items()]
109-
return self._from_native_frame(self._native_frame.select(*new_columns_list))
113+
if all(returns_scalar):
114+
new_columns_list = [
115+
col.alias(col_name) for col_name, col in new_columns.items()
116+
]
117+
return self._from_native_frame(self._native_frame.agg(*new_columns_list))
118+
else:
119+
new_columns_list = [
120+
col.over(Window.partitionBy(F.lit(1))).alias(col_name)
121+
if _returns_scalar
122+
else col.alias(col_name)
123+
for (col_name, col), _returns_scalar in zip(
124+
new_columns.items(), returns_scalar
125+
)
126+
]
127+
return self._from_native_frame(self._native_frame.select(*new_columns_list))
128+
129+
def with_columns(
130+
self: Self,
131+
*exprs: SparkLikeExpr,
132+
**named_exprs: SparkLikeExpr,
133+
) -> Self:
134+
new_columns, returns_scalar = parse_exprs_and_named_exprs(self)(
135+
*exprs, **named_exprs
136+
)
137+
138+
new_columns_map = {
139+
col_name: col.over(Window.partitionBy(F.lit(1))) if _returns_scalar else col
140+
for (col_name, col), _returns_scalar in zip(
141+
new_columns.items(), returns_scalar
142+
)
143+
}
144+
return self._from_native_frame(self._native_frame.withColumns(new_columns_map))
110145

111146
def filter(self: Self, *predicates: SparkLikeExpr, **constraints: Any) -> Self:
112147
plx = self.__narwhals_namespace__()
@@ -130,14 +165,6 @@ def schema(self: Self) -> dict[str, DType]:
130165
def collect_schema(self: Self) -> dict[str, DType]:
131166
return self.schema
132167

133-
def with_columns(
134-
self: Self,
135-
*exprs: SparkLikeExpr,
136-
**named_exprs: SparkLikeExpr,
137-
) -> Self:
138-
new_columns_map = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs)
139-
return self._from_native_frame(self._native_frame.withColumns(new_columns_map))
140-
141168
def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
142169
columns_to_drop = parse_columns_to_drop(
143170
compliant_frame=self, columns=columns, strict=strict
@@ -155,7 +182,7 @@ def group_by(self: Self, *keys: str, drop_null_keys: bool) -> SparkLikeLazyGroup
155182
from narwhals._spark_like.group_by import SparkLikeLazyGroupBy
156183

157184
return SparkLikeLazyGroupBy(
158-
df=self, keys=list(keys), drop_null_keys=drop_null_keys
185+
compliant_frame=self, keys=list(keys), drop_null_keys=drop_null_keys
159186
)
160187

161188
def sort(

narwhals/_spark_like/expr.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def _from_call(
122122
def func(df: SparkLikeLazyFrame) -> list[Column]:
123123
native_series_list = self._call(df)
124124
other_native_series = {
125-
key: maybe_evaluate(df, value)
125+
key: maybe_evaluate(df, value, returns_scalar=returns_scalar)
126126
for key, value in expressifiable_args.items()
127127
}
128128
return [
@@ -136,7 +136,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
136136
function_name=f"{self._function_name}->{expr_name}",
137137
evaluate_output_names=self._evaluate_output_names,
138138
alias_output_names=self._alias_output_names,
139-
returns_scalar=self._returns_scalar or returns_scalar,
139+
returns_scalar=returns_scalar,
140140
backend_version=self._backend_version,
141141
version=self._version,
142142
)
@@ -349,7 +349,7 @@ def std(self: Self, ddof: int) -> Self:
349349

350350
func = partial(_std, ddof=ddof, np_version=parse_version(np.__version__))
351351

352-
return self._from_call(func, f"std[{ddof}]", returns_scalar=True)
352+
return self._from_call(func, "std", returns_scalar=True)
353353

354354
def var(self: Self, ddof: int) -> Self:
355355
from functools import partial
@@ -360,7 +360,7 @@ def var(self: Self, ddof: int) -> Self:
360360

361361
func = partial(_var, ddof=ddof, np_version=parse_version(np.__version__))
362362

363-
return self._from_call(func, f"var[{ddof}]", returns_scalar=True)
363+
return self._from_call(func, "var", returns_scalar=True)
364364

365365
def clip(
366366
self: Self,

narwhals/_spark_like/group_by.py

+34-126
Original file line numberDiff line numberDiff line change
@@ -1,150 +1,58 @@
11
from __future__ import annotations
22

3-
import re
4-
from functools import partial
53
from typing import TYPE_CHECKING
6-
from typing import Any
7-
from typing import Callable
8-
from typing import Sequence
9-
10-
from pyspark.sql import functions as F # noqa: N812
11-
12-
from narwhals._expression_parsing import is_simple_aggregation
13-
from narwhals._spark_like.utils import _std
14-
from narwhals._spark_like.utils import _var
15-
from narwhals.utils import parse_version
164

175
if TYPE_CHECKING:
18-
from pyspark.sql import Column
19-
from pyspark.sql import GroupedData
206
from typing_extensions import Self
217

228
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
23-
from narwhals._spark_like.typing import SparkLikeExpr
24-
from narwhals.typing import CompliantExpr
9+
from narwhals._spark_like.expr import SparkLikeExpr
2510

2611

2712
class SparkLikeLazyGroupBy:
2813
def __init__(
2914
self: Self,
30-
df: SparkLikeLazyFrame,
15+
compliant_frame: SparkLikeLazyFrame,
3116
keys: list[str],
3217
drop_null_keys: bool, # noqa: FBT001
3318
) -> None:
34-
self._df = df
35-
self._keys = keys
3619
if drop_null_keys:
37-
self._grouped = self._df._native_frame.dropna(subset=self._keys).groupBy(
38-
*self._keys
39-
)
20+
self._compliant_frame = compliant_frame.drop_nulls(subset=None)
4021
else:
41-
self._grouped = self._df._native_frame.groupBy(*self._keys)
42-
43-
def agg(
44-
self: Self,
45-
*exprs: SparkLikeExpr,
46-
) -> SparkLikeLazyFrame:
47-
return agg_pyspark(
48-
self._df,
49-
self._grouped,
50-
exprs,
51-
self._keys,
52-
self._from_native_frame,
53-
)
54-
55-
def _from_native_frame(self: Self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame:
56-
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
57-
58-
return SparkLikeLazyFrame(
59-
df, backend_version=self._df._backend_version, version=self._df._version
60-
)
61-
62-
63-
def get_spark_function(function_name: str) -> Column:
64-
if (stem := function_name.split("[", maxsplit=1)[0]) in ("std", "var"):
65-
import numpy as np # ignore-banned-import
66-
67-
return partial(
68-
_std if stem == "std" else _var,
69-
ddof=int(function_name.split("[", maxsplit=1)[1].rstrip("]")),
70-
np_version=parse_version(np.__version__),
71-
)
72-
73-
elif function_name == "len":
74-
# Use count(*) to count all rows including nulls
75-
def _count(*_args: Any, **_kwargs: Any) -> Column:
76-
return F.count("*")
77-
78-
return _count
79-
80-
elif function_name == "n_unique":
81-
from pyspark.sql.types import IntegerType
82-
83-
def _n_unique(_input: Column) -> Column:
84-
return F.count_distinct(_input) + F.max(F.isnull(_input).cast(IntegerType()))
85-
86-
return _n_unique
87-
88-
else:
89-
return getattr(F, function_name)
90-
91-
92-
def agg_pyspark(
93-
df: SparkLikeLazyFrame,
94-
grouped: GroupedData,
95-
exprs: Sequence[CompliantExpr[Column]],
96-
keys: list[str],
97-
from_dataframe: Callable[[Any], SparkLikeLazyFrame],
98-
) -> SparkLikeLazyFrame:
99-
if not exprs:
100-
# No aggregation provided
101-
return from_dataframe(df._native_frame.select(*keys).dropDuplicates(subset=keys))
22+
self._compliant_frame = compliant_frame
23+
self._keys = keys
10224

103-
for expr in exprs:
104-
if not is_simple_aggregation(expr): # pragma: no cover
105-
msg = (
106-
"Non-trivial complex aggregation found.\n\n"
107-
"Hint: you were probably trying to apply a non-elementary aggregation with a "
108-
"dask dataframe.\n"
109-
"Please rewrite your query such that group-by aggregations "
110-
"are elementary. For example, instead of:\n\n"
111-
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
112-
"use:\n\n"
113-
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
25+
def agg(self: Self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame:
26+
agg_columns = []
27+
df = self._compliant_frame
28+
for expr in exprs:
29+
output_names = expr._evaluate_output_names(df)
30+
aliases = (
31+
output_names
32+
if expr._alias_output_names is None
33+
else expr._alias_output_names(output_names)
11434
)
115-
raise ValueError(msg)
116-
117-
simple_aggregations: dict[str, Column] = {}
118-
for expr in exprs:
119-
output_names = expr._evaluate_output_names(df)
120-
aliases = (
121-
output_names
122-
if expr._alias_output_names is None
123-
else expr._alias_output_names(output_names)
124-
)
125-
if len(output_names) > 1:
126-
# For multi-output aggregations, e.g. `df.group_by('a').agg(nw.all().mean())`, we skip
127-
# the keys, else they would appear duplicated in the output.
128-
output_names, aliases = zip(
129-
*[(x, alias) for x, alias in zip(output_names, aliases) if x not in keys]
35+
native_expressions = expr(df)
36+
exclude = (
37+
self._keys
38+
if expr._function_name.split("->", maxsplit=1)[0] in ("all", "selector")
39+
else []
40+
)
41+
agg_columns.extend(
42+
[
43+
native_expression.alias(alias)
44+
for native_expression, output_name, alias in zip(
45+
native_expressions, output_names, aliases
46+
)
47+
if output_name not in exclude
48+
]
13049
)
131-
if expr._depth == 0: # pragma: no cover
132-
# e.g. agg(nw.len()) # noqa: ERA001
133-
agg_func = get_spark_function(expr._function_name)
134-
simple_aggregations.update({alias: agg_func(keys[0]) for alias in aliases})
135-
continue
13650

137-
# e.g. agg(nw.mean('a')) # noqa: ERA001
138-
function_name = re.sub(r"(\w+->)", "", expr._function_name)
139-
agg_func = get_spark_function(function_name)
51+
if not agg_columns:
52+
return self._compliant_frame._from_native_frame(
53+
self._compliant_frame._native_frame.select(*self._keys).dropDuplicates()
54+
)
14055

141-
simple_aggregations.update(
142-
{
143-
alias: agg_func(output_name)
144-
for alias, output_name in zip(aliases, output_names)
145-
}
56+
return self._compliant_frame._from_native_frame(
57+
self._compliant_frame._native_frame.groupBy(self._keys).agg(*agg_columns)
14658
)
147-
148-
agg_columns = [col_.alias(name) for name, col_ in simple_aggregations.items()]
149-
result_simple = grouped.agg(*agg_columns)
150-
return from_dataframe(result_simple)

narwhals/_spark_like/utils.py

+32-7
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from typing import Any
66
from typing import Callable
77

8+
from pyspark.sql import Column
9+
from pyspark.sql import Window
810
from pyspark.sql import functions as F # noqa: N812
911
from pyspark.sql import types as pyspark_types
10-
from pyspark.sql.window import Window
1112

1213
from narwhals.exceptions import UnsupportedDTypeError
1314
from narwhals.utils import import_dtypes_module
@@ -109,9 +110,16 @@ def narwhals_to_native_dtype(
109110

110111
def parse_exprs_and_named_exprs(
111112
df: SparkLikeLazyFrame,
112-
) -> Callable[..., dict[str, Column]]:
113-
def func(*exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr) -> dict[str, Column]:
113+
) -> Callable[..., tuple[dict[str, Column], list[bool]]]:
114+
def func(
115+
*exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr
116+
) -> tuple[dict[str, Column], list[bool]]:
114117
native_results: dict[str, list[Column]] = {}
118+
119+
# `returns_scalar` keeps track if an expression returns a scalar and is not lit.
120+
# Notice that lit is quite special case, since it gets broadcasted by pyspark
121+
# without the need of adding `.over(Window.partitionBy(F.lit(1)))`
122+
returns_scalar: list[bool] = []
115123
for expr in exprs:
116124
native_series_list = expr._call(df)
117125
output_names = expr._evaluate_output_names(df)
@@ -121,18 +129,30 @@ def func(*exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr) -> dict[str, Colum
121129
msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results"
122130
raise AssertionError(msg)
123131
native_results.update(zip(output_names, native_series_list))
132+
returns_scalar.extend(
133+
[
134+
expr._returns_scalar
135+
and expr._function_name.split("->", maxsplit=1)[0] != "lit"
136+
]
137+
* len(output_names)
138+
)
124139
for col_alias, expr in named_exprs.items():
125140
native_series_list = expr._call(df)
126141
if len(native_series_list) != 1: # pragma: no cover
127142
msg = "Named expressions must return a single column"
128143
raise ValueError(msg)
129144
native_results[col_alias] = native_series_list[0]
130-
return native_results
145+
returns_scalar.append(
146+
expr._returns_scalar
147+
and expr._function_name.split("->", maxsplit=1)[0] != "lit"
148+
)
149+
150+
return native_results, returns_scalar
131151

132152
return func
133153

134154

135-
def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any) -> Any:
155+
def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any, *, returns_scalar: bool) -> Column:
136156
from narwhals._spark_like.expr import SparkLikeExpr
137157

138158
if isinstance(obj, SparkLikeExpr):
@@ -141,8 +161,13 @@ def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any) -> Any:
141161
msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) not supported in this context"
142162
raise NotImplementedError(msg)
143163
column_result = column_results[0]
144-
if obj._returns_scalar:
145-
# Return scalar, let PySpark do its broadcasting
164+
if (
165+
obj._returns_scalar
166+
and obj._function_name.split("->", maxsplit=1)[0] != "lit"
167+
and not returns_scalar
168+
):
169+
# Returns scalar, but overall expression doesn't.
170+
# Let PySpark do its broadcasting
146171
return column_result.over(Window.partitionBy(F.lit(1)))
147172
return column_result
148173
return F.lit(obj)

0 commit comments

Comments
 (0)