Skip to content

Commit 75c3e7c

Browse files
authored
perf: streamline pyarrow and pandas group by agg (#1621)
1 parent 12078a9 commit 75c3e7c

File tree

2 files changed

+25
-33
lines changed

2 files changed

+25
-33
lines changed

narwhals/_arrow/group_by.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,10 @@ def agg_arrow(
150150
)
151151
raise ValueError(msg)
152152

153-
# Mapping from output name to
154-
# (aggregation_args, pyarrow_output_name) # noqa: ERA001
155-
simple_aggregations: dict[str, tuple[tuple[Any, ...], str]] = {}
153+
aggs: list[tuple[str, str, pc.FunctionOptions | None]] = []
154+
expected_pyarrow_column_names: list[str] = keys.copy()
155+
new_column_names: list[str] = keys.copy()
156+
156157
for expr in exprs:
157158
if expr._depth == 0:
158159
# e.g. agg(nw.len()) # noqa: ERA001
@@ -161,10 +162,11 @@ def agg_arrow(
161162
): # pragma: no cover
162163
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
163164
raise AssertionError(msg)
164-
simple_aggregations[expr._output_names[0]] = (
165-
(keys[0], "count", pc.CountOptions(mode="all")),
166-
f"{keys[0]}_count",
167-
)
165+
166+
new_column_names.append(expr._output_names[0])
167+
expected_pyarrow_column_names.append(f"{keys[0]}_count")
168+
aggs.append((keys[0], "count", pc.CountOptions(mode="all")))
169+
168170
continue
169171

170172
# e.g. agg(nw.mean('a')) # noqa: ERA001
@@ -179,22 +181,13 @@ def agg_arrow(
179181
function_name, (function_name, None)
180182
)
181183

182-
for root_name, output_name in zip(expr._root_names, expr._output_names):
183-
simple_aggregations[output_name] = (
184-
(root_name, function_name, option),
185-
f"{root_name}_{function_name}",
186-
)
187-
188-
aggs: list[Any] = []
189-
expected_pyarrow_column_names = keys.copy()
190-
new_column_names = keys.copy()
191-
for output_name, (
192-
aggregation_args,
193-
pyarrow_output_name,
194-
) in simple_aggregations.items():
195-
aggs.append(aggregation_args)
196-
expected_pyarrow_column_names.append(pyarrow_output_name)
197-
new_column_names.append(output_name)
184+
new_column_names.extend(expr._output_names)
185+
expected_pyarrow_column_names.extend(
186+
[f"{root_name}_{function_name}" for root_name in expr._root_names]
187+
)
188+
aggs.extend(
189+
[(root_name, function_name, option) for root_name in expr._root_names]
190+
)
198191

199192
result_simple = grouped.aggregate(aggs)
200193

narwhals/_pandas_like/group_by.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,11 @@ def agg_pandas( # noqa: PLR0915
166166
# We need to do this separately from the rest so that we
167167
# can pass the `dropna` kwargs.
168168
nunique_aggs: dict[str, str] = {}
169+
simple_aggs: dict[str, list[str]] = collections.defaultdict(list)
170+
expected_old_names: list[str] = []
171+
new_names: list[str] = []
169172

170173
if all_aggs_are_simple:
171-
simple_aggregations: dict[str, tuple[str, str]] = {}
172174
for expr in exprs:
173175
if expr._depth == 0:
174176
# e.g. agg(nw.len()) # noqa: ERA001
@@ -180,7 +182,9 @@ def agg_pandas( # noqa: PLR0915
180182
expr._function_name, expr._function_name
181183
)
182184
for output_name in expr._output_names:
183-
simple_aggregations[output_name] = (keys[0], function_name)
185+
new_names.append(output_name)
186+
expected_old_names.append(f"{keys[0]}_{function_name}")
187+
simple_aggs[keys[0]].append(function_name)
184188
continue
185189

186190
# e.g. agg(nw.mean('a')) # noqa: ERA001
@@ -199,15 +203,10 @@ def agg_pandas( # noqa: PLR0915
199203
if is_n_unique:
200204
nunique_aggs[output_name] = root_name
201205
else:
202-
simple_aggregations[output_name] = (root_name, function_name)
206+
new_names.append(output_name)
207+
expected_old_names.append(f"{root_name}_{function_name}")
208+
simple_aggs[root_name].append(function_name)
203209

204-
simple_aggs: dict[str, list[str]] = collections.defaultdict(list)
205-
expected_old_names: list[str] = []
206-
new_names: list[str] = []
207-
for output_name, (col_name, function) in simple_aggregations.items():
208-
simple_aggs[col_name].append(function)
209-
new_names.append(output_name)
210-
expected_old_names.append(f"{col_name}_{function}")
211210
if simple_aggs:
212211
result_simple_aggs = grouped.agg(simple_aggs)
213212
result_simple_aggs.columns = [

0 commit comments

Comments
 (0)