|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 |
| -import re |
4 |
| -from functools import partial |
5 | 3 | from typing import TYPE_CHECKING
|
6 |
| -from typing import Any |
7 |
| -from typing import Callable |
8 |
| -from typing import Sequence |
9 |
| - |
10 |
| -from pyspark.sql import functions as F # noqa: N812 |
11 |
| - |
12 |
| -from narwhals._expression_parsing import is_simple_aggregation |
13 |
| -from narwhals._spark_like.utils import _std |
14 |
| -from narwhals._spark_like.utils import _var |
15 |
| -from narwhals.utils import parse_version |
16 | 4 |
|
17 | 5 | if TYPE_CHECKING:
|
18 |
| - from pyspark.sql import Column |
19 |
| - from pyspark.sql import GroupedData |
20 | 6 | from typing_extensions import Self
|
21 | 7 |
|
22 | 8 | from narwhals._spark_like.dataframe import SparkLikeLazyFrame
|
23 |
| - from narwhals._spark_like.typing import SparkLikeExpr |
24 |
| - from narwhals.typing import CompliantExpr |
| 9 | + from narwhals._spark_like.expr import SparkLikeExpr |
25 | 10 |
|
26 | 11 |
|
27 | 12 | class SparkLikeLazyGroupBy:
|
28 | 13 | def __init__(
|
29 | 14 | self: Self,
|
30 |
| - df: SparkLikeLazyFrame, |
| 15 | + compliant_frame: SparkLikeLazyFrame, |
31 | 16 | keys: list[str],
|
32 | 17 | drop_null_keys: bool, # noqa: FBT001
|
33 | 18 | ) -> None:
|
34 |
| - self._df = df |
35 |
| - self._keys = keys |
36 | 19 | if drop_null_keys:
|
37 |
| - self._grouped = self._df._native_frame.dropna(subset=self._keys).groupBy( |
38 |
| - *self._keys |
39 |
| - ) |
| 20 | + self._compliant_frame = compliant_frame.drop_nulls(subset=None) |
40 | 21 | else:
|
41 |
| - self._grouped = self._df._native_frame.groupBy(*self._keys) |
42 |
| - |
43 |
| - def agg( |
44 |
| - self: Self, |
45 |
| - *exprs: SparkLikeExpr, |
46 |
| - ) -> SparkLikeLazyFrame: |
47 |
| - return agg_pyspark( |
48 |
| - self._df, |
49 |
| - self._grouped, |
50 |
| - exprs, |
51 |
| - self._keys, |
52 |
| - self._from_native_frame, |
53 |
| - ) |
54 |
| - |
55 |
| - def _from_native_frame(self: Self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame: |
56 |
| - from narwhals._spark_like.dataframe import SparkLikeLazyFrame |
57 |
| - |
58 |
| - return SparkLikeLazyFrame( |
59 |
| - df, backend_version=self._df._backend_version, version=self._df._version |
60 |
| - ) |
61 |
| - |
62 |
| - |
63 |
| -def get_spark_function(function_name: str) -> Column: |
64 |
| - if (stem := function_name.split("[", maxsplit=1)[0]) in ("std", "var"): |
65 |
| - import numpy as np # ignore-banned-import |
66 |
| - |
67 |
| - return partial( |
68 |
| - _std if stem == "std" else _var, |
69 |
| - ddof=int(function_name.split("[", maxsplit=1)[1].rstrip("]")), |
70 |
| - np_version=parse_version(np.__version__), |
71 |
| - ) |
72 |
| - |
73 |
| - elif function_name == "len": |
74 |
| - # Use count(*) to count all rows including nulls |
75 |
| - def _count(*_args: Any, **_kwargs: Any) -> Column: |
76 |
| - return F.count("*") |
77 |
| - |
78 |
| - return _count |
79 |
| - |
80 |
| - elif function_name == "n_unique": |
81 |
| - from pyspark.sql.types import IntegerType |
82 |
| - |
83 |
| - def _n_unique(_input: Column) -> Column: |
84 |
| - return F.count_distinct(_input) + F.max(F.isnull(_input).cast(IntegerType())) |
85 |
| - |
86 |
| - return _n_unique |
87 |
| - |
88 |
| - else: |
89 |
| - return getattr(F, function_name) |
90 |
| - |
91 |
| - |
92 |
| -def agg_pyspark( |
93 |
| - df: SparkLikeLazyFrame, |
94 |
| - grouped: GroupedData, |
95 |
| - exprs: Sequence[CompliantExpr[Column]], |
96 |
| - keys: list[str], |
97 |
| - from_dataframe: Callable[[Any], SparkLikeLazyFrame], |
98 |
| -) -> SparkLikeLazyFrame: |
99 |
| - if not exprs: |
100 |
| - # No aggregation provided |
101 |
| - return from_dataframe(df._native_frame.select(*keys).dropDuplicates(subset=keys)) |
| 22 | + self._compliant_frame = compliant_frame |
| 23 | + self._keys = keys |
102 | 24 |
|
103 |
| - for expr in exprs: |
104 |
| - if not is_simple_aggregation(expr): # pragma: no cover |
105 |
| - msg = ( |
106 |
| - "Non-trivial complex aggregation found.\n\n" |
107 |
| - "Hint: you were probably trying to apply a non-elementary aggregation with a " |
108 |
| - "dask dataframe.\n" |
109 |
| - "Please rewrite your query such that group-by aggregations " |
110 |
| - "are elementary. For example, instead of:\n\n" |
111 |
| - " df.group_by('a').agg(nw.col('b').round(2).mean())\n\n" |
112 |
| - "use:\n\n" |
113 |
| - " df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n" |
| 25 | + def agg(self: Self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame: |
| 26 | + agg_columns = [] |
| 27 | + df = self._compliant_frame |
| 28 | + for expr in exprs: |
| 29 | + output_names = expr._evaluate_output_names(df) |
| 30 | + aliases = ( |
| 31 | + output_names |
| 32 | + if expr._alias_output_names is None |
| 33 | + else expr._alias_output_names(output_names) |
114 | 34 | )
|
115 |
| - raise ValueError(msg) |
116 |
| - |
117 |
| - simple_aggregations: dict[str, Column] = {} |
118 |
| - for expr in exprs: |
119 |
| - output_names = expr._evaluate_output_names(df) |
120 |
| - aliases = ( |
121 |
| - output_names |
122 |
| - if expr._alias_output_names is None |
123 |
| - else expr._alias_output_names(output_names) |
124 |
| - ) |
125 |
| - if len(output_names) > 1: |
126 |
| - # For multi-output aggregations, e.g. `df.group_by('a').agg(nw.all().mean())`, we skip |
127 |
| - # the keys, else they would appear duplicated in the output. |
128 |
| - output_names, aliases = zip( |
129 |
| - *[(x, alias) for x, alias in zip(output_names, aliases) if x not in keys] |
| 35 | + native_expressions = expr(df) |
| 36 | + exclude = ( |
| 37 | + self._keys |
| 38 | + if expr._function_name.split("->", maxsplit=1)[0] in ("all", "selector") |
| 39 | + else [] |
| 40 | + ) |
| 41 | + agg_columns.extend( |
| 42 | + [ |
| 43 | + native_expression.alias(alias) |
| 44 | + for native_expression, output_name, alias in zip( |
| 45 | + native_expressions, output_names, aliases |
| 46 | + ) |
| 47 | + if output_name not in exclude |
| 48 | + ] |
130 | 49 | )
|
131 |
| - if expr._depth == 0: # pragma: no cover |
132 |
| - # e.g. agg(nw.len()) # noqa: ERA001 |
133 |
| - agg_func = get_spark_function(expr._function_name) |
134 |
| - simple_aggregations.update({alias: agg_func(keys[0]) for alias in aliases}) |
135 |
| - continue |
136 | 50 |
|
137 |
| - # e.g. agg(nw.mean('a')) # noqa: ERA001 |
138 |
| - function_name = re.sub(r"(\w+->)", "", expr._function_name) |
139 |
| - agg_func = get_spark_function(function_name) |
| 51 | + if not agg_columns: |
| 52 | + return self._compliant_frame._from_native_frame( |
| 53 | + self._compliant_frame._native_frame.select(*self._keys).dropDuplicates() |
| 54 | + ) |
140 | 55 |
|
141 |
| - simple_aggregations.update( |
142 |
| - { |
143 |
| - alias: agg_func(output_name) |
144 |
| - for alias, output_name in zip(aliases, output_names) |
145 |
| - } |
| 56 | + return self._compliant_frame._from_native_frame( |
| 57 | + self._compliant_frame._native_frame.groupBy(self._keys).agg(*agg_columns) |
146 | 58 | )
|
147 |
| - |
148 |
| - agg_columns = [col_.alias(name) for name, col_ in simple_aggregations.items()] |
149 |
| - result_simple = grouped.agg(*agg_columns) |
150 |
| - return from_dataframe(result_simple) |
0 commit comments