Skip to content

Commit 86097b4

Browse files
wtnclaude
andcommitted
fix(python): Reject Expr objects in replace sequence arguments
Co-authored-by: Claude <noreply@anthropic.com>
1 parent 405b194 commit 86097b4

File tree

3 files changed

+124
-5
lines changed

3 files changed

+124
-5
lines changed

py-polars/src/polars/expr/expr.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11074,16 +11074,38 @@ def replace(
1107411074
raise TypeError(msg)
1107511075
new = list(old.values())
1107611076
old = list(old.keys())
11077+
11078+
old_is_seq = isinstance(old, Sequence) and not isinstance(old, (str, pl.Series))
11079+
new_is_seq = isinstance(new, Sequence) and not isinstance(new, (str, pl.Series))
11080+
has_expr_in_old = old_is_seq and any(isinstance(v, Expr) for v in old) # type: ignore[union-attr]
11081+
has_expr_in_new = new_is_seq and any(isinstance(v, Expr) for v in new) # type: ignore[union-attr]
11082+
11083+
if has_expr_in_old or has_expr_in_new:
11084+
old_list = list(old) if old_is_seq else [old] # type: ignore[arg-type, misc]
11085+
new_list = list(new) if new_is_seq else [new] # type: ignore[arg-type]
11086+
11087+
if len(new_list) == 1 and len(old_list) > 1:
11088+
new_list = new_list * len(old_list)
11089+
11090+
if len(old_list) != len(new_list):
11091+
msg = f"lengths of `old` ({len(old_list)}) and `new` ({len(new_list)}) must match"
11092+
raise ValueError(msg)
11093+
11094+
# when(False).then(self) preserves the column name
11095+
result: Expr = F.when(F.lit(False)).then(self)
11096+
for old_val, new_val in zip(old_list, new_list):
11097+
result = result.when(self == old_val).then(new_val) # type: ignore[attr-defined]
11098+
result = result.otherwise(self) # type: ignore[attr-defined]
1107711099
else:
11078-
if isinstance(old, Sequence) and not isinstance(old, (str, pl.Series)):
11100+
if old_is_seq:
1107911101
old = pl.Series(old)
11080-
if isinstance(new, Sequence) and not isinstance(new, (str, pl.Series)):
11102+
if new_is_seq:
1108111103
new = pl.Series(new)
1108211104

11083-
old_pyexpr = parse_into_expression(old, str_as_lit=True) # type: ignore[arg-type]
11084-
new_pyexpr = parse_into_expression(new, str_as_lit=True)
11105+
old_pyexpr = parse_into_expression(old, str_as_lit=True) # type: ignore[arg-type]
11106+
new_pyexpr = parse_into_expression(new, str_as_lit=True) # type: ignore[arg-type]
1108511107

11086-
result = wrap_expr(self._pyexpr.replace(old_pyexpr, new_pyexpr))
11108+
result = wrap_expr(self._pyexpr.replace(old_pyexpr, new_pyexpr))
1108711109

1108811110
if return_dtype is not None:
1108911111
result = result.cast(return_dtype)
@@ -11271,6 +11293,37 @@ def replace_strict(
1127111293
new = list(old.values())
1127211294
old = list(old.keys())
1127311295

11296+
old_is_seq = isinstance(old, Sequence) and not isinstance(old, (str, pl.Series))
11297+
new_is_seq = isinstance(new, Sequence) and not isinstance(new, (str, pl.Series))
11298+
has_expr_in_old = old_is_seq and any(isinstance(v, Expr) for v in old) # type: ignore[union-attr]
11299+
has_expr_in_new = new_is_seq and any(isinstance(v, Expr) for v in new) # type: ignore[union-attr]
11300+
11301+
if has_expr_in_old or has_expr_in_new:
11302+
old_list = list(old) if old_is_seq else [old] # type: ignore[arg-type, misc]
11303+
new_list = list(new) if new_is_seq else [new] # type: ignore[arg-type]
11304+
11305+
if len(new_list) == 1 and len(old_list) > 1:
11306+
new_list = new_list * len(old_list)
11307+
11308+
if len(old_list) != len(new_list):
11309+
msg = f"lengths of `old` ({len(old_list)}) and `new` ({len(new_list)}) must match"
11310+
raise ValueError(msg)
11311+
11312+
# when(False).then(self) preserves the column name
11313+
result: Expr = F.when(F.lit(False)).then(self)
11314+
for old_val, new_val in zip(old_list, new_list):
11315+
result = result.when(self == old_val).then(new_val) # type: ignore[attr-defined]
11316+
11317+
if default is no_default:
11318+
result = result.otherwise(None) # type: ignore[attr-defined]
11319+
else:
11320+
result = result.otherwise(default) # type: ignore[attr-defined]
11321+
11322+
if return_dtype is not None:
11323+
result = result.cast(return_dtype)
11324+
11325+
return result
11326+
1127411327
old_pyexpr = parse_into_expression(old, str_as_lit=True) # type: ignore[arg-type]
1127511328
new_pyexpr = parse_into_expression(new, str_as_lit=True) # type: ignore[arg-type]
1127611329

py-polars/tests/unit/operations/test_replace.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,3 +289,45 @@ def test_replace_single_argument_not_mapping() -> None:
289289
match="`new` argument is required if `old` argument is not a Mapping type",
290290
):
291291
df.select(pl.col("a").replace("b"))
292+
293+
294+
def test_replace_expr_in_sequence() -> None:
295+
s = pl.Series([1, 2, 3])
296+
297+
# Expression in old sequence
298+
result = s.replace([pl.lit(1)], [10])
299+
expected = pl.Series([10, 2, 3])
300+
assert_series_equal(result, expected)
301+
302+
# Expression in new sequence
303+
result = s.replace([1], [pl.lit(10)])
304+
expected = pl.Series([10, 2, 3])
305+
assert_series_equal(result, expected)
306+
307+
# Multiple values with expressions
308+
result = s.replace([pl.lit(1), 2], [10, pl.lit(20)])
309+
expected = pl.Series([10, 20, 3])
310+
assert_series_equal(result, expected)
311+
312+
# Scalar expressions are still allowed
313+
result = s.replace(pl.lit(1), pl.lit(10))
314+
expected = pl.Series([10, 2, 3])
315+
assert_series_equal(result, expected)
316+
317+
318+
def test_replace_expr_in_sequence_with_column_refs() -> None:
319+
df = pl.DataFrame({"a": [1, 2, 3], "b": [10, 20, 30]})
320+
321+
# Replace with column reference expressions
322+
result = df.select(
323+
pl.col("a").replace([pl.col("a").min(), pl.col("a").max()], [100, 300])
324+
)
325+
expected = pl.DataFrame({"a": [100, 2, 300]})
326+
assert_frame_equal(result, expected)
327+
328+
# Replace values with column-derived expressions
329+
result = df.select(
330+
pl.col("a").replace([1, 2], [pl.col("b").first(), pl.col("b").sum()])
331+
)
332+
expected = pl.DataFrame({"a": [10, 60, 3]})
333+
assert_frame_equal(result, expected)

py-polars/tests/unit/operations/test_replace_strict.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,3 +419,27 @@ def test_replace_strict_nested_mapping_22554() -> None:
419419
),
420420
pl.Series([[42], [13], [37]]),
421421
)
422+
423+
424+
def test_replace_strict_expr_in_sequence() -> None:
425+
s = pl.Series([1, 2, 3])
426+
427+
# Expression in old sequence with default
428+
result = s.replace_strict([pl.lit(1)], [10], default=None)
429+
expected = pl.Series([10, None, None])
430+
assert_series_equal(result, expected, check_dtypes=False)
431+
432+
# Expression in new sequence with default
433+
result = s.replace_strict([1], [pl.lit(10)], default=None)
434+
expected = pl.Series([10, None, None])
435+
assert_series_equal(result, expected, check_dtypes=False)
436+
437+
# Multiple values with expressions and default
438+
result = s.replace_strict([pl.lit(1), 2], [10, pl.lit(20)], default=-1)
439+
expected = pl.Series([10, 20, -1])
440+
assert_series_equal(result, expected, check_dtypes=False)
441+
442+
# Scalar expressions are still allowed
443+
result = s.replace_strict(pl.lit(1), pl.lit(10), default=pl.lit(None))
444+
expected = pl.Series([10, None, None], dtype=pl.Int32)
445+
assert_series_equal(result, expected)

0 commit comments

Comments
 (0)