Skip to content

Commit 4b7e2bd

Browse files
fix(python): Support duplicate expression names when calling ufuncs (#17641)
Co-authored-by: Itamar Turner-Trauring <itamar@pythonspeed.com>
1 parent 3897a37 commit 4b7e2bd

File tree

2 files changed

+28
-9
lines changed

2 files changed

+28
-9
lines changed

py-polars/polars/expr/expr.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -292,20 +292,33 @@ def __array_ufunc__(
292292
is_custom_ufunc = getattr(ufunc, "signature") is not None # noqa: B009
293293
num_expr = sum(isinstance(inp, Expr) for inp in inputs)
294294
exprs = [
295-
(inp, Expr, i) if isinstance(inp, Expr) else (inp, None, i)
295+
(inp, True, i) if isinstance(inp, Expr) else (inp, False, i)
296296
for i, inp in enumerate(inputs)
297297
]
298+
298299
if num_expr == 1:
299-
root_expr = next(expr[0] for expr in exprs if expr[1] == Expr)
300+
root_expr = next(expr[0] for expr in exprs if expr[1])
300301
else:
301-
root_expr = F.struct(expr[0] for expr in exprs if expr[1] == Expr)
302+
# We rename all but the first expression in case someone did e.g.
303+
# np.divide(pl.col("a"), pl.col("a")); we'll be creating a struct
304+
# below, and structs can't have duplicate names.
305+
first_renamable_expr = True
306+
actual_exprs = []
307+
for inp, is_actual_expr, index in exprs:
308+
if is_actual_expr:
309+
if first_renamable_expr:
310+
first_renamable_expr = False
311+
else:
312+
inp = inp.alias(f"argument_{index}")
313+
actual_exprs.append(inp)
314+
root_expr = F.struct(actual_exprs)
302315

303316
def function(s: Series) -> Series: # pragma: no cover
304317
args = []
305318
for i, expr in enumerate(exprs):
306-
if expr[1] == Expr and num_expr > 1:
319+
if expr[1] and num_expr > 1:
307320
args.append(s.struct[i])
308-
elif expr[1] == Expr:
321+
elif expr[1]:
309322
args.append(s)
310323
else:
311324
args.append(expr[0])
@@ -323,10 +336,8 @@ def function(s: Series) -> Series: # pragma: no cover
323336
CustomUFuncWarning,
324337
stacklevel=find_stacklevel(),
325338
)
326-
return root_expr.map_batches(
327-
function, is_elementwise=False
328-
).meta.undo_aliases()
329-
return root_expr.map_batches(function, is_elementwise=True).meta.undo_aliases()
339+
return root_expr.map_batches(function, is_elementwise=False)
340+
return root_expr.map_batches(function, is_elementwise=True)
330341

331342
@classmethod
332343
def deserialize(

py-polars/tests/unit/interop/numpy/test_ufunc_expr.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,14 @@ def test_ufunc_multiple_expressions() -> None:
120120
assert_series_equal(expected, result) # type: ignore[arg-type]
121121

122122

123+
def test_repeated_name_ufunc_17472() -> None:
124+
"""If a ufunc takes multiple inputs has a repeating name, this works."""
125+
df = pl.DataFrame({"a": [6.0]})
126+
result = df.select(np.divide(pl.col("a"), pl.col("a"))) # type: ignore[call-overload]
127+
expected = pl.DataFrame({"a": [1.0]})
128+
assert_frame_equal(expected, result)
129+
130+
123131
def test_grouped_ufunc() -> None:
124132
df = pl.DataFrame({"id": ["a", "a", "b", "b"], "values": [0.1, 0.1, -0.1, -0.1]})
125133
df.group_by("id").agg(pl.col("values").log1p().sum().pipe(np.expm1))

0 commit comments

Comments
 (0)