Skip to content

Commit 897eea1

Browse files
wtnclaude
andcommitted
fix(python): Suggest str.contains for string containment in map_elements
Co-authored-by: Claude <noreply@anthropic.com>
1 parent b7c28d8 commit 897eea1

File tree

3 files changed

+156
-6
lines changed

3 files changed

+156
-6
lines changed

.github/scripts/test_bytecode_parser.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
import pytest
2222
from polars._utils.udfs import BytecodeParser
2323
from tests.unit.operations.map.test_inefficient_map_warning import (
24+
MY_COLLECTION,
2425
MY_DICT,
26+
MY_STRING,
27+
MY_SUBSTRING,
2528
NOOP_TEST_CASES,
2629
TEST_CASES,
2730
)
@@ -52,7 +55,10 @@ def test_bytecode_parser_expression_in_ipython(
5255
"from datetime import datetime; "
5356
"import numpy as np; "
5457
"import json; "
55-
f"MY_DICT = {MY_DICT};"
58+
f"MY_DICT = {MY_DICT}; "
59+
f"MY_COLLECTION = {MY_COLLECTION}; "
60+
f"MY_STRING = {repr(MY_STRING)}; "
61+
f"MY_SUBSTRING = {repr(MY_SUBSTRING)}; "
5662
f'bytecode_parser = BytecodeParser({func}, map_target="expr");'
5763
f'print(bytecode_parser.to_expression("{col}"));'
5864
)

