Skip to content

Commit 86cd0ab

Browse files
authored
feat: Support other Expr/Series in str.contains for polars and SQL-like backends (#3473)
1 parent 228fee8 commit 86cd0ab

11 files changed

Lines changed: 219 additions & 75 deletions

File tree

narwhals/_arrow/series_str.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,13 @@ def ends_with(self, suffix: str) -> ArrowSeries:
5454
pc.equal(self.slice(-len(suffix), None).native, lit(suffix))
5555
)
5656

57-
def contains(self, pattern: str, *, literal: bool) -> ArrowSeries:
58-
check_func = pc.match_substring if literal else pc.match_substring_regex
59-
return self.with_native(check_func(self.native, pattern))
57+
def contains(self, pattern: ArrowSeries, *, literal: bool) -> ArrowSeries:
58+
_, pattern_native = extract_native(self.compliant, pattern)
59+
if not isinstance(pattern_native, pa.StringScalar):
60+
msg = "`.str.contains` only supports str pattern values for pyarrow backend"
61+
raise TypeError(msg)
62+
fn = pc.match_substring if literal else pc.match_substring_regex
63+
return self.with_native(fn(self.native, pattern_native.as_py()))
6064

6165
def slice(self, offset: int, length: int | None) -> ArrowSeries:
6266
stop = offset + length if length is not None else None

narwhals/_compliant/any_namespace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def replace_all(self, value: T, pattern: str, *, literal: bool) -> T: ...
9898
def strip_chars(self, characters: str | None) -> T: ...
9999
def starts_with(self, prefix: str) -> T: ...
100100
def ends_with(self, suffix: str) -> T: ...
101-
def contains(self, pattern: str, *, literal: bool) -> T: ...
101+
def contains(self, pattern: T, *, literal: bool) -> T: ...
102102
def slice(self, offset: int, length: int | None) -> T: ...
103103
def split(self, by: str) -> T: ...
104104
def to_datetime(self, format: str | None) -> T: ...

narwhals/_compliant/expr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1135,7 +1135,7 @@ def starts_with(self, prefix: str) -> EagerExprT:
11351135
def ends_with(self, suffix: str) -> EagerExprT:
11361136
return self.compliant._reuse_series_namespace("str", "ends_with", suffix=suffix)
11371137

1138-
def contains(self, pattern: str, *, literal: bool) -> EagerExprT:
1138+
def contains(self, pattern: EagerExprT, *, literal: bool) -> EagerExprT:
11391139
return self.compliant._reuse_series_namespace(
11401140
"str", "contains", pattern=pattern, literal=literal
11411141
)

narwhals/_dask/expr_str.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,18 @@ def starts_with(self, prefix: str) -> DaskExpr:
4545
def ends_with(self, suffix: str) -> DaskExpr:
4646
return self.compliant._with_callable(lambda expr: expr.str.endswith(suffix))
4747

48-
def contains(self, pattern: str, *, literal: bool) -> DaskExpr:
49-
return self.compliant._with_callable(
50-
lambda expr: expr.str.contains(pat=pattern, regex=not literal)
51-
)
48+
def contains(self, pattern: DaskExpr, *, literal: bool) -> DaskExpr:
49+
if not pattern._metadata.is_literal:
50+
msg = "dask backed `Expr.str.contains` only supports str replacement values"
51+
raise TypeError(msg)
52+
53+
def _contains(expr: dx.Series, pattern: dx.Series) -> dx.Series:
54+
# OK to call `compute` here as `pattern` is just a literal expression.
55+
return expr.str.contains( # pyright: ignore[reportAttributeAccessIssue]
56+
pat=pattern.compute(), regex=not literal
57+
)
58+
59+
return self.compliant._with_callable(_contains, pattern=pattern)
5260

