Skip to content

Commit 3c4d782

Browse files
API (string dtype): implement hierarchy (NA > NaN, pyarrow > python) for consistent comparisons between different string dtypes
1 parent 8943c97 commit 3c4d782

File tree

4 files changed

+61
-19
lines changed

4 files changed

+61
-19
lines changed

pandas/core/arrays/arrow/array.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
infer_dtype_from_scalar,
3434
)
3535
from pandas.core.dtypes.common import (
36-
CategoricalDtype,
3736
is_array_like,
3837
is_bool_dtype,
3938
is_float_dtype,
@@ -730,9 +729,7 @@ def __setstate__(self, state) -> None:
730729

731730
def _cmp_method(self, other, op) -> ArrowExtensionArray:
732731
pc_func = ARROW_CMP_FUNCS[op.__name__]
733-
if isinstance(
734-
other, (ArrowExtensionArray, np.ndarray, list, BaseMaskedArray)
735-
) or isinstance(getattr(other, "dtype", None), CategoricalDtype):
732+
if isinstance(other, (ExtensionArray, np.ndarray, list)):
736733
try:
737734
result = pc_func(self._pa_array, self._box_pa(other))
738735
except pa.ArrowNotImplementedError:

pandas/core/arrays/string_.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -1018,7 +1018,30 @@ def searchsorted(
10181018
return super().searchsorted(value=value, side=side, sorter=sorter)
10191019

10201020
def _cmp_method(self, other, op):
1021-
from pandas.arrays import BooleanArray
1021+
from pandas.arrays import (
1022+
ArrowExtensionArray,
1023+
BooleanArray,
1024+
)
1025+
1026+
if (
1027+
isinstance(other, BaseStringArray)
1028+
and self.dtype.na_value is not libmissing.NA
1029+
and other.dtype.na_value is libmissing.NA
1030+
):
1031+
# NA has priority of NaN semantics
1032+
return NotImplemented
1033+
1034+
if isinstance(other, ArrowExtensionArray):
1035+
if isinstance(other, BaseStringArray):
1036+
# pyarrow storage has priority over python storage
1037+
# (except if we have NA semantics and other not)
1038+
if not (
1039+
self.dtype.na_value is libmissing.NA
1040+
and other.dtype.na_value is not libmissing.NA
1041+
):
1042+
return NotImplemented
1043+
else:
1044+
return NotImplemented
10221045

10231046
if isinstance(other, StringArray):
10241047
other = other._ndarray

pandas/core/arrays/string_arrow.py

+8
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,14 @@ def value_counts(self, dropna: bool = True) -> Series:
473473
return result
474474

475475
def _cmp_method(self, other, op):
476+
if (
477+
isinstance(other, BaseStringArray)
478+
and self.dtype.na_value is not libmissing.NA
479+
and other.dtype.na_value is libmissing.NA
480+
):
481+
# NA has priority of NaN semantics
482+
return NotImplemented
483+
476484
result = super()._cmp_method(other, op)
477485
if self.dtype.na_value is np.nan:
478486
if op == operator.ne:

pandas/tests/arrays/string_/test_string.py

+28-14
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ def cls(dtype):
4545
return dtype.construct_array_type()
4646

4747

48+
DTYPE_HIERARCHY = [
49+
pd.StringDtype("python", na_value=np.nan),
50+
pd.StringDtype("pyarrow", na_value=np.nan),
51+
pd.StringDtype("python", na_value=pd.NA),
52+
pd.StringDtype("pyarrow", na_value=pd.NA),
53+
]
54+
55+
4856
def test_dtype_constructor():
4957
pytest.importorskip("pyarrow")
5058

@@ -319,37 +327,43 @@ def test_comparison_methods_scalar_not_string(comparison_op, dtype):
319327
tm.assert_extension_array_equal(result, expected)
320328

321329

322-
def test_comparison_methods_array(comparison_op, dtype):
330+
def test_comparison_methods_array(comparison_op, dtype, dtype2):
323331
op_name = f"__{comparison_op.__name__}__"
324332

325333
a = pd.array(["a", None, "c"], dtype=dtype)
326-
other = [None, None, "c"]
327-
result = getattr(a, op_name)(other)
328-
if dtype.na_value is np.nan:
334+
other = pd.array([None, None, "c"], dtype=dtype2)
335+
result = comparison_op(a, other)
336+
337+
# ensure operation is commutative
338+
result2 = comparison_op(other, a)
339+
tm.assert_equal(result, result2)
340+
341+
if dtype.na_value is np.nan and dtype2.na_value is np.nan:
329342
if operator.ne == comparison_op:
330343
expected = np.array([True, True, False])
331344
else:
332345
expected = np.array([False, False, False])
333346
expected[-1] = getattr(other[-1], op_name)(a[-1])
334347
tm.assert_numpy_array_equal(result, expected)
335348

336-
result = getattr(a, op_name)(pd.NA)
337-
if operator.ne == comparison_op:
338-
expected = np.array([True, True, True])
349+
else:
350+
h1 = DTYPE_HIERARCHY.index(dtype)
351+
h2 = DTYPE_HIERARCHY.index(dtype2)
352+
max_dtype = DTYPE_HIERARCHY[max(h1, h2)]
353+
if max_dtype.storage == "python":
354+
expected_dtype = "boolean"
339355
else:
340-
expected = np.array([False, False, False])
341-
tm.assert_numpy_array_equal(result, expected)
356+
expected_dtype = "bool[pyarrow]"
342357

343-
else:
344-
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
345358
expected = np.full(len(a), fill_value=None, dtype="object")
346359
expected[-1] = getattr(other[-1], op_name)(a[-1])
347360
expected = pd.array(expected, dtype=expected_dtype)
348361
tm.assert_extension_array_equal(result, expected)
349362

350-
result = getattr(a, op_name)(pd.NA)
351-
expected = pd.array([None, None, None], dtype=expected_dtype)
352-
tm.assert_extension_array_equal(result, expected)
363+
# # with list
364+
# other = [None, None, "c"]
365+
# result3 = getattr(a, op_name)(other)
366+
# tm.assert_equal(result, result3)
353367

354368

355369
def test_constructor_raises(cls):

0 commit comments

Comments
 (0)