Skip to content

Commit 9c443b6

Browse files
authored
Fix: use context argument in drop_diagonal_before_measurement (#7837)
Respect the `tags_to_ignore` and `deep` options in `drop_diagonal_before_measurement` Fixes #7831
1 parent 5c23169 commit 9c443b6

File tree

2 files changed

+129
-4
lines changed

2 files changed

+129
-4
lines changed

cirq-core/cirq/transformers/diagonal_optimization.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def _is_z_or_cz_pow_gate(op: cirq.Operation) -> bool:
3535
return isinstance(op.gate, (ops.ZPowGate, ops.CZPowGate, ops.IdentityGate))
3636

3737

38-
@transformer_api.transformer
38+
@transformer_api.transformer(add_deep_support=True)
3939
def drop_diagonal_before_measurement(
4040
circuit: cirq.AbstractCircuit, *, context: cirq.TransformerContext | None = None
4141
) -> cirq.Circuit:
@@ -96,6 +96,9 @@ def drop_diagonal_before_measurement(
9696
if context is None:
9797
context = transformer_api.TransformerContext()
9898

99+
# Extract tags_to_ignore for efficient lookup (frozenset for immutability)
100+
tags_to_ignore = frozenset(context.tags_to_ignore)
101+
99102
# Phase 1: Push Z gates later in the circuit to maximize removal opportunities.
100103
circuit = transformers.eject_z(circuit, context=context)
101104

@@ -116,12 +119,13 @@ def drop_diagonal_before_measurement(
116119
new_ops.append(op)
117120
# If this is a diagonal gate and ALL of its qubits will be measured, remove it
118121
# (diagonal gates only affect phase, which doesn't impact computational basis
119-
# measurements)
122+
# measurements). Skip removal if operation has tags_to_ignore.
120123
elif _is_z_or_cz_pow_gate(op):
121-
# CRITICAL: we can only remove if all qubits involved are measured.
124+
# CRITICAL: we can only remove if all qubits involved are measured
125+
# AND the operation is not tagged to be ignored.
122126
# if even one qubit is NOT measured, the gate must stay to preserve
123127
# the state of that unmeasured qubit (due to phase kickback/entanglement).
124-
if measured_qubits.issuperset(op.qubits):
128+
if tags_to_ignore.isdisjoint(op.tags) and measured_qubits.issuperset(op.qubits):
125129
continue # Drop the operation
126130

127131
new_ops.append(op)

cirq-core/cirq/transformers/diagonal_optimization_test.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,124 @@ def test_is_z_or_cz_pow_gate_helper_edge_cases():
236236
# Other diagonal gates (like CCZ) are not detected by the optimized version
237237
# This is intentional - eject_z is only effective for Z and CZ anyway
238238
assert not _is_z_or_cz_pow_gate(cirq.CCZ(q0, q1, q))
239+
240+
241+
def test_tags_to_ignore_preserves_tagged_operations():
242+
"""Test that operations with tags_to_ignore are preserved and not optimized."""
243+
q0 = cirq.LineQubit(0)
244+
245+
# Circuit with a Z gate tagged with "ignore" followed by measurement
246+
# Without tags_to_ignore, the Z would be removed
247+
circuit = cirq.Circuit(cirq.Z(q0).with_tags("ignore"), cirq.measure(q0, key='m'))
248+
249+
# Apply transformer with tags_to_ignore
250+
context = cirq.TransformerContext(tags_to_ignore=("ignore",))
251+
optimized = drop_diagonal_before_measurement(circuit, context=context)
252+
253+
# The tagged Z gate should be preserved
254+
cirq.testing.assert_same_circuits(optimized, circuit)
255+
256+
257+
def test_tags_to_ignore_does_not_break_optimization_chain():
258+
"""Test that tagged diagonal operations don't break the optimization chain.
259+
260+
For Z(q) -> Z[ignore](q) -> M(q), the first Z should still be removed because:
261+
1. Diagonal gates commute with each other
262+
2. The tagged Z is preserved but doesn't block earlier diagonal gates
263+
"""
264+
q0 = cirq.LineQubit(0)
265+
266+
# Circuit: Z -> Z(tagged) -> measure
267+
circuit = cirq.Circuit(cirq.Z(q0), cirq.Z(q0).with_tags("ignore"), cirq.measure(q0, key='m'))
268+
269+
context = cirq.TransformerContext(tags_to_ignore=("ignore",))
270+
optimized = drop_diagonal_before_measurement(circuit, context=context)
271+
272+
# The first Z is removed, but tagged Z is preserved
273+
expected = cirq.Circuit(cirq.Z(q0).with_tags("ignore"), cirq.measure(q0, key='m'))
274+
cirq.testing.assert_same_circuits(optimized, expected)
275+
276+
277+
def test_tags_to_ignore_only_affects_tagged_operations():
278+
"""Test that untagged operations are still optimized when tags_to_ignore is set."""
279+
q0, q1 = cirq.LineQubit.range(2)
280+
281+
# Circuit with one tagged Z (preserved) and one untagged Z (should be removed)
282+
circuit = cirq.Circuit(
283+
cirq.Z(q0).with_tags("ignore"),
284+
cirq.Z(q1),
285+
cirq.measure(q0, key='m0'),
286+
cirq.measure(q1, key='m1'),
287+
)
288+
289+
context = cirq.TransformerContext(tags_to_ignore=("ignore",))
290+
optimized = drop_diagonal_before_measurement(circuit, context=context)
291+
292+
# q0's Z is preserved (tagged), q1's Z is removed (untagged)
293+
# The tagged Z breaks the chain for q0, so it stays in its own moment
294+
expected = cirq.Circuit(
295+
cirq.Moment(cirq.Z(q0).with_tags("ignore")),
296+
cirq.Moment(cirq.measure(q0, key='m0'), cirq.measure(q1, key='m1')),
297+
)
298+
299+
cirq.testing.assert_same_circuits(optimized, expected)
300+
301+
302+
def test_deep_transforms_sub_circuits():
303+
"""Test that deep=True applies transformation to sub-circuits in CircuitOperation.
304+
305+
Uses CZ gate to truly test deep support - a Z gate alone would be removed by eject_z.
306+
"""
307+
q0, q1 = cirq.LineQubit.range(2)
308+
309+
# Create a sub-circuit with CZ before measurements on both qubits
310+
sub_circuit = cirq.FrozenCircuit(
311+
cirq.CZ(q0, q1), cirq.measure(q0, key='m0'), cirq.measure(q1, key='m1')
312+
)
313+
circuit_op = cirq.CircuitOperation(sub_circuit)
314+
circuit = cirq.Circuit(circuit_op)
315+
316+
# Apply transformer with deep=True
317+
context = cirq.TransformerContext(deep=True)
318+
optimized = drop_diagonal_before_measurement(circuit, context=context)
319+
320+
# The sub-circuit should have the CZ removed (both qubits are measured)
321+
expected_sub_circuit = cirq.FrozenCircuit(
322+
cirq.measure(q0, key='m0'), cirq.measure(q1, key='m1')
323+
)
324+
expected = cirq.Circuit(cirq.CircuitOperation(expected_sub_circuit))
325+
326+
cirq.testing.assert_same_circuits(optimized, expected)
327+
328+
329+
def test_deep_false_preserves_sub_circuits():
330+
"""Test that deep=False (default) does not modify sub-circuits."""
331+
q0 = cirq.LineQubit(0)
332+
333+
# Create a sub-circuit with Z before measurement
334+
sub_circuit = cirq.FrozenCircuit(cirq.Z(q0), cirq.measure(q0, key='m'))
335+
circuit_op = cirq.CircuitOperation(sub_circuit)
336+
circuit = cirq.Circuit(circuit_op)
337+
338+
# Apply transformer without deep (default is False)
339+
optimized = drop_diagonal_before_measurement(circuit)
340+
341+
# The sub-circuit should be unchanged
342+
cirq.testing.assert_same_circuits(optimized, circuit)
343+
344+
345+
def test_deep_with_tags_to_ignore_in_sub_circuit():
346+
"""Test that tags_to_ignore is respected within sub-circuits when deep=True."""
347+
q0 = cirq.LineQubit(0)
348+
349+
# Create a sub-circuit with a tagged Z before measurement
350+
sub_circuit = cirq.FrozenCircuit(cirq.Z(q0).with_tags("ignore"), cirq.measure(q0, key='m'))
351+
circuit_op = cirq.CircuitOperation(sub_circuit)
352+
circuit = cirq.Circuit(circuit_op)
353+
354+
# Apply transformer with deep=True and tags_to_ignore
355+
context = cirq.TransformerContext(deep=True, tags_to_ignore=("ignore",))
356+
optimized = drop_diagonal_before_measurement(circuit, context=context)
357+
358+
# The sub-circuit should be unchanged (tagged Z preserved)
359+
cirq.testing.assert_same_circuits(optimized, circuit)

0 commit comments

Comments
 (0)