@@ -150,9 +150,10 @@ def agg_arrow(
150
150
)
151
151
raise ValueError (msg )
152
152
153
- # Mapping from output name to
154
- # (aggregation_args, pyarrow_output_name) # noqa: ERA001
155
- simple_aggregations : dict [str , tuple [tuple [Any , ...], str ]] = {}
153
+ aggs : list [tuple [str , str , pc .FunctionOptions | None ]] = []
154
+ expected_pyarrow_column_names : list [str ] = keys .copy ()
155
+ new_column_names : list [str ] = keys .copy ()
156
+
156
157
for expr in exprs :
157
158
if expr ._depth == 0 :
158
159
# e.g. agg(nw.len()) # noqa: ERA001
@@ -161,10 +162,11 @@ def agg_arrow(
161
162
): # pragma: no cover
162
163
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
163
164
raise AssertionError (msg )
164
- simple_aggregations [expr ._output_names [0 ]] = (
165
- (keys [0 ], "count" , pc .CountOptions (mode = "all" )),
166
- f"{ keys [0 ]} _count" ,
167
- )
165
+
166
+ new_column_names .append (expr ._output_names [0 ])
167
+ expected_pyarrow_column_names .append (f"{ keys [0 ]} _count" )
168
+ aggs .append ((keys [0 ], "count" , pc .CountOptions (mode = "all" )))
169
+
168
170
continue
169
171
170
172
# e.g. agg(nw.mean('a')) # noqa: ERA001
@@ -179,22 +181,13 @@ def agg_arrow(
179
181
function_name , (function_name , None )
180
182
)
181
183
182
- for root_name , output_name in zip (expr ._root_names , expr ._output_names ):
183
- simple_aggregations [output_name ] = (
184
- (root_name , function_name , option ),
185
- f"{ root_name } _{ function_name } " ,
186
- )
187
-
188
- aggs : list [Any ] = []
189
- expected_pyarrow_column_names = keys .copy ()
190
- new_column_names = keys .copy ()
191
- for output_name , (
192
- aggregation_args ,
193
- pyarrow_output_name ,
194
- ) in simple_aggregations .items ():
195
- aggs .append (aggregation_args )
196
- expected_pyarrow_column_names .append (pyarrow_output_name )
197
- new_column_names .append (output_name )
184
+ new_column_names .extend (expr ._output_names )
185
+ expected_pyarrow_column_names .extend (
186
+ [f"{ root_name } _{ function_name } " for root_name in expr ._root_names ]
187
+ )
188
+ aggs .extend (
189
+ [(root_name , function_name , option ) for root_name in expr ._root_names ]
190
+ )
198
191
199
192
result_simple = grouped .aggregate (aggs )
200
193
0 commit comments