Skip to content

Commit 6584d15

Browse files
wtnclaude
andcommitted
fix(python): Reject Expr objects in replace sequence arguments
Co-authored-by: Claude <noreply@anthropic.com>
1 parent 8fc67c7 commit 6584d15

File tree

3 files changed

+158
-5
lines changed

3 files changed

+158
-5
lines changed

py-polars/src/polars/expr/expr.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11165,16 +11165,38 @@ def replace(
1116511165
raise TypeError(msg)
1116611166
new = list(old.values())
1116711167
old = list(old.keys())
11168+
11169+
old_is_seq = isinstance(old, Sequence) and not isinstance(old, (str, pl.Series))
11170+
new_is_seq = isinstance(new, Sequence) and not isinstance(new, (str, pl.Series))
11171+
has_expr_in_old = old_is_seq and any(isinstance(v, Expr) for v in old) # type: ignore[union-attr]
11172+
has_expr_in_new = new_is_seq and any(isinstance(v, Expr) for v in new) # type: ignore[union-attr]
11173+
11174+
if has_expr_in_old or has_expr_in_new:
11175+
old_list = list(old) if old_is_seq else [old] # type: ignore[arg-type, misc]
11176+
new_list = list(new) if new_is_seq else [new] # type: ignore[arg-type]
11177+
11178+
if len(new_list) == 1 and len(old_list) > 1:
11179+
new_list = new_list * len(old_list)
11180+
11181+
if len(old_list) != len(new_list):
11182+
msg = f"lengths of `old` ({len(old_list)}) and `new` ({len(new_list)}) must match"
11183+
raise ValueError(msg)
11184+
11185+
# when(False).then(self) preserves the column name
11186+
result: Expr = F.when(F.lit(False)).then(self)
11187+
for old_val, new_val in zip(old_list, new_list, strict=True):
11188+
result = result.when(self == old_val).then(new_val) # type: ignore[attr-defined]
11189+
result = result.otherwise(self) # type: ignore[attr-defined]
1116811190
else:
11169-
if isinstance(old, Sequence) and not isinstance(old, (str, pl.Series)):
11191+
if old_is_seq:
1117011192
old = pl.Series(old)
11171-
if isinstance(new, Sequence) and not isinstance(new, (str, pl.Series)):
11193+
if new_is_seq:
1117211194
new = pl.Series(new)
1117311195

11174-
old_pyexpr = parse_into_expression(old, str_as_lit=True) # type: ignore[arg-type]
11175-
new_pyexpr = parse_into_expression(new, str_as_lit=True)
11196+
old_pyexpr = parse_into_expression(old, str_as_lit=True) # type: ignore[arg-type]
11197+
new_pyexpr = parse_into_expression(new, str_as_lit=True) # type: ignore[arg-type]
1117611198

11177-
result = wrap_expr(self._pyexpr.replace(old_pyexpr, new_pyexpr))
11199+
result = wrap_expr(self._pyexpr.replace(old_pyexpr, new_pyexpr))
1117811200

1117911201
if return_dtype is not None:
1118011202
result = result.cast(return_dtype)
@@ -11362,6 +11384,37 @@ def replace_strict(
1136211384
new = list(old.values())
1136311385
old = list(old.keys())
1136411386

11387+
old_is_seq = isinstance(old, Sequence) and not isinstance(old, (str, pl.Series))
11388+
new_is_seq = isinstance(new, Sequence) and not isinstance(new, (str, pl.Series))
11389+
has_expr_in_old = old_is_seq and any(isinstance(v, Expr) for v in old) # type: ignore[union-attr]
11390+
has_expr_in_new = new_is_seq and any(isinstance(v, Expr) for v in new) # type: ignore[union-attr]
11391+
11392+
if has_expr_in_old or has_expr_in_new:
11393+
old_list = list(old) if old_is_seq else [old] # type: ignore[arg-type, misc]
11394+
new_list = list(new) if new_is_seq else [new] # type: ignore[arg-type]
11395+
11396+
if len(new_list) == 1 and len(old_list) > 1:
11397+
new_list = new_list * len(old_list)
11398+
11399+
if len(old_list) != len(new_list):
11400+
msg = f"lengths of `old` ({len(old_list)}) and `new` ({len(new_list)}) must match"
11401+
raise ValueError(msg)
11402+
11403+
# when(False).then(self) preserves the column name
11404+
result: Expr = F.when(F.lit(False)).then(self)
11405+
for old_val, new_val in zip(old_list, new_list, strict=True):
11406+
result = result.when(self == old_val).then(new_val) # type: ignore[attr-defined]
11407+
11408+
if default is no_default:
11409+
result = result.otherwise(None) # type: ignore[attr-defined]
11410+
else:
11411+
result = result.otherwise(default) # type: ignore[attr-defined]
11412+
11413+
if return_dtype is not None:
11414+
result = result.cast(return_dtype)
11415+
11416+
return result
11417+
1136511418
old_pyexpr = parse_into_expression(old, str_as_lit=True) # type: ignore[arg-type]
1136611419
new_pyexpr = parse_into_expression(new, str_as_lit=True) # type: ignore[arg-type]
1136711420

py-polars/tests/unit/operations/test_replace.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,3 +289,54 @@ def test_replace_single_argument_not_mapping() -> None:
289289
match="`new` argument is required if `old` argument is not a Mapping type",
290290
):
291291
df.select(pl.col("a").replace("b"))
292+
293+
294+
def test_replace_expr_in_sequence() -> None:
295+
s = pl.Series([1, 2, 3])
296+
297+
result = s.replace([pl.lit(1)], [10])
298+
expected = pl.Series([10, 2, 3])
299+
assert_series_equal(result, expected)
300+
301+
result = s.replace([1], [pl.lit(10)])
302+
expected = pl.Series([10, 2, 3])
303+
assert_series_equal(result, expected)
304+
305+
result = s.replace([pl.lit(1), 2], [10, pl.lit(20)])
306+
expected = pl.Series([10, 20, 3])
307+
assert_series_equal(result, expected)
308+
309+
result = s.replace(pl.lit(1), pl.lit(10))
310+
expected = pl.Series([10, 2, 3])
311+
assert_series_equal(result, expected)
312+
313+
314+
def test_replace_expr_in_sequence_with_column_refs() -> None:
315+
df = pl.DataFrame({"a": [1, 2, 3], "b": [10, 20, 30]})
316+
317+
result = df.select(
318+
pl.col("a").replace([pl.col("a").min(), pl.col("a").max()], [100, 300])
319+
)
320+
expected = pl.DataFrame({"a": [100, 2, 300]})
321+
assert_frame_equal(result, expected)
322+
323+
result = df.select(
324+
pl.col("a").replace([1, 2], [pl.col("b").first(), pl.col("b").sum()])
325+
)
326+
expected = pl.DataFrame({"a": [10, 60, 3]})
327+
assert_frame_equal(result, expected)
328+
329+
330+
def test_replace_expr_in_sequence_many_to_one() -> None:
331+
s = pl.Series([1, 2, 3])
332+
result = s.replace([pl.lit(1), 2], [10])
333+
expected = pl.Series([10, 10, 3])
334+
assert_series_equal(result, expected)
335+
336+
337+
def test_replace_expr_in_sequence_length_mismatch() -> None:
338+
s = pl.Series([1, 2, 3])
339+
with pytest.raises(
340+
ValueError, match=r"lengths of `old` \(2\) and `new` \(3\) must match"
341+
):
342+
s.replace([pl.lit(1), 2], [10, 20, 30])

py-polars/tests/unit/operations/test_replace_strict.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,3 +419,52 @@ def test_replace_strict_nested_mapping_22554() -> None:
419419
),
420420
pl.Series([[42], [13], [37]]),
421421
)
422+
423+
424+
def test_replace_strict_expr_in_sequence() -> None:
425+
s = pl.Series([1, 2, 3])
426+
427+
result = s.replace_strict([pl.lit(1)], [10], default=None)
428+
expected = pl.Series([10, None, None])
429+
assert_series_equal(result, expected, check_dtypes=False)
430+
431+
result = s.replace_strict([1], [pl.lit(10)], default=None)
432+
expected = pl.Series([10, None, None])
433+
assert_series_equal(result, expected, check_dtypes=False)
434+
435+
result = s.replace_strict([pl.lit(1), 2], [10, pl.lit(20)], default=-1)
436+
expected = pl.Series([10, 20, -1])
437+
assert_series_equal(result, expected, check_dtypes=False)
438+
439+
result = s.replace_strict(pl.lit(1), pl.lit(10), default=pl.lit(None))
440+
expected = pl.Series([10, None, None], dtype=pl.Int32)
441+
assert_series_equal(result, expected)
442+
443+
444+
def test_replace_strict_expr_in_sequence_many_to_one() -> None:
445+
s = pl.Series([1, 2, 3])
446+
result = s.replace_strict([pl.lit(1), 2], [10], default=-1)
447+
expected = pl.Series([10, 10, -1])
448+
assert_series_equal(result, expected, check_dtypes=False)
449+
450+
451+
def test_replace_strict_expr_in_sequence_length_mismatch() -> None:
452+
s = pl.Series([1, 2, 3])
453+
with pytest.raises(
454+
ValueError, match=r"lengths of `old` \(2\) and `new` \(3\) must match"
455+
):
456+
s.replace_strict([pl.lit(1), 2], [10, 20, 30], default=None)
457+
458+
459+
def test_replace_strict_expr_in_sequence_no_default() -> None:
460+
s = pl.Series([1, 2, 3])
461+
result = s.replace_strict([pl.lit(1)], [10])
462+
expected = pl.Series([10, None, None])
463+
assert_series_equal(result, expected, check_dtypes=False)
464+
465+
466+
def test_replace_strict_expr_in_sequence_return_dtype() -> None:
467+
s = pl.Series([1, 2, 3])
468+
result = s.replace_strict([pl.lit(1)], [10], default=None, return_dtype=pl.Float64)
469+
expected = pl.Series([10.0, None, None], dtype=pl.Float64)
470+
assert_series_equal(result, expected)

0 commit comments

Comments
 (0)