Skip to content

Commit cf9edc3

Browse files
wtnclaude
andcommitted
fix(python): Suggest str.contains for string containment in map_elements
Co-authored-by: Claude <noreply@anthropic.com>
1 parent b25b4a8 commit cf9edc3

File tree

3 files changed

+156
-6
lines changed

3 files changed

+156
-6
lines changed

.github/scripts/test_bytecode_parser.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
import pytest
2222
from polars._utils.udfs import BytecodeParser
2323
from tests.unit.operations.map.test_inefficient_map_warning import (
24+
MY_COLLECTION,
2425
MY_DICT,
26+
MY_STRING,
27+
MY_SUBSTRING,
2528
NOOP_TEST_CASES,
2629
TEST_CASES,
2730
)
@@ -52,7 +55,10 @@ def test_bytecode_parser_expression_in_ipython(
5255
"from datetime import datetime; "
5356
"import numpy as np; "
5457
"import json; "
55-
f"MY_DICT = {MY_DICT};"
58+
f"MY_DICT = {MY_DICT}; "
59+
f"MY_COLLECTION = {MY_COLLECTION}; "
60+
f"MY_STRING = {repr(MY_STRING)}; "
61+
f"MY_SUBSTRING = {repr(MY_SUBSTRING)}; "
5662
f'bytecode_parser = BytecodeParser({func}, map_target="expr");'
5763
f'print(bytecode_parser.to_expression("{col}"));'
5864
)