py-polars/src/polars/_utils/udfs.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -701,12 +701,46 @@ def _expr(self, value: StackEntry, col: str, param_name: str, depth: int) -> str
701701
not_ = "" if op == "is" else "not_"
702702
return f"{e1}.is_{not_}null()"
703703
elif op in ("in", "not in"):
704-
not_ = "" if op == "in" else "~"
705-
return (
706-
f"{not_}({e1}.is_in({e2}))"
707-
if " " in e1
708-
else f"{not_}{e1}.is_in({e2})"
704+
e2_stripped = e2.lstrip()
705+
is_collection_literal = e2_stripped.startswith(
706+
("(", "[", "{", "frozenset(")
709707
)
708+
709+
is_collection_variable = False
710+
if not is_collection_literal and not e2.startswith(
711+
("pl.col(", "'")
712+
):
713+
if not self._caller_variables:
714+
self._caller_variables = _get_all_caller_variables()
715+
var_value = self._caller_variables.get(e2)
716+
if isinstance(var_value, (list, tuple, set, frozenset, dict)):
717+
is_collection_variable = True
718+
719+
if is_collection_literal or is_collection_variable:
720+
not_ = "" if op == "in" else "~"
721+
return (
722+
f"{not_}({e1}.is_in({e2}))"
723+
if " " in e1
724+
else f"{not_}{e1}.is_in({e2})"
725+
)
726+
else:
727+
e2_is_col = e2.startswith("pl.col(")
728+
e1_is_col = e1.startswith("pl.col(")
729+
730+
if e2_is_col:
731+
needle = f"pl.lit({e1})" if not e1_is_col else e1
732+
haystack = e2
733+
else:
734+
needle = e1
735+
haystack = f"pl.lit({e2})"
736+
737+
contains_expr = (
738+
f"{haystack}.str.contains({needle}, literal=True)"
739+
)
740+
741+
if op == "not in":
742+
return f"~{contains_expr}"
743+
return contains_expr
710744
elif op == "replace_strict":
711745
if not self._caller_variables:
712746
self._caller_variables = _get_all_caller_variables()

py-polars/tests/unit/operations/map/test_inefficient_map_warning.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
MY_CONSTANT = 3
2222
MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"}
2323
MY_LIST = [1, 2, 3]
24+
MY_STRING = "qwerty"
25+
MY_SUBSTRING = "we"
26+
MY_COLLECTION = [2, 3, 4]
2427

2528
# column_name, function, expected_suggestion
2629
TEST_CASES = [
@@ -67,12 +70,68 @@
6770
),
6871
("a", "lambda x: x in (2, 3, 4)", 'pl.col("a").is_in((2, 3, 4))', None),
6972
("a", "lambda x: x not in (2, 3, 4)", '~pl.col("a").is_in((2, 3, 4))', None),
73+
("a", "lambda x: x in MY_COLLECTION", 'pl.col("a").is_in(MY_COLLECTION)', None),
74+
("a", "lambda x: x in MY_DICT", 'pl.col("a").is_in(MY_DICT)', None),
75+
(
76+
"a",
77+
"lambda x: (x + 1) in (1, 2, 3)",
78+
'((pl.col("a") + 1).is_in((1, 2, 3)))',
79+
None,
80+
),
7081
(
7182
"a",
7283
"lambda x: x in (1, 2, 3, 4, 3) and x % 2 == 0 and x > 0",
7384
'pl.col("a").is_in((1, 2, 3, 4, 3)) & ((pl.col("a") % 2) == 0) & (pl.col("a") > 0)',
7485
None,
7586
),
87+
# ---------------------------------------------
88+
# string containment with 'in' operator
89+
# ---------------------------------------------
90+
(
91+
"b",
92+
"lambda x: x in MY_STRING",
93+
'pl.lit(MY_STRING).str.contains(pl.col("b"), literal=True)',
94+
None,
95+
),
96+
(
97+
"b",
98+
"lambda x: MY_SUBSTRING in x",
99+
'pl.col("b").str.contains(pl.lit(MY_SUBSTRING), literal=True)',
100+
None,
101+
),
102+
(
103+
"b",
104+
'lambda x: "A" in x',
105+
"pl.col(\"b\").str.contains(pl.lit('A'), literal=True)",
106+
None,
107+
),
108+
(
109+
"b",
110+
"lambda x: x not in MY_STRING",
111+
'~pl.lit(MY_STRING).str.contains(pl.col("b"), literal=True)',
112+
None,
113+
),
114+
(
115+
"b",
116+
"lambda x: x in x",
117+
'pl.col("b").str.contains(pl.col("b"), literal=True)',
118+
None,
119+
),
120+
(
121+
"b",
122+
'lambda x: "test" in x',
123+
"pl.col(\"b\").str.contains(pl.lit('test'), literal=True)",
124+
None,
125+
),
126+
(
127+
"b",
128+
'lambda x: x not in "hello"',
129+
"~pl.lit('hello').str.contains(pl.col(\"b\"), literal=True)",
130+
None,
131+
),
132+
# ---------------------------------------------
133+
# constants
134+
# ---------------------------------------------
76135
("a", "lambda x: MY_CONSTANT + x", 'MY_CONSTANT + pl.col("a")', None),
77136
(
78137
"a",
@@ -310,6 +369,9 @@
310369
"MY_CONSTANT": MY_CONSTANT,
311370
"MY_DICT": MY_DICT,
312371
"MY_LIST": MY_LIST,
372+
"MY_STRING": MY_STRING,
373+
"MY_SUBSTRING": MY_SUBSTRING,
374+
"MY_COLLECTION": MY_COLLECTION,
313375
"cosh": cosh,
314376
"datetime": datetime,
315377
"dt": dt,
@@ -601,3 +663,51 @@ def plus(value: int, amount: int) -> int:
601663
df = pl.DataFrame(data)
602664
# should not warn
603665
_ = df["a"].map_elements(partial(plus, amount=1))
666+
667+
668+
@pytest.mark.filterwarnings(
669+
"ignore:.*:polars.exceptions.PolarsInefficientMapWarning",
670+
"ignore:.*:polars.exceptions.MapWithoutReturnDtypeWarning",
671+
)
672+
@pytest.mark.parametrize("pattern", [".", "^", "$", "[0]", "a|b", "a+", "a?"])
673+
def test_string_containment_regex_metacharacters_17182(pattern: str) -> None:
674+
df = pl.DataFrame({"b": [f"x{pattern}y", "xyz", pattern, "hello"]})
675+
676+
result_lambda = df.select(
677+
pl.col("b").map_elements(
678+
lambda x: x.find(pattern) >= 0,
679+
return_dtype=pl.Boolean,
680+
)
681+
)
682+
683+
func = lambda x: pattern in x # noqa: E731
684+
parser = BytecodeParser(func, map_target="expr")
685+
suggested = parser.to_expression("b")
686+
assert suggested is not None
687+
688+
result_suggested = df.select(eval(suggested, {"pl": pl, "pattern": pattern}))
689+
assert_frame_equal(result_lambda, result_suggested)
690+
691+
692+
@pytest.mark.filterwarnings(
693+
"ignore:.*:polars.exceptions.PolarsInefficientMapWarning",
694+
"ignore:.*:polars.exceptions.MapWithoutReturnDtypeWarning",
695+
)
696+
@pytest.mark.parametrize("pattern", ["*", "(", ")", "["])
697+
def test_string_containment_invalid_regex_17182(pattern: str) -> None:
698+
df = pl.DataFrame({"b": [f"x{pattern}y", "xyz", pattern, "hello"]})
699+
700+
result_lambda = df.select(
701+
pl.col("b").map_elements(
702+
lambda x: x.find(pattern) >= 0,
703+
return_dtype=pl.Boolean,
704+
)
705+
)
706+
707+
func = lambda x: pattern in x # noqa: E731
708+
parser = BytecodeParser(func, map_target="expr")
709+
suggested = parser.to_expression("b")
710+
assert suggested is not None
711+
712+
result_suggested = df.select(eval(suggested, {"pl": pl, "pattern": pattern}))
713+
assert_frame_equal(result_lambda, result_suggested)

0 commit comments

Comments
 (0)