|
21 | 21 | MY_CONSTANT = 3 |
22 | 22 | MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"} |
23 | 23 | MY_LIST = [1, 2, 3] |
| 24 | +MY_STRING = "qwerty" |
| 25 | +MY_SUBSTRING = "we" |
| 26 | +MY_COLLECTION = [2, 3, 4] |
24 | 27 |
|
25 | 28 | # column_name, function, expected_suggestion |
26 | 29 | TEST_CASES = [ |
|
67 | 70 | ), |
68 | 71 | ("a", "lambda x: x in (2, 3, 4)", 'pl.col("a").is_in((2, 3, 4))', None), |
69 | 72 | ("a", "lambda x: x not in (2, 3, 4)", '~pl.col("a").is_in((2, 3, 4))', None), |
| 73 | + ("a", "lambda x: x in MY_COLLECTION", 'pl.col("a").is_in(MY_COLLECTION)', None), |
| 74 | + ("a", "lambda x: x in MY_DICT", 'pl.col("a").is_in(MY_DICT)', None), |
| 75 | + ( |
| 76 | + "a", |
| 77 | + "lambda x: (x + 1) in (1, 2, 3)", |
| 78 | + '((pl.col("a") + 1).is_in((1, 2, 3)))', |
| 79 | + None, |
| 80 | + ), |
70 | 81 | ( |
71 | 82 | "a", |
72 | 83 | "lambda x: x in (1, 2, 3, 4, 3) and x % 2 == 0 and x > 0", |
73 | 84 | 'pl.col("a").is_in((1, 2, 3, 4, 3)) & ((pl.col("a") % 2) == 0) & (pl.col("a") > 0)', |
74 | 85 | None, |
75 | 86 | ), |
| 87 | + # --------------------------------------------- |
| 88 | + # string containment with 'in' operator |
| 89 | + # --------------------------------------------- |
| 90 | + ( |
| 91 | + "b", |
| 92 | + "lambda x: x in MY_STRING", |
| 93 | + 'pl.lit(MY_STRING).str.contains(pl.col("b"), literal=True)', |
| 94 | + None, |
| 95 | + ), |
| 96 | + ( |
| 97 | + "b", |
| 98 | + "lambda x: MY_SUBSTRING in x", |
| 99 | + 'pl.col("b").str.contains(pl.lit(MY_SUBSTRING), literal=True)', |
| 100 | + None, |
| 101 | + ), |
| 102 | + ( |
| 103 | + "b", |
| 104 | + 'lambda x: "A" in x', |
| 105 | + "pl.col(\"b\").str.contains(pl.lit('A'), literal=True)", |
| 106 | + None, |
| 107 | + ), |
| 108 | + ( |
| 109 | + "b", |
| 110 | + "lambda x: x not in MY_STRING", |
| 111 | + '~pl.lit(MY_STRING).str.contains(pl.col("b"), literal=True)', |
| 112 | + None, |
| 113 | + ), |
| 114 | + ( |
| 115 | + "b", |
| 116 | + "lambda x: x in x", |
| 117 | + 'pl.col("b").str.contains(pl.col("b"), literal=True)', |
| 118 | + None, |
| 119 | + ), |
| 120 | + ( |
| 121 | + "b", |
| 122 | + 'lambda x: "test" in x', |
| 123 | + "pl.col(\"b\").str.contains(pl.lit('test'), literal=True)", |
| 124 | + None, |
| 125 | + ), |
| 126 | + ( |
| 127 | + "b", |
| 128 | + 'lambda x: x not in "hello"', |
| 129 | + "~pl.lit('hello').str.contains(pl.col(\"b\"), literal=True)", |
| 130 | + None, |
| 131 | + ), |
| 132 | + # --------------------------------------------- |
| 133 | + # constants |
| 134 | + # --------------------------------------------- |
76 | 135 | ("a", "lambda x: MY_CONSTANT + x", 'MY_CONSTANT + pl.col("a")', None), |
77 | 136 | ( |
78 | 137 | "a", |
|
310 | 369 | "MY_CONSTANT": MY_CONSTANT, |
311 | 370 | "MY_DICT": MY_DICT, |
312 | 371 | "MY_LIST": MY_LIST, |
| 372 | + "MY_STRING": MY_STRING, |
| 373 | + "MY_SUBSTRING": MY_SUBSTRING, |
| 374 | + "MY_COLLECTION": MY_COLLECTION, |
313 | 375 | "cosh": cosh, |
314 | 376 | "datetime": datetime, |
315 | 377 | "dt": dt, |
@@ -601,3 +663,79 @@ def plus(value: int, amount: int) -> int: |
601 | 663 | df = pl.DataFrame(data) |
602 | 664 | # should not warn |
603 | 665 | _ = df["a"].map_elements(partial(plus, amount=1)) |
| 666 | + |
| 667 | + |
| 668 | +@pytest.mark.filterwarnings( |
| 669 | + "ignore:.*:polars.exceptions.PolarsInefficientMapWarning", |
| 670 | + "ignore:.*:polars.exceptions.MapWithoutReturnDtypeWarning", |
| 671 | +) |
| 672 | +@pytest.mark.parametrize( |
| 673 | + "pattern", |
| 674 | + [ |
| 675 | + ".", # regex: matches any character |
| 676 | + "^", # regex: start of string |
| 677 | + "$", # regex: end of string |
| 678 | + "[0]", # regex: character class |
| 679 | + "a|b", # regex: alternation |
| 680 | + "a+", # regex: one or more |
| 681 | + "a?", # regex: zero or one |
| 682 | + ], |
| 683 | +) |
| 684 | +def test_string_containment_regex_metacharacters_17182(pattern: str) -> None: |
| 685 | + """The suggested str.contains must use literal matching, not regex.""" |
| 686 | + df = pl.DataFrame({"b": [f"x{pattern}y", "xyz", pattern, "hello"]}) |
| 687 | + |
| 688 | + # What the lambda actually does (literal matching) |
| 689 | + result_lambda = df.select( |
| 690 | + pl.col("b").map_elements( |
| 691 | + lambda x: x.find(pattern) >= 0, # equivalent to `pattern in x` |
| 692 | + return_dtype=pl.Boolean, |
| 693 | + ) |
| 694 | + ) |
| 695 | + |
| 696 | + # Get the suggested expression from BytecodeParser |
| 697 | + func = lambda x: pattern in x # noqa: E731 |
| 698 | + parser = BytecodeParser(func, map_target="expr") |
| 699 | + suggested = parser.to_expression("b") |
| 700 | + |
| 701 | + # The suggested expression should produce the same results as the lambda |
| 702 | + result_suggested = df.select(eval(suggested, {"pl": pl, "pattern": pattern})) |
| 703 | + |
| 704 | + assert_frame_equal(result_lambda, result_suggested) |
| 705 | + |
| 706 | + |
| 707 | +@pytest.mark.filterwarnings( |
| 708 | + "ignore:.*:polars.exceptions.PolarsInefficientMapWarning", |
| 709 | + "ignore:.*:polars.exceptions.MapWithoutReturnDtypeWarning", |
| 710 | +) |
| 711 | +@pytest.mark.parametrize( |
| 712 | + "pattern", |
| 713 | + [ |
| 714 | + "*", # regex: zero or more (invalid without preceding expr) |
| 715 | + "(", # regex: unclosed group |
| 716 | + ")", # regex: unopened group |
| 717 | + "[", # regex: unclosed character class |
| 718 | + ], |
| 719 | +) |
| 720 | +def test_string_containment_invalid_regex_17182(pattern: str) -> None: |
| 721 | + """Patterns that are invalid regex but valid for Python's `in` operator.""" |
| 722 | + df = pl.DataFrame({"b": [f"x{pattern}y", "xyz", pattern, "hello"]}) |
| 723 | + |
| 724 | + # What the lambda actually does (literal matching) |
| 725 | + result_lambda = df.select( |
| 726 | + pl.col("b").map_elements( |
| 727 | + lambda x: x.find(pattern) >= 0, # equivalent to `pattern in x` |
| 728 | + return_dtype=pl.Boolean, |
| 729 | + ) |
| 730 | + ) |
| 731 | + |
| 732 | + # Get the suggested expression from BytecodeParser |
| 733 | + func = lambda x: pattern in x # noqa: E731 |
| 734 | + parser = BytecodeParser(func, map_target="expr") |
| 735 | + suggested = parser.to_expression("b") |
| 736 | + |
| 737 | + # The suggested expression must work without regex errors |
| 738 | + # and produce the same results as the lambda |
| 739 | + result_suggested = df.select(eval(suggested, {"pl": pl, "pattern": pattern})) |
| 740 | + |
| 741 | + assert_frame_equal(result_lambda, result_suggested) |
0 commit comments