Skip to content

Commit 932746d

Browse files
Daniel Cohenfacebook-github-bot
Daniel Cohen
authored andcommitted
Fix equality check for floats (#2507)
Summary: Pull Request resolved: #2507 `same_elements()` wasn't working for `float('nan')`, and it wasn't treating floats with `np.is_close()` like in `object_attribute_dicts_find_unequal_fields()`. Also, `same_elements` was generally broken. Example: {F1676730235} Reviewed By: saitcakmak Differential Revision: D58289519 fbshipit-source-id: 6fd2a253968763a8a52956c87cd2a97975a9755f
1 parent 40ae984 commit 932746d

File tree

2 files changed

+40
-33
lines changed

2 files changed

+40
-33
lines changed

ax/utils/common/equality.py

+33-33
Original file line numberDiff line numberDiff line change
@@ -42,30 +42,47 @@ def same_elements(list1: List[Any], list2: List[Any]) -> bool:
4242
-- The lists do not contain duplicates
4343
4444
Checking equality is then the same as checking that the lists are the same
45-
length, and that one is a subset of the other.
45+
length, and that both are subsets of the other.
4646
"""
4747

4848
if len(list1) != len(list2):
4949
return False
5050

51+
matched = [False for _ in list2]
5152
for item1 in list1:
52-
found = False
53-
for item2 in list2:
54-
if isinstance(item1, np.ndarray) or isinstance(item2, np.ndarray):
55-
if (
56-
isinstance(item1, np.ndarray)
57-
and isinstance(item2, np.ndarray)
58-
and np.array_equal(item1, item2)
59-
):
60-
found = True
61-
break
62-
elif item1 == item2:
63-
found = True
53+
matched_this_item = False
54+
for i, item2 in enumerate(list2):
55+
if not matched[i] and is_ax_equal(item1, item2):
56+
matched[i] = True
57+
matched_this_item = True
6458
break
65-
if not found:
59+
if not matched_this_item:
6660
return False
61+
return all(matched)
6762

68-
return True
63+
64+
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
65+
def is_ax_equal(one_val: Any, other_val: Any) -> bool:
66+
"""Check for equality of two values, handling lists, dicts, dfs, floats,
67+
dates, and numpy arrays. This method and ``same_elements`` function
68+
as a recursive unit.
69+
"""
70+
if isinstance(one_val, list) and isinstance(other_val, list):
71+
return same_elements(one_val, other_val)
72+
elif isinstance(one_val, dict) and isinstance(other_val, dict):
73+
return sorted(one_val.keys()) == sorted(other_val.keys()) and same_elements(
74+
list(one_val.values()), list(other_val.values())
75+
)
76+
elif isinstance(one_val, np.ndarray) and isinstance(other_val, np.ndarray):
77+
return np.array_equal(one_val, other_val, equal_nan=True)
78+
elif isinstance(one_val, datetime):
79+
return datetime_equals(one_val, other_val)
80+
elif isinstance(one_val, float) and isinstance(other_val, float):
81+
return np.isclose(one_val, other_val, equal_nan=True)
82+
elif isinstance(one_val, pd.DataFrame) and isinstance(other_val, pd.DataFrame):
83+
return dataframe_equals(one_val, other_val)
84+
else:
85+
return one_val == other_val
6986

7087

7188
def datetime_equals(dt1: Optional[datetime], dt2: Optional[datetime]) -> bool:
@@ -198,25 +215,8 @@ def object_attribute_dicts_find_unequal_fields(
198215
and isinstance(one_val.model, type(other_val.model))
199216
)
200217

201-
elif isinstance(one_val, list):
202-
equal = isinstance(other_val, list) and same_elements(one_val, other_val)
203-
elif isinstance(one_val, dict):
204-
equal = isinstance(other_val, dict) and sorted(one_val.keys()) == sorted(
205-
other_val.keys()
206-
)
207-
equal = equal and same_elements(
208-
list(one_val.values()), list(other_val.values())
209-
)
210-
elif isinstance(one_val, np.ndarray):
211-
equal = np.array_equal(one_val, other_val, equal_nan=True)
212-
elif isinstance(one_val, datetime):
213-
equal = datetime_equals(one_val, other_val)
214-
elif isinstance(one_val, float):
215-
equal = np.isclose(one_val, other_val)
216-
elif isinstance(one_val, pd.DataFrame):
217-
equal = dataframe_equals(one_val, other_val)
218218
else:
219-
equal = one_val == other_val
219+
equal = is_ax_equal(one_val, other_val)
220220

221221
if not equal:
222222
unequal_value[field] = (one_val, other_val)

ax/utils/common/tests/test_equality.py

+7
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,15 @@ def eq(x, y):
3434
def test_ListsEquals(self) -> None:
3535
self.assertFalse(same_elements([0], [0, 1]))
3636
self.assertFalse(same_elements([1, 0], [0, 2]))
37+
self.assertFalse(same_elements([1, 1], [1, 2]))
38+
self.assertFalse(same_elements([1, 2], [1, 1]))
39+
self.assertFalse(same_elements([1, 1, 2], [1, 2, 2]))
3740
self.assertTrue(same_elements([1, 0], [0, 1]))
3841

42+
def test_ListsEquals_floats(self) -> None:
43+
self.assertTrue(same_elements([0.0], [0.000000000000001]))
44+
self.assertTrue(same_elements([float("nan")], [float("nan")]))
45+
3946
def test_DatetimeEquals(self) -> None:
4047
now = datetime.now()
4148
self.assertTrue(datetime_equals(None, None))

0 commit comments

Comments
 (0)