5361
def slice(self, offset: int, length: int | None) -> DaskExpr:
5462
return self.compliant._with_callable(

narwhals/_pandas_like/series_str.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,14 @@ def starts_with(self, prefix: str) -> PandasLikeSeries:
4545
def ends_with(self, suffix: str) -> PandasLikeSeries:
4646
return self.with_native(self.native.str.endswith(suffix))
4747

48-
def contains(self, pattern: str, *, literal: bool) -> PandasLikeSeries:
49-
return self.with_native(self.native.str.contains(pat=pattern, regex=not literal))
48+
def contains(self, pattern: PandasLikeSeries, *, literal: bool) -> PandasLikeSeries:
49+
_, pattern_native = align_and_extract_native(self.compliant, pattern)
50+
if not isinstance(pattern_native, str):
51+
msg = f"`.str.contains` only supports str pattern values for {self.compliant._implementation} backend"
52+
raise TypeError(msg)
53+
return self.with_native(
54+
self.native.str.contains(pat=pattern_native, regex=not literal)
55+
)
5056

5157
def slice(self, offset: int, length: int | None) -> PandasLikeSeries:
5258
stop = offset + length if length else None

narwhals/_polars/expr.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,12 @@ def replace_all(
471471
self.native.str.replace_all(pattern, value_native, literal=literal)
472472
)
473473

474+
def contains(self, pattern: PolarsExpr, *, literal: bool) -> PolarsExpr:
475+
pattern_native = extract_native(pattern)
476+
return self.compliant._with_native(
477+
self.native.str.contains(pattern_native, literal=literal)
478+
)
479+
474480

475481
class PolarsExprCatNamespace(
476482
PolarsExprNamespace, PolarsCatNamespace[PolarsExpr, pl.Expr]

narwhals/_polars/series.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,12 @@ def replace_all(
824824
self.native.str.replace_all(pattern, value_native, literal=literal) # type: ignore[arg-type]
825825
)
826826

827+
def contains(self, pattern: PolarsSeries, *, literal: bool) -> PolarsSeries:
828+
pattern_native = extract_native(pattern)
829+
return self.compliant._with_native(
830+
self.native.str.contains(pattern_native, literal=literal) # type: ignore[arg-type]
831+
)
832+
827833

828834
class PolarsSeriesCatNamespace(
829835
PolarsSeriesNamespace, PolarsCatNamespace[PolarsSeries, pl.Series]

narwhals/_sql/expr_str.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,18 @@ def _when(
2626
) -> NativeExpr:
2727
return self.compliant._when(condition, value, otherwise) # type: ignore[no-any-return]
2828

29-
def contains(self, pattern: str, *, literal: bool) -> SQLExprT:
30-
def func(expr: NativeExpr) -> NativeExpr:
31-
if literal:
32-
return self._function("contains", expr, self._lit(pattern))
33-
return self._function("regexp_matches", expr, self._lit(pattern))
29+
def contains(self, pattern: SQLExprT, *, literal: bool) -> SQLExprT:
3430

35-
return self.compliant._with_elementwise(func)
31+
def func(expr: NativeExpr, pattern: NativeExpr) -> NativeExpr:
32+
func_name = "contains" if literal else "regexp_matches"
33+
return self._function(func_name, expr, pattern)
34+
35+
compliant_pattern = (
36+
self.compliant.__narwhals_namespace__().lit(pattern, dtype=None)
37+
if isinstance(pattern, str)
38+
else pattern
39+
)
40+
return self.compliant._with_elementwise(func, pattern=compliant_pattern)
3641

3742
def ends_with(self, suffix: str) -> SQLExprT:
3843
return self.compliant._with_elementwise(

narwhals/expr_str.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,13 +186,19 @@ def ends_with(self, suffix: str) -> ExprT:
186186
ExprNode(ExprKind.ELEMENTWISE, "str.ends_with", suffix=suffix)
187187
)
188188

189-
def contains(self, pattern: str, *, literal: bool = False) -> ExprT:
189+
def contains(self, pattern: str | IntoExpr, *, literal: bool = False) -> ExprT:
190190
r"""Check if string contains a substring that matches a pattern.
191191
192192
Arguments:
193-
pattern: A Character sequence or valid regular expression pattern.
193+
pattern: A Character sequence, valid regular expression pattern, or another
194+
Expr.
194195
literal: If True, treats the pattern as a literal string.
195-
If False, assumes the pattern is a regular expression.
196+
If False, assumes the pattern is a regular expression.
197+
198+
Warning:
199+
Passing an expression as `pattern` is only supported by DuckDB, Ibis, Polars,
200+
PySpark and SQLFrame. Other backends, such as pandas and PyArrow, will raise
201+
a `TypeError`.
196202
197203
Examples:
198204
>>> import pyarrow as pa
@@ -214,7 +220,11 @@ def contains(self, pattern: str, *, literal: bool = False) -> ExprT:
214220
"""
215221
return self._expr._append_node(
216222
ExprNode(
217-
ExprKind.ELEMENTWISE, "str.contains", pattern=pattern, literal=literal
223+
ExprKind.ELEMENTWISE,
224+
"str.contains",
225+
exprs=(pattern,),
226+
literal=literal,
227+
str_as_lit=True,
218228
)
219229
)
220230

narwhals/series_str.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,18 @@ def ends_with(self, suffix: str) -> SeriesT:
152152
self._narwhals_series._compliant_series.str.ends_with(suffix)
153153
)
154154

155-
def contains(self, pattern: str, *, literal: bool = False) -> SeriesT:
155+
def contains(self, pattern: str | SeriesT, *, literal: bool = False) -> SeriesT:
156156
r"""Check if string contains a substring that matches a pattern.
157157
158158
Arguments:
159-
pattern: A Character sequence or valid regular expression pattern.
159+
pattern: A Character sequence, valid regular expression pattern, or another
160+
Series.
160161
literal: If True, treats the pattern as a literal string.
161-
If False, assumes the pattern is a regular expression.
162+
If False, assumes the pattern is a regular expression.
163+
164+
Warning:
165+
Passing a Series as `pattern` is only supported by Polars. Other backends
166+
will raise a `TypeError`.
162167
163168
Examples:
164169
>>> import pyarrow as pa
@@ -176,7 +181,9 @@ def contains(self, pattern: str, *, literal: bool = False) -> SeriesT:
176181
]
177182
"""
178183
return self._narwhals_series._with_compliant(
179-
self._narwhals_series._compliant_series.str.contains(pattern, literal=literal)
184+
self._narwhals_series._compliant_series.str.contains(
185+
self._extract_compliant(pattern), literal=literal
186+
)
180187
)
181188

182189
def slice(self, offset: int, length: int | None = None) -> SeriesT:

0 commit comments

Comments
 (0)