@@ -42,30 +42,47 @@ def same_elements(list1: List[Any], list2: List[Any]) -> bool:
42
42
-- The lists do not contain duplicates
43
43
44
44
Checking equality is then the same as checking that the lists are the same
45
- length, and that one is a subset of the other.
45
+ length, and that both are subsets of the other.
46
46
"""
47
47
48
48
if len (list1 ) != len (list2 ):
49
49
return False
50
50
51
+ matched = [False for _ in list2 ]
51
52
for item1 in list1 :
52
- found = False
53
- for item2 in list2 :
54
- if isinstance (item1 , np .ndarray ) or isinstance (item2 , np .ndarray ):
55
- if (
56
- isinstance (item1 , np .ndarray )
57
- and isinstance (item2 , np .ndarray )
58
- and np .array_equal (item1 , item2 )
59
- ):
60
- found = True
61
- break
62
- elif item1 == item2 :
63
- found = True
53
+ matched_this_item = False
54
+ for i , item2 in enumerate (list2 ):
55
+ if not matched [i ] and is_ax_equal (item1 , item2 ):
56
+ matched [i ] = True
57
+ matched_this_item = True
64
58
break
65
- if not found :
59
+ if not matched_this_item :
66
60
return False
61
+ return all (matched )
67
62
68
- return True
63
+
64
+ # pyre-fixme[2]: Parameter annotation cannot contain `Any`.
65
+ def is_ax_equal (one_val : Any , other_val : Any ) -> bool :
66
+ """Check for equality of two values, handling lists, dicts, dfs, floats,
67
+ dates, and numpy arrays. This method and ``same_elements`` function
68
+ as a recursive unit.
69
+ """
70
+ if isinstance (one_val , list ) and isinstance (other_val , list ):
71
+ return same_elements (one_val , other_val )
72
+ elif isinstance (one_val , dict ) and isinstance (other_val , dict ):
73
+ return sorted (one_val .keys ()) == sorted (other_val .keys ()) and same_elements (
74
+ list (one_val .values ()), list (other_val .values ())
75
+ )
76
+ elif isinstance (one_val , np .ndarray ) and isinstance (other_val , np .ndarray ):
77
+ return np .array_equal (one_val , other_val , equal_nan = True )
78
+ elif isinstance (one_val , datetime ):
79
+ return datetime_equals (one_val , other_val )
80
+ elif isinstance (one_val , float ) and isinstance (other_val , float ):
81
+ return np .isclose (one_val , other_val , equal_nan = True )
82
+ elif isinstance (one_val , pd .DataFrame ) and isinstance (other_val , pd .DataFrame ):
83
+ return dataframe_equals (one_val , other_val )
84
+ else :
85
+ return one_val == other_val
69
86
70
87
71
88
def datetime_equals (dt1 : Optional [datetime ], dt2 : Optional [datetime ]) -> bool :
@@ -198,25 +215,8 @@ def object_attribute_dicts_find_unequal_fields(
198
215
and isinstance (one_val .model , type (other_val .model ))
199
216
)
200
217
201
- elif isinstance (one_val , list ):
202
- equal = isinstance (other_val , list ) and same_elements (one_val , other_val )
203
- elif isinstance (one_val , dict ):
204
- equal = isinstance (other_val , dict ) and sorted (one_val .keys ()) == sorted (
205
- other_val .keys ()
206
- )
207
- equal = equal and same_elements (
208
- list (one_val .values ()), list (other_val .values ())
209
- )
210
- elif isinstance (one_val , np .ndarray ):
211
- equal = np .array_equal (one_val , other_val , equal_nan = True )
212
- elif isinstance (one_val , datetime ):
213
- equal = datetime_equals (one_val , other_val )
214
- elif isinstance (one_val , float ):
215
- equal = np .isclose (one_val , other_val )
216
- elif isinstance (one_val , pd .DataFrame ):
217
- equal = dataframe_equals (one_val , other_val )
218
218
else :
219
- equal = one_val == other_val
219
+ equal = is_ax_equal ( one_val , other_val )
220
220
221
221
if not equal :
222
222
unequal_value [field ] = (one_val , other_val )
0 commit comments