Skip to content

Commit 0e86b24

Browse files
Fix inconsistency with set equality in approx_eq (#7792)
Compare set and frozenset items in a sorted order. Fall back to exact equality for sets that cannot be sorted. Fixes #6376 --------- Co-authored-by: Pavol Juhas <juhas@google.com>
1 parent 38318ba commit 0e86b24

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

cirq-core/cirq/protocols/approximate_equality_protocol.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,9 @@ def approx_eq(val: Any, other: Any, *, atol: float = 1e-8) -> bool:
115115

116116
# If the values are iterable, try comparing recursively on items.
117117
if isinstance(val, Iterable) and isinstance(other, Iterable):
118-
return _approx_eq_iterables(val, other, atol=atol)
118+
result = _approx_eq_iterables(val, other, atol=atol)
119+
if result is not NotImplemented:
120+
return result
119121

120122
# Last resort: exact equality.
121123
return val == other
@@ -141,6 +143,17 @@ def _approx_eq_iterables(val: Iterable, other: Iterable, *, atol: float) -> bool
141143
types.
142144
"""
143145

146+
if isinstance(val, (set, frozenset)):
147+
try:
148+
val = sorted(val)
149+
except TypeError:
150+
return NotImplemented
151+
if isinstance(other, (set, frozenset)):
152+
try:
153+
other = sorted(other)
154+
except TypeError:
155+
return NotImplemented
156+
144157
iter1 = iter(val)
145158
iter2 = iter(other)
146159
done = object()

cirq-core/cirq/protocols/approximate_equality_protocol_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,33 @@ def test_approx_eq_list() -> None:
162162
assert not cirq.approx_eq([1.1, 1.2, 1.3], [1, 1, 1], atol=0.2)
163163

164164

165+
def test_approx_eq_set() -> None:
166+
# create two equal sets with a different order
167+
found_differently_ordered_sets = False
168+
generate_pairs = ((i, j) for i in range(20) for j in range(i + 1, 20))
169+
for i, j in generate_pairs:
170+
sij = {cirq.LineQubit(i), cirq.LineQubit(j)}
171+
sji = {cirq.LineQubit(j), cirq.LineQubit(i)}
172+
if list(sij) != list(sji):
173+
found_differently_ordered_sets = True
174+
break
175+
assert found_differently_ordered_sets, "fix code for differently ordered sets"
176+
177+
# here sij, sji are equal, but have a different order
178+
assert cirq.approx_eq(sij, sji)
179+
assert cirq.approx_eq(sij, frozenset(sji))
180+
assert cirq.approx_eq(frozenset(sij), frozenset(sji))
181+
182+
# ensure approx_eq handles non-sortable sets
183+
unsortable = {"a", 0}
184+
assert cirq.approx_eq(unsortable, unsortable)
185+
assert cirq.approx_eq(unsortable, frozenset(unsortable))
186+
assert cirq.approx_eq(frozenset(unsortable), frozenset(unsortable))
187+
assert not cirq.approx_eq(unsortable, {"a", 1})
188+
# complete coverage for only the second argument being unsortable
189+
assert not cirq.approx_eq({"a", "b"}, unsortable)
190+
191+
165192
def test_approx_eq_symbol() -> None:
166193
q = cirq.GridQubit(0, 0)
167194
s = sympy.Symbol("s")

0 commit comments

Comments
 (0)