@@ -137,85 +137,85 @@ def agg_arrow(
137
137
all_simple_aggs = False
138
138
break
139
139
140
- if all_simple_aggs :
141
- # Mapping from output name to
142
- # (aggregation_args, pyarrow_output_name) # noqa: ERA001
143
- simple_aggregations : dict [str , tuple [tuple [Any , ...], str ]] = {}
144
- for expr in exprs :
145
- if expr ._depth == 0 :
146
- # e.g. agg(nw.len()) # noqa: ERA001
147
- if (
148
- expr ._output_names is None or expr ._function_name != "len"
149
- ): # pragma: no cover
150
- msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
151
- raise AssertionError (msg )
152
- simple_aggregations [expr ._output_names [0 ]] = (
153
- (keys [0 ], "count" , pc .CountOptions (mode = "all" )),
154
- f"{ keys [0 ]} _count" ,
155
- )
156
- continue
140
+ if not all_simple_aggs :
141
+ msg = (
142
+ "Non-trivial complex aggregation found.\n \n "
143
+ "Hint: you were probably trying to apply a non-elementary aggregation with a "
144
+ "pyarrow table.\n "
145
+ "Please rewrite your query such that group-by aggregations "
146
+ "are elementary. For example, instead of:\n \n "
147
+ " df.group_by('a').agg(nw.col('b').round(2).mean())\n \n "
148
+ "use:\n \n "
149
+ " df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n \n "
150
+ )
151
+ raise ValueError (msg )
157
152
158
- # e.g. agg(nw.mean('a')) # noqa: ERA001
153
+ # Mapping from output name to
154
+ # (aggregation_args, pyarrow_output_name) # noqa: ERA001
155
+ simple_aggregations : dict [str , tuple [tuple [Any , ...], str ]] = {}
156
+ for expr in exprs :
157
+ if expr ._depth == 0 :
158
+ # e.g. agg(nw.len()) # noqa: ERA001
159
159
if (
160
- expr ._depth != 1 or expr . _root_names is None or expr ._output_names is None
160
+ expr ._output_names is None or expr ._function_name != "len"
161
161
): # pragma: no cover
162
162
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
163
163
raise AssertionError (msg )
164
-
165
- function_name = remove_prefix (expr ._function_name , "col->" )
166
- function_name , option = polars_to_arrow_aggregations ().get (
167
- function_name , (function_name , None )
164
+ simple_aggregations [expr ._output_names [0 ]] = (
165
+ (keys [0 ], "count" , pc .CountOptions (mode = "all" )),
166
+ f"{ keys [0 ]} _count" ,
168
167
)
168
+ continue
169
169
170
- for root_name , output_name in zip (expr ._root_names , expr ._output_names ):
171
- simple_aggregations [output_name ] = (
172
- (root_name , function_name , option ),
173
- f"{ root_name } _{ function_name } " ,
174
- )
175
-
176
- aggs : list [Any ] = []
177
- expected_pyarrow_column_names = keys .copy ()
178
- new_column_names = keys .copy ()
179
- for output_name , (
180
- aggregation_args ,
181
- pyarrow_output_name ,
182
- ) in simple_aggregations .items ():
183
- aggs .append (aggregation_args )
184
- expected_pyarrow_column_names .append (pyarrow_output_name )
185
- new_column_names .append (output_name )
186
-
187
- result_simple = grouped .aggregate (aggs )
188
-
189
- # Rename columns, being very careful
190
- expected_old_names_indices : dict [str , list [int ]] = collections .defaultdict (list )
191
- for idx , item in enumerate (expected_pyarrow_column_names ):
192
- expected_old_names_indices [item ].append (idx )
193
- if not (
194
- set (result_simple .column_names ) == set (expected_pyarrow_column_names )
195
- and len (result_simple .column_names ) == len (expected_pyarrow_column_names )
170
+ # e.g. agg(nw.mean('a')) # noqa: ERA001
171
+ if (
172
+ expr ._depth != 1 or expr ._root_names is None or expr ._output_names is None
196
173
): # pragma: no cover
197
- msg = (
198
- f"Safety assertion failed, expected { expected_pyarrow_column_names } "
199
- f"got { result_simple .column_names } , "
200
- "please report a bug at https://github.com/narwhals-dev/narwhals/issues"
201
- )
174
+ msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
202
175
raise AssertionError (msg )
203
- index_map : list [int ] = [
204
- expected_old_names_indices [item ].pop (0 ) for item in result_simple .column_names
205
- ]
206
- new_column_names = [new_column_names [i ] for i in index_map ]
207
-
208
- result_simple = result_simple .rename_columns (new_column_names )
209
- return from_dataframe (result_simple )
210
-
211
- msg = (
212
- "Non-trivial complex aggregation found.\n \n "
213
- "Hint: you were probably trying to apply a non-elementary aggregation with a "
214
- "pyarrow table.\n "
215
- "Please rewrite your query such that group-by aggregations "
216
- "are elementary. For example, instead of:\n \n "
217
- " df.group_by('a').agg(nw.col('b').round(2).mean())\n \n "
218
- "use:\n \n "
219
- " df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n \n "
220
- )
221
- raise ValueError (msg )
176
+
177
+ function_name = remove_prefix (expr ._function_name , "col->" )
178
+ function_name , option = polars_to_arrow_aggregations ().get (
179
+ function_name , (function_name , None )
180
+ )
181
+
182
+ for root_name , output_name in zip (expr ._root_names , expr ._output_names ):
183
+ simple_aggregations [output_name ] = (
184
+ (root_name , function_name , option ),
185
+ f"{ root_name } _{ function_name } " ,
186
+ )
187
+
188
+ aggs : list [Any ] = []
189
+ expected_pyarrow_column_names = keys .copy ()
190
+ new_column_names = keys .copy ()
191
+ for output_name , (
192
+ aggregation_args ,
193
+ pyarrow_output_name ,
194
+ ) in simple_aggregations .items ():
195
+ aggs .append (aggregation_args )
196
+ expected_pyarrow_column_names .append (pyarrow_output_name )
197
+ new_column_names .append (output_name )
198
+
199
+ result_simple = grouped .aggregate (aggs )
200
+
201
+ # Rename columns, being very careful
202
+ expected_old_names_indices : dict [str , list [int ]] = collections .defaultdict (list )
203
+ for idx , item in enumerate (expected_pyarrow_column_names ):
204
+ expected_old_names_indices [item ].append (idx )
205
+ if not (
206
+ set (result_simple .column_names ) == set (expected_pyarrow_column_names )
207
+ and len (result_simple .column_names ) == len (expected_pyarrow_column_names )
208
+ ): # pragma: no cover
209
+ msg = (
210
+ f"Safety assertion failed, expected { expected_pyarrow_column_names } "
211
+ f"got { result_simple .column_names } , "
212
+ "please report a bug at https://github.com/narwhals-dev/narwhals/issues"
213
+ )
214
+ raise AssertionError (msg )
215
+ index_map : list [int ] = [
216
+ expected_old_names_indices [item ].pop (0 ) for item in result_simple .column_names
217
+ ]
218
+ new_column_names = [new_column_names [i ] for i in index_map ]
219
+
220
+ result_simple = result_simple .rename_columns (new_column_names )
221
+ return from_dataframe (result_simple )
0 commit comments