11from typing import Any , Optional , Union
2+ import re
23
34import pandas as pd
45
@@ -12,6 +13,26 @@ def _cudf_mask_none(result: Any, mask: Any) -> Any:
1213 return result_pd
1314
1415
16+ def _pandas_handle_na (
17+ result : pd .Series ,
18+ source : pd .Series ,
19+ na : Optional [bool ]
20+ ) -> pd .Series :
21+ mask = source .isna ()
22+ if na is None :
23+ if mask .any ():
24+ result = result .astype ('object' )
25+ result [mask ] = None
26+ return result
27+
28+ if mask .any ():
29+ result = result .copy ()
30+ result [mask ] = na
31+ if result .dtype == object :
32+ result = result .infer_objects (copy = False )
33+ return result
34+
35+
1536class Contains (ASTPredicate ):
1637 def __init__ (
1738 self ,
@@ -58,13 +79,14 @@ def __call__(self, s: SeriesT) -> SeriesT:
5879
5980 return result
6081 else :
61- return s .str .contains (
82+ result = s .str .contains (
6283 self .pat ,
63- self .case ,
64- self .flags ,
65- self .na ,
66- self .regex
84+ case = self .case ,
85+ flags = self .flags ,
86+ na = self .na ,
87+ regex = self .regex
6788 )
89+ return _pandas_handle_na (result , s , self .na )
6890
6991 def _validate_fields (self ) -> None :
7092 """Validate predicate fields."""
@@ -153,9 +175,7 @@ def __call__(self, s: SeriesT) -> SeriesT:
153175 if not is_cudf and self .case :
154176 # Use pandas native tuple support for case-sensitive
155177 result = s .str .startswith (self .pat )
156- if self .na is not None :
157- return result .fillna (self .na )
158- return result
178+ return _pandas_handle_na (result , s , self .na )
159179 elif not is_cudf and not self .case :
160180 # pandas tuple with case-insensitive - need workaround
161181 if len (self .pat ) == 0 :
@@ -169,9 +189,7 @@ def __call__(self, s: SeriesT) -> SeriesT:
169189 patterns_lower = tuple (p .lower () for p in self .pat )
170190 # Use pandas native tuple support on lowercased data
171191 result = s_lower .str .startswith (patterns_lower )
172- if self .na is not None :
173- return result .fillna (self .na )
174- return result
192+ return _pandas_handle_na (result , s , self .na )
175193 else :
176194 # cuDF - need manual OR logic (workaround for bug #20237)
177195 if len (self .pat ) == 0 :
@@ -217,14 +235,7 @@ def __call__(self, s: SeriesT) -> SeriesT:
217235 else :
218236 return result
219237 else :
220- # pandas supports na parameter for case-sensitive str patterns
221- if not self .case :
222- if self .na is not None :
223- return result .fillna (self .na )
224- else :
225- return result
226- else :
227- return s .str .startswith (self .pat , self .na )
238+ return _pandas_handle_na (result , s , self .na )
228239
229240 def _validate_fields (self ) -> None :
230241 """Validate predicate fields."""
@@ -319,9 +330,7 @@ def __call__(self, s: SeriesT) -> SeriesT:
319330 if not is_cudf and self .case :
320331 # Use pandas native tuple support for case-sensitive
321332 result = s .str .endswith (self .pat )
322- if self .na is not None :
323- return result .fillna (self .na )
324- return result
333+ return _pandas_handle_na (result , s , self .na )
325334 elif not is_cudf and not self .case :
326335 # pandas tuple with case-insensitive - need workaround
327336 if len (self .pat ) == 0 :
@@ -336,9 +345,7 @@ def __call__(self, s: SeriesT) -> SeriesT:
336345 patterns_lower = tuple (p .lower () for p in self .pat )
337346 # Use pandas native tuple support on lowercased data
338347 result = s_lower .str .endswith (patterns_lower )
339- if self .na is not None :
340- return result .fillna (self .na )
341- return result
348+ return _pandas_handle_na (result , s , self .na )
342349 else :
343350 # cuDF - need manual OR logic (workaround for bug #20237)
344351 if len (self .pat ) == 0 :
@@ -384,14 +391,7 @@ def __call__(self, s: SeriesT) -> SeriesT:
384391 else :
385392 return result
386393 else :
387- # pandas supports na parameter for case-sensitive str patterns
388- if not self .case :
389- if self .na is not None :
390- return result .fillna (self .na )
391- else :
392- return result
393- else :
394- return s .str .endswith (self .pat , self .na )
394+ return _pandas_handle_na (result , s , self .na )
395395
396396 def _validate_fields (self ) -> None :
397397 """Validate predicate fields."""
@@ -493,7 +493,18 @@ def __call__(self, s: SeriesT) -> SeriesT:
493493
494494 return result
495495 else :
496- return s .str .match (self .pat , self .case , self .flags , self .na )
496+ if self .flags :
497+ effective_flags = self .flags
498+ if not self .case :
499+ effective_flags |= re .IGNORECASE
500+ pattern = re .compile (self .pat , effective_flags )
501+ result = s .str .match (pattern , na = self .na )
502+ else :
503+ if not self .case :
504+ result = s .str .match (self .pat , case = False , na = self .na )
505+ else :
506+ result = s .str .match (self .pat , na = self .na )
507+ return _pandas_handle_na (result , s , self .na )
497508
498509 def _validate_fields (self ) -> None :
499510 """Validate predicate fields."""
@@ -582,7 +593,13 @@ def __call__(self, s: SeriesT) -> SeriesT:
582593 return result
583594 else :
584595 # pandas has native fullmatch support
585- return s .str .fullmatch (self .pat , self .case , self .flags , self .na )
596+ result = s .str .fullmatch (
597+ self .pat ,
598+ case = self .case ,
599+ flags = self .flags ,
600+ na = self .na
601+ )
602+ return _pandas_handle_na (result , s , self .na )
586603
587604 def _validate_fields (self ) -> None :
588605 """Validate predicate fields."""
0 commit comments