Skip to content

Commit 9b89f3c

Browse files
codrut3pavoljuhas
andauthored
Fix measurement and control key commutativity bug in insertion_sort_transformer. (#7822)
Check measurement and control key commutativity before swapping two operations. Fixes #7472 --------- Co-authored-by: Pavol Juhas <juhas@google.com>
1 parent 0e86b24 commit 9b89f3c

File tree

2 files changed

+81
-17
lines changed

2 files changed

+81
-17
lines changed

cirq-core/cirq/transformers/insertion_sort.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,41 @@ def insertion_sort_transformer(
4242
q: idx for idx, q in enumerate(sorted(circuit.all_qubits()))
4343
}
4444
cached_qubit_indices: dict[int, list[int]] = {}
45+
cached_measurement_keys: dict[int, frozenset[cirq.MeasurementKey]] = {}
46+
cached_control_keys: dict[int, frozenset[cirq.MeasurementKey]] = {}
4547
for pos, op in enumerate(circuit.all_operations()):
4648
# here `pos` is at the append position of final_operations
47-
if (op_qubit_indices := cached_qubit_indices.get(id(op))) is None:
48-
op_qubit_indices = cached_qubit_indices[id(op)] = sorted(
49+
op_id = id(op)
50+
if (op_qubit_indices := cached_qubit_indices.get(op_id)) is None:
51+
op_qubit_indices = cached_qubit_indices[op_id] = sorted(
4952
qubit_index[q] for q in op.qubits
5053
)
54+
if (op_measurement_keys := cached_measurement_keys.get(op_id)) is None:
55+
op_measurement_keys = cached_measurement_keys[op_id] = protocols.measurement_key_objs(
56+
op
57+
)
58+
if (op_control_keys := cached_control_keys.get(op_id)) is None:
59+
op_control_keys = cached_control_keys[op_id] = protocols.control_keys(op)
60+
5161
for tail_op in reversed(final_operations):
52-
tail_qubit_indices = cached_qubit_indices[id(tail_op)]
53-
if op_qubit_indices < tail_qubit_indices and (
54-
# special case for zero-qubit gates
55-
not op_qubit_indices
56-
# check if two sorted sequences are disjoint
57-
or op_qubit_indices[-1] < tail_qubit_indices[0]
58-
or set(op_qubit_indices).isdisjoint(tail_qubit_indices)
59-
# fallback to more expensive commutation check
60-
or protocols.commutes(op, tail_op, default=False)
62+
tail_id = id(tail_op)
63+
tail_qubit_indices = cached_qubit_indices[tail_id]
64+
tail_measurement_keys = cached_measurement_keys[tail_id]
65+
tail_control_keys = cached_control_keys[tail_id]
66+
if (
67+
op_qubit_indices < tail_qubit_indices
68+
and op_measurement_keys.isdisjoint(tail_measurement_keys)
69+
and op_control_keys.isdisjoint(tail_measurement_keys)
70+
and tail_control_keys.isdisjoint(op_measurement_keys)
71+
and (
72+
# special case for zero-qubit gates
73+
not op_qubit_indices
74+
# check if two sorted sequences are disjoint
75+
or op_qubit_indices[-1] < tail_qubit_indices[0]
76+
or set(op_qubit_indices).isdisjoint(tail_qubit_indices)
77+
# fallback to more expensive commutation check
78+
or protocols.commutes(op, tail_op, default=False)
79+
)
6180
):
6281
pos -= 1
6382
continue

cirq-core/cirq/transformers/insertion_sort_test.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,55 @@ def test_insertion_sort() -> None:
2727
cirq.GlobalPhaseGate(1j).on(),
2828
)
2929
sorted_circuit = cirq.transformers.insertion_sort_transformer(c)
30-
assert sorted_circuit == cirq.Circuit(
31-
cirq.GlobalPhaseGate(1j).on(),
32-
cirq.CZ(cirq.q(0), cirq.q(1)),
33-
cirq.CZ(cirq.q(2), cirq.q(1)),
34-
cirq.CZ(cirq.q(2), cirq.q(1)),
35-
cirq.CZ(cirq.q(2), cirq.q(4)),
30+
cirq.testing.assert_same_circuits(
31+
sorted_circuit,
32+
cirq.Circuit(
33+
cirq.GlobalPhaseGate(1j).on(),
34+
cirq.CZ(cirq.q(0), cirq.q(1)),
35+
cirq.CZ(cirq.q(2), cirq.q(1)),
36+
cirq.CZ(cirq.q(2), cirq.q(1)),
37+
cirq.CZ(cirq.q(2), cirq.q(4)),
38+
),
39+
)
40+
41+
42+
def test_insertion_sort_same_measurement_key() -> None:
43+
q0, q1 = cirq.LineQubit.range(2)
44+
c = cirq.Circuit(cirq.measure(q1, key='k'), cirq.measure(q0, key='k'))
45+
cirq.testing.assert_same_circuits(cirq.transformers.insertion_sort_transformer(c), c)
46+
47+
48+
def test_insertion_sort_measurement_and_control_key_conflict() -> None:
49+
q0, q1 = cirq.LineQubit.range(2)
50+
c = cirq.Circuit(cirq.measure(q1, key='k'), cirq.X(q0).with_classical_controls('k'))
51+
# Second operation depends on the first so they don't commute.
52+
cirq.testing.assert_same_circuits(cirq.transformers.insertion_sort_transformer(c), c)
53+
54+
55+
def test_insertion_sort_measurement_and_control_key_conflict_other_way_around() -> None:
56+
q0, q1 = cirq.LineQubit.range(2)
57+
c = cirq.Circuit(
58+
cirq.measure(q0, key='k'),
59+
cirq.X(q1).with_classical_controls('k'),
60+
cirq.measure(q0, key='k'),
61+
)
62+
cirq.testing.assert_same_circuits(cirq.transformers.insertion_sort_transformer(c), c)
63+
64+
65+
def test_insertion_sort_distinct_measurement_keys() -> None:
66+
q0, q1 = cirq.LineQubit.range(2)
67+
c = cirq.Circuit(cirq.measure(q1, key='k1'), cirq.measure(q0, key='k0'))
68+
# Measurement keys are distinct, so the measurements commute.
69+
expected = cirq.Circuit(cirq.measure(q0, key='k0'), cirq.measure(q1, key='k1'))
70+
assert cirq.transformers.insertion_sort_transformer(c)[0].operations == expected[0].operations
71+
72+
73+
def test_insertion_sort_shared_control_key() -> None:
74+
q0, q1 = cirq.LineQubit.range(2)
75+
c = cirq.Circuit(
76+
cirq.X(q1).with_classical_controls('k'), cirq.X(q0).with_classical_controls('k')
77+
)
78+
expected = cirq.Circuit(
79+
cirq.X(q0).with_classical_controls('k'), cirq.X(q1).with_classical_controls('k')
3680
)
81+
assert cirq.transformers.insertion_sort_transformer(c)[0].operations == expected[0].operations

0 commit comments

Comments
 (0)