Skip to content

Commit 7ffb08f

Browse files
fix string arith tests
1 parent 3c4d782 commit 7ffb08f

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

pandas/tests/arrays/string_/test_string.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ def cls(dtype):
5353
]
5454

5555

56+
def string_dtype_highest_priority(dtype1, dtype2):
57+
h1 = DTYPE_HIERARCHY.index(dtype1)
58+
h2 = DTYPE_HIERARCHY.index(dtype2)
59+
return DTYPE_HIERARCHY[max(h1, h2)]
60+
61+
5662
def test_dtype_constructor():
5763
pytest.importorskip("pyarrow")
5864

@@ -347,9 +353,7 @@ def test_comparison_methods_array(comparison_op, dtype, dtype2):
347353
tm.assert_numpy_array_equal(result, expected)
348354

349355
else:
350-
h1 = DTYPE_HIERARCHY.index(dtype)
351-
h2 = DTYPE_HIERARCHY.index(dtype2)
352-
max_dtype = DTYPE_HIERARCHY[max(h1, h2)]
356+
max_dtype = string_dtype_highest_priority(dtype, dtype2)
353357
if max_dtype.storage == "python":
354358
expected_dtype = "boolean"
355359
else:

pandas/tests/extension/test_string.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,14 @@
2222
import numpy as np
2323
import pytest
2424

25-
from pandas.compat import HAS_PYARROW
26-
2725
from pandas.core.dtypes.base import StorageExtensionDtype
2826

2927
import pandas as pd
3028
import pandas._testing as tm
3129
from pandas.api.types import is_string_dtype
3230
from pandas.core.arrays import ArrowStringArray
3331
from pandas.core.arrays.string_ import StringDtype
32+
from pandas.tests.arrays.string_.test_string import string_dtype_highest_priority
3433
from pandas.tests.extension import base
3534

3635

@@ -202,10 +201,13 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
202201
dtype = cast(StringDtype, tm.get_dtype(obj))
203202
if op_name in ["__add__", "__radd__"]:
204203
cast_to = dtype
204+
dtype_other = tm.get_dtype(other) if not isinstance(other, str) else None
205+
if isinstance(dtype_other, StringDtype):
206+
cast_to = string_dtype_highest_priority(dtype, dtype_other)
205207
elif dtype.na_value is np.nan:
206208
cast_to = np.bool_ # type: ignore[assignment]
207209
elif dtype.storage == "pyarrow":
208-
cast_to = "boolean[pyarrow]" # type: ignore[assignment]
210+
cast_to = "bool[pyarrow]" # type: ignore[assignment]
209211
else:
210212
cast_to = "boolean" # type: ignore[assignment]
211213
return pointwise_result.astype(cast_to)
@@ -236,9 +238,7 @@ def test_arith_series_with_array(
236238
if (
237239
using_infer_string
238240
and all_arithmetic_operators == "__radd__"
239-
and (
240-
(dtype.na_value is pd.NA) or (dtype.storage == "python" and HAS_PYARROW)
241-
)
241+
and dtype.na_value is pd.NA
242242
):
243243
mark = pytest.mark.xfail(
244244
reason="The pointwise operation result will be inferred to "

0 commit comments

Comments
 (0)