Skip to content

Commit c650e4c

Browse files
committed
fix: handle pandas 3 string predicates
1 parent e6dafec commit c650e4c

1 file changed

Lines changed: 52 additions & 35 deletions

File tree

  • graphistry/compute/predicates

graphistry/compute/predicates/str.py

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Any, Optional, Union
2+
import re
23

34
import pandas as pd
45

@@ -12,6 +13,26 @@ def _cudf_mask_none(result: Any, mask: Any) -> Any:
1213
return result_pd
1314

1415

16+
def _pandas_handle_na(
17+
result: pd.Series,
18+
source: pd.Series,
19+
na: Optional[bool]
20+
) -> pd.Series:
21+
mask = source.isna()
22+
if na is None:
23+
if mask.any():
24+
result = result.astype('object')
25+
result[mask] = None
26+
return result
27+
28+
if mask.any():
29+
result = result.copy()
30+
result[mask] = na
31+
if result.dtype == object:
32+
result = result.infer_objects(copy=False)
33+
return result
34+
35+
1536
class Contains(ASTPredicate):
1637
def __init__(
1738
self,
@@ -58,13 +79,14 @@ def __call__(self, s: SeriesT) -> SeriesT:
5879

5980
return result
6081
else:
61-
return s.str.contains(
82+
result = s.str.contains(
6283
self.pat,
63-
self.case,
64-
self.flags,
65-
self.na,
66-
self.regex
84+
case=self.case,
85+
flags=self.flags,
86+
na=self.na,
87+
regex=self.regex
6788
)
89+
return _pandas_handle_na(result, s, self.na)
6890

6991
def _validate_fields(self) -> None:
7092
"""Validate predicate fields."""
@@ -153,9 +175,7 @@ def __call__(self, s: SeriesT) -> SeriesT:
153175
if not is_cudf and self.case:
154176
# Use pandas native tuple support for case-sensitive
155177
result = s.str.startswith(self.pat)
156-
if self.na is not None:
157-
return result.fillna(self.na)
158-
return result
178+
return _pandas_handle_na(result, s, self.na)
159179
elif not is_cudf and not self.case:
160180
# pandas tuple with case-insensitive - need workaround
161181
if len(self.pat) == 0:
@@ -169,9 +189,7 @@ def __call__(self, s: SeriesT) -> SeriesT:
169189
patterns_lower = tuple(p.lower() for p in self.pat)
170190
# Use pandas native tuple support on lowercased data
171191
result = s_lower.str.startswith(patterns_lower)
172-
if self.na is not None:
173-
return result.fillna(self.na)
174-
return result
192+
return _pandas_handle_na(result, s, self.na)
175193
else:
176194
# cuDF - need manual OR logic (workaround for bug #20237)
177195
if len(self.pat) == 0:
@@ -217,14 +235,7 @@ def __call__(self, s: SeriesT) -> SeriesT:
217235
else:
218236
return result
219237
else:
220-
# pandas supports na parameter for case-sensitive str patterns
221-
if not self.case:
222-
if self.na is not None:
223-
return result.fillna(self.na)
224-
else:
225-
return result
226-
else:
227-
return s.str.startswith(self.pat, self.na)
238+
return _pandas_handle_na(result, s, self.na)
228239

229240
def _validate_fields(self) -> None:
230241
"""Validate predicate fields."""
@@ -319,9 +330,7 @@ def __call__(self, s: SeriesT) -> SeriesT:
319330
if not is_cudf and self.case:
320331
# Use pandas native tuple support for case-sensitive
321332
result = s.str.endswith(self.pat)
322-
if self.na is not None:
323-
return result.fillna(self.na)
324-
return result
333+
return _pandas_handle_na(result, s, self.na)
325334
elif not is_cudf and not self.case:
326335
# pandas tuple with case-insensitive - need workaround
327336
if len(self.pat) == 0:
@@ -336,9 +345,7 @@ def __call__(self, s: SeriesT) -> SeriesT:
336345
patterns_lower = tuple(p.lower() for p in self.pat)
337346
# Use pandas native tuple support on lowercased data
338347
result = s_lower.str.endswith(patterns_lower)
339-
if self.na is not None:
340-
return result.fillna(self.na)
341-
return result
348+
return _pandas_handle_na(result, s, self.na)
342349
else:
343350
# cuDF - need manual OR logic (workaround for bug #20237)
344351
if len(self.pat) == 0:
@@ -384,14 +391,7 @@ def __call__(self, s: SeriesT) -> SeriesT:
384391
else:
385392
return result
386393
else:
387-
# pandas supports na parameter for case-sensitive str patterns
388-
if not self.case:
389-
if self.na is not None:
390-
return result.fillna(self.na)
391-
else:
392-
return result
393-
else:
394-
return s.str.endswith(self.pat, self.na)
394+
return _pandas_handle_na(result, s, self.na)
395395

396396
def _validate_fields(self) -> None:
397397
"""Validate predicate fields."""
@@ -493,7 +493,18 @@ def __call__(self, s: SeriesT) -> SeriesT:
493493

494494
return result
495495
else:
496-
return s.str.match(self.pat, self.case, self.flags, self.na)
496+
if self.flags:
497+
effective_flags = self.flags
498+
if not self.case:
499+
effective_flags |= re.IGNORECASE
500+
pattern = re.compile(self.pat, effective_flags)
501+
result = s.str.match(pattern, na=self.na)
502+
else:
503+
if not self.case:
504+
result = s.str.match(self.pat, case=False, na=self.na)
505+
else:
506+
result = s.str.match(self.pat, na=self.na)
507+
return _pandas_handle_na(result, s, self.na)
497508

498509
def _validate_fields(self) -> None:
499510
"""Validate predicate fields."""
@@ -582,7 +593,13 @@ def __call__(self, s: SeriesT) -> SeriesT:
582593
return result
583594
else:
584595
# pandas has native fullmatch support
585-
return s.str.fullmatch(self.pat, self.case, self.flags, self.na)
596+
result = s.str.fullmatch(
597+
self.pat,
598+
case=self.case,
599+
flags=self.flags,
600+
na=self.na
601+
)
602+
return _pandas_handle_na(result, s, self.na)
586603

587604
def _validate_fields(self) -> None:
588605
"""Validate predicate fields."""

0 commit comments

Comments
 (0)