@@ -292,20 +292,33 @@ def __array_ufunc__(
292292 is_custom_ufunc = getattr (ufunc , "signature" ) is not None # noqa: B009
293293 num_expr = sum (isinstance (inp , Expr ) for inp in inputs )
294294 exprs = [
295- (inp , Expr , i ) if isinstance (inp , Expr ) else (inp , None , i )
295+ (inp , True , i ) if isinstance (inp , Expr ) else (inp , False , i )
296296 for i , inp in enumerate (inputs )
297297 ]
298+
298299 if num_expr == 1 :
299- root_expr = next (expr [0 ] for expr in exprs if expr [1 ] == Expr )
300+ root_expr = next (expr [0 ] for expr in exprs if expr [1 ])
300301 else :
301- root_expr = F .struct (expr [0 ] for expr in exprs if expr [1 ] == Expr )
302+ # We rename all but the first expression in case someone did e.g.
303+ # np.divide(pl.col("a"), pl.col("a")); we'll be creating a struct
304+ # below, and structs can't have duplicate names.
305+ first_renamable_expr = True
306+ actual_exprs = []
307+ for inp , is_actual_expr , index in exprs :
308+ if is_actual_expr :
309+ if first_renamable_expr :
310+ first_renamable_expr = False
311+ else :
312+ inp = inp .alias (f"argument_{ index } " )
313+ actual_exprs .append (inp )
314+ root_expr = F .struct (actual_exprs )
302315
303316 def function (s : Series ) -> Series : # pragma: no cover
304317 args = []
305318 for i , expr in enumerate (exprs ):
306- if expr [1 ] == Expr and num_expr > 1 :
319+ if expr [1 ] and num_expr > 1 :
307320 args .append (s .struct [i ])
308- elif expr [1 ] == Expr :
321+ elif expr [1 ]:
309322 args .append (s )
310323 else :
311324 args .append (expr [0 ])
@@ -323,10 +336,8 @@ def function(s: Series) -> Series: # pragma: no cover
323336 CustomUFuncWarning ,
324337 stacklevel = find_stacklevel (),
325338 )
326- return root_expr .map_batches (
327- function , is_elementwise = False
328- ).meta .undo_aliases ()
329- return root_expr .map_batches (function , is_elementwise = True ).meta .undo_aliases ()
339+ return root_expr .map_batches (function , is_elementwise = False )
340+ return root_expr .map_batches (function , is_elementwise = True )
330341
331342 @classmethod
332343 def deserialize (
0 commit comments