|
24 | 24 | MY_CONSTANT = 3 |
25 | 25 | MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"} |
26 | 26 | MY_LIST = [1, 2, 3] |
| 27 | +MY_STRING = "qwerty" |
| 28 | +MY_SUBSTRING = "we" |
| 29 | +MY_COLLECTION = [2, 3, 4] |
27 | 30 |
|
28 | 31 | # column_name, function, expected_suggestion |
29 | 32 | TEST_CASES = [ |
|
70 | 73 | ), |
71 | 74 | ("a", "lambda x: x in (2, 3, 4)", 'pl.col("a").is_in((2, 3, 4))', None), |
72 | 75 | ("a", "lambda x: x not in (2, 3, 4)", '~pl.col("a").is_in((2, 3, 4))', None), |
| 76 | + ("a", "lambda x: x in MY_COLLECTION", 'pl.col("a").is_in(MY_COLLECTION)', None), |
| 77 | + ("a", "lambda x: x in MY_DICT", 'pl.col("a").is_in(MY_DICT)', None), |
| 78 | + ( |
| 79 | + "a", |
| 80 | + "lambda x: (x + 1) in (1, 2, 3)", |
| 81 | + '((pl.col("a") + 1).is_in((1, 2, 3)))', |
| 82 | + None, |
| 83 | + ), |
73 | 84 | ( |
74 | 85 | "a", |
75 | 86 | "lambda x: x in (1, 2, 3, 4, 3) and x % 2 == 0 and x > 0", |
76 | 87 | 'pl.col("a").is_in((1, 2, 3, 4, 3)) & ((pl.col("a") % 2) == 0) & (pl.col("a") > 0)', |
77 | 88 | None, |
78 | 89 | ), |
| 90 | + # --------------------------------------------- |
| 91 | + # string containment with 'in' operator |
| 92 | + # --------------------------------------------- |
| 93 | + ( |
| 94 | + "b", |
| 95 | + "lambda x: x in MY_STRING", |
| 96 | + 'pl.lit(MY_STRING).str.contains(pl.col("b"), literal=True)', |
| 97 | + None, |
| 98 | + ), |
| 99 | + ( |
| 100 | + "b", |
| 101 | + "lambda x: MY_SUBSTRING in x", |
| 102 | + 'pl.col("b").str.contains(pl.lit(MY_SUBSTRING), literal=True)', |
| 103 | + None, |
| 104 | + ), |
| 105 | + ( |
| 106 | + "b", |
| 107 | + 'lambda x: "A" in x', |
| 108 | + "pl.col(\"b\").str.contains(pl.lit('A'), literal=True)", |
| 109 | + None, |
| 110 | + ), |
| 111 | + ( |
| 112 | + "b", |
| 113 | + "lambda x: x not in MY_STRING", |
| 114 | + '~pl.lit(MY_STRING).str.contains(pl.col("b"), literal=True)', |
| 115 | + None, |
| 116 | + ), |
| 117 | + ( |
| 118 | + "b", |
| 119 | + "lambda x: x in x", |
| 120 | + 'pl.col("b").str.contains(pl.col("b"), literal=True)', |
| 121 | + None, |
| 122 | + ), |
| 123 | + ( |
| 124 | + "b", |
| 125 | + 'lambda x: "test" in x', |
| 126 | + "pl.col(\"b\").str.contains(pl.lit('test'), literal=True)", |
| 127 | + None, |
| 128 | + ), |
| 129 | + ( |
| 130 | + "b", |
| 131 | + 'lambda x: x not in "hello"', |
| 132 | + "~pl.lit('hello').str.contains(pl.col(\"b\"), literal=True)", |
| 133 | + None, |
| 134 | + ), |
| 135 | + # --------------------------------------------- |
| 136 | + # constants |
| 137 | + # --------------------------------------------- |
79 | 138 | ("a", "lambda x: MY_CONSTANT + x", 'MY_CONSTANT + pl.col("a")', None), |
80 | 139 | ( |
81 | 140 | "a", |
|
313 | 372 | "MY_CONSTANT": MY_CONSTANT, |
314 | 373 | "MY_DICT": MY_DICT, |
315 | 374 | "MY_LIST": MY_LIST, |
| 375 | + "MY_STRING": MY_STRING, |
| 376 | + "MY_SUBSTRING": MY_SUBSTRING, |
| 377 | + "MY_COLLECTION": MY_COLLECTION, |
316 | 378 | "cosh": cosh, |
317 | 379 | "datetime": datetime, |
318 | 380 | "dt": dt, |
@@ -604,3 +666,51 @@ def plus(value: int, amount: int) -> int: |
604 | 666 | df = pl.DataFrame(data) |
605 | 667 | # should not warn |
606 | 668 | _ = df["a"].map_elements(partial(plus, amount=1)) |
| 669 | + |
| 670 | + |
| 671 | +@pytest.mark.filterwarnings( |
| 672 | + "ignore:.*:polars.exceptions.PolarsInefficientMapWarning", |
| 673 | + "ignore:.*:polars.exceptions.MapWithoutReturnDtypeWarning", |
| 674 | +) |
| 675 | +@pytest.mark.parametrize("pattern", [".", "^", "$", "[0]", "a|b", "a+", "a?"]) |
| 676 | +def test_string_containment_regex_metacharacters_17182(pattern: str) -> None: |
| 677 | + df = pl.DataFrame({"b": [f"x{pattern}y", "xyz", pattern, "hello"]}) |
| 678 | + |
| 679 | + result_lambda = df.select( |
| 680 | + pl.col("b").map_elements( |
| 681 | + lambda x: x.find(pattern) >= 0, |
| 682 | + return_dtype=pl.Boolean, |
| 683 | + ) |
| 684 | + ) |
| 685 | + |
| 686 | + func = lambda x: pattern in x # noqa: E731 |
| 687 | + parser = BytecodeParser(func, map_target="expr") |
| 688 | + suggested = parser.to_expression("b") |
| 689 | + assert suggested is not None |
| 690 | + |
| 691 | + result_suggested = df.select(eval(suggested, {"pl": pl, "pattern": pattern})) |
| 692 | + assert_frame_equal(result_lambda, result_suggested) |
| 693 | + |
| 694 | + |
| 695 | +@pytest.mark.filterwarnings( |
| 696 | + "ignore:.*:polars.exceptions.PolarsInefficientMapWarning", |
| 697 | + "ignore:.*:polars.exceptions.MapWithoutReturnDtypeWarning", |
| 698 | +) |
| 699 | +@pytest.mark.parametrize("pattern", ["*", "(", ")", "["]) |
| 700 | +def test_string_containment_invalid_regex_17182(pattern: str) -> None: |
| 701 | + df = pl.DataFrame({"b": [f"x{pattern}y", "xyz", pattern, "hello"]}) |
| 702 | + |
| 703 | + result_lambda = df.select( |
| 704 | + pl.col("b").map_elements( |
| 705 | + lambda x: x.find(pattern) >= 0, |
| 706 | + return_dtype=pl.Boolean, |
| 707 | + ) |
| 708 | + ) |
| 709 | + |
| 710 | + func = lambda x: pattern in x # noqa: E731 |
| 711 | + parser = BytecodeParser(func, map_target="expr") |
| 712 | + suggested = parser.to_expression("b") |
| 713 | + assert suggested is not None |
| 714 | + |
| 715 | + result_suggested = df.select(eval(suggested, {"pl": pl, "pattern": pattern})) |
| 716 | + assert_frame_equal(result_lambda, result_suggested) |
0 commit comments