py-polars/src/polars/_utils/udfs.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -695,12 +695,46 @@ def _expr(self, value: StackEntry, col: str, param_name: str, depth: int) -> str
695695
not_ = "" if op == "is" else "not_"
696696
return f"{e1}.is_{not_}null()"
697697
elif op in ("in", "not in"):
698-
not_ = "" if op == "in" else "~"
699-
return (
700-
f"{not_}({e1}.is_in({e2}))"
701-
if " " in e1
702-
else f"{not_}{e1}.is_in({e2})"
698+
e2_stripped = e2.lstrip()
699+
is_collection_literal = e2_stripped.startswith(
700+
("(", "[", "{", "frozenset(")
703701
)
702+
703+
is_collection_variable = False
704+
if not is_collection_literal and not e2.startswith(
705+
("pl.col(", "'")
706+
):
707+
if not self._caller_variables:
708+
self._caller_variables = _get_all_caller_variables()
709+
var_value = self._caller_variables.get(e2)
710+
if isinstance(var_value, (list, tuple, set, frozenset, dict)):
711+
is_collection_variable = True
712+
713+
if is_collection_literal or is_collection_variable:
714+
not_ = "" if op == "in" else "~"
715+
return (
716+
f"{not_}({e1}.is_in({e2}))"
717+
if " " in e1
718+
else f"{not_}{e1}.is_in({e2})"
719+
)
720+
else:
721+
e2_is_col = e2.startswith("pl.col(")
722+
e1_is_col = e1.startswith("pl.col(")
723+
724+
if e2_is_col:
725+
needle = f"pl.lit({e1})" if not e1_is_col else e1
726+
haystack = e2
727+
else:
728+
needle = e1
729+
haystack = f"pl.lit({e2})"
730+
731+
contains_expr = (
732+
f"{haystack}.str.contains({needle}, literal=True)"
733+
)
734+
735+
if op == "not in":
736+
return f"~{contains_expr}"
737+
return contains_expr
704738
elif op == "replace_strict":
705739
if not self._caller_variables:
706740
self._caller_variables = _get_all_caller_variables()

py-polars/tests/unit/operations/map/test_inefficient_map_warning.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
MY_CONSTANT = 3
2525
MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"}
2626
MY_LIST = [1, 2, 3]
27+
MY_STRING = "qwerty"
28+
MY_SUBSTRING = "we"
29+
MY_COLLECTION = [2, 3, 4]
2730

2831
# column_name, function, expected_suggestion
2932
TEST_CASES = [
@@ -70,12 +73,68 @@
7073
),
7174
("a", "lambda x: x in (2, 3, 4)", 'pl.col("a").is_in((2, 3, 4))', None),
7275
("a", "lambda x: x not in (2, 3, 4)", '~pl.col("a").is_in((2, 3, 4))', None),
76+
("a", "lambda x: x in MY_COLLECTION", 'pl.col("a").is_in(MY_COLLECTION)', None),
77+
("a", "lambda x: x in MY_DICT", 'pl.col("a").is_in(MY_DICT)', None),
78+
(
79+
"a",
80+
"lambda x: (x + 1) in (1, 2, 3)",
81+
'((pl.col("a") + 1).is_in((1, 2, 3)))',
82+
None,
83+
),
7384
(
7485
"a",
7586
"lambda x: x in (1, 2, 3, 4, 3) and x % 2 == 0 and x > 0",
7687
'pl.col("a").is_in((1, 2, 3, 4, 3)) & ((pl.col("a") % 2) == 0) & (pl.col("a") > 0)',
7788
None,
7889
),
90+
# ---------------------------------------------
91+
# string containment with 'in' operator
92+
# ---------------------------------------------
93+
(
94+
"b",
95+
"lambda x: x in MY_STRING",
96+
'pl.lit(MY_STRING).str.contains(pl.col("b"), literal=True)',
97+
None,
98+
),
99+
(
100+
"b",
101+
"lambda x: MY_SUBSTRING in x",
102+
'pl.col("b").str.contains(pl.lit(MY_SUBSTRING), literal=True)',
103+
None,
104+
),
105+
(
106+
"b",
107+
'lambda x: "A" in x',
108+
"pl.col(\"b\").str.contains(pl.lit('A'), literal=True)",
109+
None,
110+
),
111+
(
112+
"b",
113+
"lambda x: x not in MY_STRING",
114+
'~pl.lit(MY_STRING).str.contains(pl.col("b"), literal=True)',
115+
None,
116+
),
117+
(
118+
"b",
119+
"lambda x: x in x",
120+
'pl.col("b").str.contains(pl.col("b"), literal=True)',
121+
None,
122+
),
123+
(
124+
"b",
125+
'lambda x: "test" in x',
126+
"pl.col(\"b\").str.contains(pl.lit('test'), literal=True)",
127+
None,
128+
),
129+
(
130+
"b",
131+
'lambda x: x not in "hello"',
132+
"~pl.lit('hello').str.contains(pl.col(\"b\"), literal=True)",
133+
None,
134+
),
135+
# ---------------------------------------------
136+
# constants
137+
# ---------------------------------------------
79138
("a", "lambda x: MY_CONSTANT + x", 'MY_CONSTANT + pl.col("a")', None),
80139
(
81140
"a",
@@ -313,6 +372,9 @@
313372
"MY_CONSTANT": MY_CONSTANT,
314373
"MY_DICT": MY_DICT,
315374
"MY_LIST": MY_LIST,
375+
"MY_STRING": MY_STRING,
376+
"MY_SUBSTRING": MY_SUBSTRING,
377+
"MY_COLLECTION": MY_COLLECTION,
316378
"cosh": cosh,
317379
"datetime": datetime,
318380
"dt": dt,
@@ -604,3 +666,51 @@ def plus(value: int, amount: int) -> int:
604666
df = pl.DataFrame(data)
605667
# should not warn
606668
_ = df["a"].map_elements(partial(plus, amount=1))
669+
670+
671+
@pytest.mark.filterwarnings(
672+
"ignore:.*:polars.exceptions.PolarsInefficientMapWarning",
673+
"ignore:.*:polars.exceptions.MapWithoutReturnDtypeWarning",
674+
)
675+
@pytest.mark.parametrize("pattern", [".", "^", "$", "[0]", "a|b", "a+", "a?"])
676+
def test_string_containment_regex_metacharacters_17182(pattern: str) -> None:
677+
df = pl.DataFrame({"b": [f"x{pattern}y", "xyz", pattern, "hello"]})
678+
679+
result_lambda = df.select(
680+
pl.col("b").map_elements(
681+
lambda x: x.find(pattern) >= 0,
682+
return_dtype=pl.Boolean,
683+
)
684+
)
685+
686+
func = lambda x: pattern in x # noqa: E731
687+
parser = BytecodeParser(func, map_target="expr")
688+
suggested = parser.to_expression("b")
689+
assert suggested is not None
690+
691+
result_suggested = df.select(eval(suggested, {"pl": pl, "pattern": pattern}))
692+
assert_frame_equal(result_lambda, result_suggested)
693+
694+
695+
@pytest.mark.filterwarnings(
696+
"ignore:.*:polars.exceptions.PolarsInefficientMapWarning",
697+
"ignore:.*:polars.exceptions.MapWithoutReturnDtypeWarning",
698+
)
699+
@pytest.mark.parametrize("pattern", ["*", "(", ")", "["])
700+
def test_string_containment_invalid_regex_17182(pattern: str) -> None:
701+
df = pl.DataFrame({"b": [f"x{pattern}y", "xyz", pattern, "hello"]})
702+
703+
result_lambda = df.select(
704+
pl.col("b").map_elements(
705+
lambda x: x.find(pattern) >= 0,
706+
return_dtype=pl.Boolean,
707+
)
708+
)
709+
710+
func = lambda x: pattern in x # noqa: E731
711+
parser = BytecodeParser(func, map_target="expr")
712+
suggested = parser.to_expression("b")
713+
assert suggested is not None
714+
715+
result_suggested = df.select(eval(suggested, {"pl": pl, "pattern": pattern}))
716+
assert_frame_equal(result_lambda, result_suggested)

0 commit comments

Comments
 (0)