Skip to content

Commit c7d8574

Browse files
wtnclaude
andcommitted
fix(python): Reject Expr objects in replace sequence arguments
Co-authored-by: Claude <noreply@anthropic.com>
1 parent b7c28d8 commit c7d8574

File tree

3 files changed

+67
-0
lines changed

3 files changed

+67
-0
lines changed

py-polars/src/polars/expr/expr.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11076,8 +11076,20 @@ def replace(
1107611076
old = list(old.keys())
1107711077
else:
1107811078
if isinstance(old, Sequence) and not isinstance(old, (str, pl.Series)):
11079+
if any(isinstance(v, Expr) for v in old):
11080+
msg = (
11081+
"passing expressions to `old` is not supported when `old` "
11082+
"is a sequence; use a scalar expression or literal values"
11083+
)
11084+
raise TypeError(msg)
1107911085
old = pl.Series(old)
1108011086
if isinstance(new, Sequence) and not isinstance(new, (str, pl.Series)):
11087+
if any(isinstance(v, Expr) for v in new):
11088+
msg = (
11089+
"passing expressions to `new` is not supported when `new` "
11090+
"is a sequence; use a scalar expression or literal values"
11091+
)
11092+
raise TypeError(msg)
1108111093
new = pl.Series(new)
1108211094

1108311095
old_pyexpr = parse_into_expression(old, str_as_lit=True) # type: ignore[arg-type]
@@ -11270,6 +11282,21 @@ def replace_strict(
1127011282
raise TypeError(msg)
1127111283
new = list(old.values())
1127211284
old = list(old.keys())
11285+
else:
11286+
if isinstance(old, Sequence) and not isinstance(old, (str, pl.Series)):
11287+
if any(isinstance(v, Expr) for v in old):
11288+
msg = (
11289+
"passing expressions to `old` is not supported when `old` "
11290+
"is a sequence; use a scalar expression or literal values"
11291+
)
11292+
raise TypeError(msg)
11293+
if isinstance(new, Sequence) and not isinstance(new, (str, pl.Series)):
11294+
if any(isinstance(v, Expr) for v in new):
11295+
msg = (
11296+
"passing expressions to `new` is not supported when `new` "
11297+
"is a sequence; use a scalar expression or literal values"
11298+
)
11299+
raise TypeError(msg)
1127311300

1127411301
old_pyexpr = parse_into_expression(old, str_as_lit=True) # type: ignore[arg-type]
1127511302
new_pyexpr = parse_into_expression(new, str_as_lit=True) # type: ignore[arg-type]

py-polars/tests/unit/operations/test_replace.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,3 +289,23 @@ def test_replace_single_argument_not_mapping() -> None:
289289
match="`new` argument is required if `old` argument is not a Mapping type",
290290
):
291291
df.select(pl.col("a").replace("b"))
292+
293+
294+
def test_replace_expr_in_sequence_raises() -> None:
295+
s = pl.Series([1, 2, 3])
296+
with pytest.raises(
297+
TypeError,
298+
match="passing expressions to `old` is not supported when `old` is a sequence",
299+
):
300+
s.replace([pl.lit(1)], [2])
301+
302+
with pytest.raises(
303+
TypeError,
304+
match="passing expressions to `new` is not supported when `new` is a sequence",
305+
):
306+
s.replace([1], [pl.lit(2)])
307+
308+
# Scalar expressions are still allowed
309+
result = s.replace(pl.lit(1), pl.lit(10))
310+
expected = pl.Series([10, 2, 3])
311+
assert_series_equal(result, expected)

py-polars/tests/unit/operations/test_replace_strict.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,3 +419,23 @@ def test_replace_strict_nested_mapping_22554() -> None:
419419
),
420420
pl.Series([[42], [13], [37]]),
421421
)
422+
423+
424+
def test_replace_strict_expr_in_sequence_raises() -> None:
425+
s = pl.Series([1, 2, 3])
426+
with pytest.raises(
427+
TypeError,
428+
match="passing expressions to `old` is not supported when `old` is a sequence",
429+
):
430+
s.replace_strict([pl.lit(1)], [2])
431+
432+
with pytest.raises(
433+
TypeError,
434+
match="passing expressions to `new` is not supported when `new` is a sequence",
435+
):
436+
s.replace_strict([1], [pl.lit(2)])
437+
438+
# Scalar expressions are still allowed
439+
result = s.replace_strict(pl.lit(1), pl.lit(10), default=pl.lit(None))
440+
expected = pl.Series([10, None, None], dtype=pl.Int32)
441+
assert_series_equal(result, expected)

0 commit comments

Comments
 (0)