Skip to content

Commit 1008cf2

Browse files
wtnclaude
andcommitted
fix(python): Suggest str.contains for string containment in map_elements
Co-authored-by: Claude <noreply@anthropic.com>
1 parent d92ff0e commit 1008cf2

File tree

3 files changed

+95
-6
lines changed

3 files changed

+95
-6
lines changed

.github/scripts/test_bytecode_parser.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
import pytest
2222
from polars._utils.udfs import BytecodeParser
2323
from tests.unit.operations.map.test_inefficient_map_warning import (
24+
MY_COLLECTION,
2425
MY_DICT,
26+
MY_STRING,
27+
MY_SUBSTRING,
2528
NOOP_TEST_CASES,
2629
TEST_CASES,
2730
)
@@ -52,7 +55,10 @@ def test_bytecode_parser_expression_in_ipython(
5255
"from datetime import datetime; "
5356
"import numpy as np; "
5457
"import json; "
55-
f"MY_DICT = {MY_DICT};"
58+
f"MY_DICT = {MY_DICT}; "
59+
f"MY_COLLECTION = {MY_COLLECTION}; "
60+
f"MY_STRING = {repr(MY_STRING)}; "
61+
f"MY_SUBSTRING = {repr(MY_SUBSTRING)}; "
5662
f'bytecode_parser = BytecodeParser({func}, map_target="expr");'
5763
f'print(bytecode_parser.to_expression("{col}"));'
5864
)

py-polars/src/polars/_utils/udfs.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -696,12 +696,44 @@ def _expr(self, value: StackEntry, col: str, param_name: str, depth: int) -> str
696696
not_ = "" if op == "is" else "not_"
697697
return f"{e1}.is_{not_}null()"
698698
elif op in ("in", "not in"):
699-
not_ = "" if op == "in" else "~"
700-
return (
701-
f"{not_}({e1}.is_in({e2}))"
702-
if " " in e1
703-
else f"{not_}{e1}.is_in({e2})"
699+
e2_stripped = e2.lstrip()
700+
is_collection_literal = e2_stripped.startswith(
701+
("(", "[", "{", "frozenset(")
704702
)
703+
704+
is_collection_variable = False
705+
if not is_collection_literal and not e2.startswith(
706+
("pl.col(", "'")
707+
):
708+
if not self._caller_variables:
709+
self._caller_variables = _get_all_caller_variables()
710+
var_value = self._caller_variables.get(e2)
711+
if isinstance(var_value, (list, tuple, set, frozenset)):
712+
is_collection_variable = True
713+
714+
if is_collection_literal or is_collection_variable:
715+
not_ = "" if op == "in" else "~"
716+
return (
717+
f"{not_}({e1}.is_in({e2}))"
718+
if " " in e1
719+
else f"{not_}{e1}.is_in({e2})"
720+
)
721+
else:
722+
e2_is_col = e2.startswith("pl.col(")
723+
e1_is_col = e1.startswith("pl.col(")
724+
725+
if e2_is_col:
726+
needle = f"pl.lit({e1})" if not e1_is_col else e1
727+
haystack = e2
728+
else:
729+
needle = e1
730+
haystack = f"pl.lit({e2})"
731+
732+
contains_expr = f"{haystack}.str.contains({needle})"
733+
734+
if op == "not in":
735+
return f"~{contains_expr}"
736+
return contains_expr
705737
elif op == "replace_strict":
706738
if not self._caller_variables:
707739
self._caller_variables = _get_all_caller_variables()

py-polars/tests/unit/operations/map/test_inefficient_map_warning.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
MY_CONSTANT = 3
2222
MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"}
2323
MY_LIST = [1, 2, 3]
24+
MY_STRING = "qwerty"
25+
MY_SUBSTRING = "we"
26+
MY_COLLECTION = [2, 3, 4]
2427

2528
# column_name, function, expected_suggestion
2629
TEST_CASES = [
@@ -67,12 +70,57 @@
6770
),
6871
("a", "lambda x: x in (2, 3, 4)", 'pl.col("a").is_in((2, 3, 4))', None),
6972
("a", "lambda x: x not in (2, 3, 4)", '~pl.col("a").is_in((2, 3, 4))', None),
73+
("a", "lambda x: x in MY_COLLECTION", 'pl.col("a").is_in(MY_COLLECTION)', None),
74+
(
75+
"a",
76+
"lambda x: (x + 1) in (1, 2, 3)",
77+
'((pl.col("a") + 1).is_in((1, 2, 3)))',
78+
None,
79+
),
7080
(
7181
"a",
7282
"lambda x: x in (1, 2, 3, 4, 3) and x % 2 == 0 and x > 0",
7383
'pl.col("a").is_in((1, 2, 3, 4, 3)) & ((pl.col("a") % 2) == 0) & (pl.col("a") > 0)',
7484
None,
7585
),
86+
# ---------------------------------------------
87+
# string containment with 'in' operator
88+
# ---------------------------------------------
89+
(
90+
"b",
91+
"lambda x: x in MY_STRING",
92+
'pl.lit(MY_STRING).str.contains(pl.col("b"))',
93+
None,
94+
),
95+
(
96+
"b",
97+
"lambda x: MY_SUBSTRING in x",
98+
'pl.col("b").str.contains(pl.lit(MY_SUBSTRING))',
99+
None,
100+
),
101+
("b", 'lambda x: "A" in x', "pl.col(\"b\").str.contains(pl.lit('A'))", None),
102+
(
103+
"b",
104+
"lambda x: x not in MY_STRING",
105+
'~pl.lit(MY_STRING).str.contains(pl.col("b"))',
106+
None,
107+
),
108+
("b", "lambda x: x in x", 'pl.col("b").str.contains(pl.col("b"))', None),
109+
(
110+
"b",
111+
'lambda x: "test" in x',
112+
"pl.col(\"b\").str.contains(pl.lit('test'))",
113+
None,
114+
),
115+
(
116+
"b",
117+
'lambda x: x not in "hello"',
118+
"~pl.lit('hello').str.contains(pl.col(\"b\"))",
119+
None,
120+
),
121+
# ---------------------------------------------
122+
# constants
123+
# ---------------------------------------------
76124
("a", "lambda x: MY_CONSTANT + x", 'MY_CONSTANT + pl.col("a")', None),
77125
(
78126
"a",
@@ -310,6 +358,9 @@
310358
"MY_CONSTANT": MY_CONSTANT,
311359
"MY_DICT": MY_DICT,
312360
"MY_LIST": MY_LIST,
361+
"MY_STRING": MY_STRING,
362+
"MY_SUBSTRING": MY_SUBSTRING,
363+
"MY_COLLECTION": MY_COLLECTION,
313364
"cosh": cosh,
314365
"datetime": datetime,
315366
"dt": dt,

0 commit comments

Comments
 (0)