@@ -45,6 +45,14 @@ def cls(dtype):
45
45
return dtype .construct_array_type ()
46
46
47
47
48
+ DTYPE_HIERARCHY = [
49
+ pd .StringDtype ("python" , na_value = np .nan ),
50
+ pd .StringDtype ("pyarrow" , na_value = np .nan ),
51
+ pd .StringDtype ("python" , na_value = pd .NA ),
52
+ pd .StringDtype ("pyarrow" , na_value = pd .NA ),
53
+ ]
54
+
55
+
48
56
def test_dtype_constructor ():
49
57
pytest .importorskip ("pyarrow" )
50
58
@@ -319,37 +327,43 @@ def test_comparison_methods_scalar_not_string(comparison_op, dtype):
319
327
tm .assert_extension_array_equal (result , expected )
320
328
321
329
322
- def test_comparison_methods_array (comparison_op , dtype ):
330
+ def test_comparison_methods_array (comparison_op , dtype , dtype2 ):
323
331
op_name = f"__{ comparison_op .__name__ } __"
324
332
325
333
a = pd .array (["a" , None , "c" ], dtype = dtype )
326
- other = [None , None , "c" ]
327
- result = getattr (a , op_name )(other )
328
- if dtype .na_value is np .nan :
334
+ other = pd .array ([None , None , "c" ], dtype = dtype2 )
335
+ result = comparison_op (a , other )
336
+
337
+ # ensure operation is commutative
338
+ result2 = comparison_op (other , a )
339
+ tm .assert_equal (result , result2 )
340
+
341
+ if dtype .na_value is np .nan and dtype2 .na_value is np .nan :
329
342
if operator .ne == comparison_op :
330
343
expected = np .array ([True , True , False ])
331
344
else :
332
345
expected = np .array ([False , False , False ])
333
346
expected [- 1 ] = getattr (other [- 1 ], op_name )(a [- 1 ])
334
347
tm .assert_numpy_array_equal (result , expected )
335
348
336
- result = getattr (a , op_name )(pd .NA )
337
- if operator .ne == comparison_op :
338
- expected = np .array ([True , True , True ])
349
+ else :
350
+ h1 = DTYPE_HIERARCHY .index (dtype )
351
+ h2 = DTYPE_HIERARCHY .index (dtype2 )
352
+ max_dtype = DTYPE_HIERARCHY [max (h1 , h2 )]
353
+ if max_dtype .storage == "python" :
354
+ expected_dtype = "boolean"
339
355
else :
340
- expected = np .array ([False , False , False ])
341
- tm .assert_numpy_array_equal (result , expected )
356
+ expected_dtype = "bool[pyarrow]"
342
357
343
- else :
344
- expected_dtype = "boolean[pyarrow]" if dtype .storage == "pyarrow" else "boolean"
345
358
expected = np .full (len (a ), fill_value = None , dtype = "object" )
346
359
expected [- 1 ] = getattr (other [- 1 ], op_name )(a [- 1 ])
347
360
expected = pd .array (expected , dtype = expected_dtype )
348
361
tm .assert_extension_array_equal (result , expected )
349
362
350
- result = getattr (a , op_name )(pd .NA )
351
- expected = pd .array ([None , None , None ], dtype = expected_dtype )
352
- tm .assert_extension_array_equal (result , expected )
363
+ # # with list
364
+ # other = [None, None, "c"]
365
+ # result3 = getattr(a, op_name)(other)
366
+ # tm.assert_equal(result, result3)
353
367
354
368
355
369
def test_constructor_raises (cls ):
0 commit comments