-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add drop_diagonal_before_measurement transformer #7790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d5442fb
9f96503
26765bd
3910677
571d11c
432b4d3
ca3e796
8894ea1
3125991
8d125ec
06bb2ad
03aeb03
df4a51f
6eb1441
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,141 @@ | ||
| # Copyright 2025 The Cirq Developers | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Transformer pass that removes diagonal gates before measurements.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| from cirq import circuits, ops, protocols, transformers | ||
| from cirq.transformers import transformer_api | ||
|
|
||
| if TYPE_CHECKING: | ||
| import cirq | ||
|
|
||
|
|
||
| def _is_z_or_cz_pow_gate(op: cirq.Operation) -> bool: | ||
| """Checks if an operation is a known diagonal gate (Z, CZ, etc.). | ||
|
|
||
| As suggested in review, we avoid computing the unitary matrix (which is expensive) | ||
| and instead strictly check for gates known to be diagonal in the computational basis. | ||
| """ | ||
| # ZPowGate covers Z, S, T, Rz. CZPowGate covers CZ. | ||
| return isinstance(op.gate, (ops.ZPowGate, ops.CZPowGate, ops.IdentityGate)) | ||
|
|
||
|
|
||
| @transformer_api.transformer | ||
| def drop_diagonal_before_measurement( | ||
| circuit: cirq.AbstractCircuit, *, context: cirq.TransformerContext | None = None | ||
| ) -> cirq.Circuit: | ||
| """Removes Z and CZ gates that appear immediately before measurements. | ||
|
|
||
| This transformer optimizes circuits by removing Z-type and CZ-type diagonal gates | ||
| (specifically ZPowGate instances like Z, S, T, Rz, and CZPowGate instances like CZ) | ||
| that appear immediately before measurement operations. Since measurements project onto | ||
| the computational basis, these diagonal gates applied immediately before a measurement | ||
| do not affect the measurement outcome and can be safely removed (when all their qubits | ||
| are measured). | ||
|
|
||
| To maximize the effectiveness of this optimization, the transformer first applies | ||
| the `eject_z` transformation, which pushes Z gates (and other diagonal phases) | ||
| later in the circuit. This handles cases where diagonal gates can commute past | ||
| other operations. For example: | ||
|
|
||
| Z(q0) - CZ(q0, q1) - measure(q0) - measure(q1) | ||
|
|
||
| After `eject_z`, the Z gate on the control qubit commutes through the CZ: | ||
|
|
||
| CZ(q0, q1) - Z(q1) - measure(q0) - measure(q1) | ||
|
|
||
| Then both the CZ and Z(q1) can be removed since all their qubits are measured: | ||
|
|
||
| measure(q0) - measure(q1) | ||
|
|
||
| Args: | ||
| circuit: Input circuit to transform. | ||
| context: `cirq.TransformerContext` storing common configurable options for transformers. | ||
|
|
||
| Returns: | ||
| Copy of the transformed input circuit with diagonal gates before measurements removed. | ||
|
|
||
| Examples: | ||
| >>> import cirq | ||
pavoljuhas marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| >>> q0, q1 = cirq.LineQubit.range(2) | ||
| >>> | ||
| >>> # Simple case: Z before measurement | ||
| >>> circuit = cirq.Circuit(cirq.H(q0), cirq.Z(q0), cirq.measure(q0)) | ||
| >>> optimized = cirq.drop_diagonal_before_measurement(circuit) | ||
| >>> print(optimized) | ||
| 0: ───H───M─── | ||
|
|
||
| >>> # Complex case: Z-CZ commutation with both qubits measured | ||
| >>> circuit = cirq.Circuit( | ||
| ... cirq.Z(q0), | ||
| ... cirq.CZ(q0, q1), | ||
| ... cirq.measure(q0), | ||
| ... cirq.measure(q1) | ||
| ... ) | ||
| >>> optimized = cirq.drop_diagonal_before_measurement(circuit) | ||
| >>> print(optimized) | ||
| 0: ───M─── | ||
| <BLANKLINE> | ||
| 1: ───M─── | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should still include the CZ gate. |
||
| """ | ||
| if context is None: | ||
| context = transformer_api.TransformerContext() | ||
|
|
||
| # Phase 1: Push Z gates later in the circuit to maximize removal opportunities. | ||
| circuit = transformers.eject_z(circuit, context=context) | ||
|
|
||
| # Phase 2: Remove diagonal gates that appear before measurements. | ||
| # We iterate in reverse to identify which qubits will be measured. | ||
| # Track qubits that will be measured (set grows as we go backwards) | ||
| measured_qubits: set[ops.Qid] = set() | ||
|
|
||
| # Build new moments in reverse | ||
| new_moments = [] | ||
| for moment in reversed(circuit): | ||
| new_ops = [] | ||
|
|
||
| for op in moment: | ||
| # If this is a measurement, mark these qubits as measured | ||
| if protocols.is_measurement(op): | ||
| measured_qubits.update(op.qubits) | ||
| new_ops.append(op) | ||
| # If this is a diagonal gate and ALL of its qubits will be measured, remove it | ||
| # (diagonal gates only affect phase, which doesn't impact computational basis | ||
| # measurements) | ||
| elif _is_z_or_cz_pow_gate(op): | ||
| # CRITICAL: we can only remove if all qubits involved are measured. | ||
| # if even one qubit is NOT measured, the gate must stay to preserve | ||
| # the state of that unmeasured qubit (due to phase kickback/entanglement). | ||
| if measured_qubits.issuperset(op.qubits): | ||
| continue # Drop the operation | ||
|
|
||
| new_ops.append(op) | ||
| # Note: We do NOT remove qubits from measured_qubits here. | ||
| # Diagonal gates commute with other diagonal gates. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a test for this? A CZ gate followed by a diagonal 4x4 unitary followed by measurement on both qubits. |
||
| else: | ||
| # Non-diagonal gate found. | ||
| new_ops.append(op) | ||
| # the chain is broken for these qubits. | ||
| measured_qubits.difference_update(op.qubits) | ||
|
|
||
| # Add the moment if it has any operations | ||
| if new_ops: | ||
| new_moments.append(circuits.Moment(new_ops)) | ||
|
|
||
| # Reverse back to original order | ||
| return circuits.Circuit(reversed(new_moments)) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,238 @@ | ||
| # Copyright 2025 The Cirq Developers | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Tests for diagonal_optimization transformer.""" | ||
|
|
||
|
|
||
| import numpy as np | ||
|
|
||
| import cirq | ||
| from cirq.transformers.diagonal_optimization import ( | ||
| _is_z_or_cz_pow_gate, | ||
| drop_diagonal_before_measurement, | ||
| ) | ||
|
|
||
|
|
||
| def test_removes_z_before_measure(): | ||
| """Tests that Z gates are removed before measurement.""" | ||
| q = cirq.NamedQubit('q') | ||
|
|
||
| # Original: H -> Z -> Measure | ||
| circuit = cirq.Circuit(cirq.H(q), cirq.Z(q), cirq.measure(q, key='m')) | ||
|
|
||
| optimized = drop_diagonal_before_measurement(circuit) | ||
|
|
||
| # Expected: H -> Measure (Z is gone) | ||
| expected = cirq.Circuit(cirq.H(q), cirq.measure(q, key='m')) | ||
|
|
||
| cirq.testing.assert_same_circuits(optimized, expected) | ||
|
|
||
|
|
||
| def test_removes_diagonal_chain(): | ||
| """Tests that a chain of diagonal gates is removed.""" | ||
| q = cirq.NamedQubit('q') | ||
|
|
||
| # Original: H -> Z -> S -> Measure | ||
| circuit = cirq.Circuit(cirq.H(q), cirq.Z(q), cirq.S(q), cirq.measure(q, key='m')) | ||
|
|
||
| optimized = drop_diagonal_before_measurement(circuit) | ||
|
|
||
| # Expected: H -> Measure (Both Z and S are gone) | ||
| expected = cirq.Circuit(cirq.H(q), cirq.measure(q, key='m')) | ||
|
|
||
| cirq.testing.assert_same_circuits(optimized, expected) | ||
|
|
||
|
|
||
| def test_keeps_z_blocked_by_x(): | ||
| """Tests that Z gates blocked by X gates are preserved.""" | ||
| q = cirq.NamedQubit('q') | ||
|
|
||
| # Original: Z -> X -> Measure | ||
| circuit = cirq.Circuit(cirq.Z(q), cirq.X(q), cirq.measure(q, key='m')) | ||
|
|
||
| # Z cannot commute past X, so it should be kept | ||
| # Note: eject_z will phase the X, so the circuit changes but Z is preserved | ||
| optimized = drop_diagonal_before_measurement(circuit) | ||
|
|
||
| # We use this helper to check mathematical equivalence | ||
| # instead of checking exact gate types (Y vs PhasedX) | ||
| cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(circuit, optimized) | ||
|
|
||
|
|
||
| def test_keeps_cz_if_only_one_qubit_measured(): | ||
| """Tests that CZ is kept if only one qubit is measured.""" | ||
| q0, q1 = cirq.LineQubit.range(2) | ||
|
|
||
| # Original: CZ(0,1) -> Measure(0) | ||
| circuit = cirq.Circuit(cirq.CZ(q0, q1), cirq.measure(q0, key='m')) | ||
|
|
||
| # CZ shouldn't be removed because q1 is not measured | ||
| optimized = drop_diagonal_before_measurement(circuit) | ||
|
|
||
| cirq.testing.assert_same_circuits(optimized, circuit) | ||
|
|
||
|
|
||
| def test_removes_cz_if_both_measured(): | ||
| """Tests that CZ is removed if both qubits are measured.""" | ||
| q0, q1 = cirq.LineQubit.range(2) | ||
|
|
||
| # Original: CZ(0,1) -> Measure(0), Measure(1) | ||
| circuit = cirq.Circuit(cirq.CZ(q0, q1), cirq.measure(q0, key='m0'), cirq.measure(q1, key='m1')) | ||
|
|
||
| optimized = drop_diagonal_before_measurement(circuit) | ||
|
|
||
| # Expected: Measures only | ||
| expected = cirq.Circuit(cirq.measure(q0, key='m0'), cirq.measure(q1, key='m1')) | ||
|
|
||
| cirq.testing.assert_same_circuits(optimized, expected) | ||
|
|
||
|
|
||
| def test_feature_request_z_cz_commutation(): | ||
| """Test the original feature request #4935: Z-CZ commutation before measurement. | ||
|
|
||
| The circuit Z(q0) - CZ(q0, q1) - Z(q1) - M(q1) should keep the CZ gate. | ||
| This is because: | ||
| 1. Z(q0) commutes through the CZ and Z(q1) is removed (via eject_z) | ||
| 2. After commutation: CZ(q0, q1) - Z(q0) - M(q1) | ||
| 3. CZ(q0, q1) and Z(q0) must be kept (q0 is not measured) | ||
|
|
||
| The optimized circuit is: CZ(q0, q1) - Z(q0) - M(q1) | ||
| """ | ||
| q0, q1 = cirq.LineQubit.range(2) | ||
|
|
||
| # Original feature request circuit | ||
| circuit = cirq.Circuit(cirq.Z(q0), cirq.CZ(q0, q1), cirq.Z(q1), cirq.measure(q1, key='m1')) | ||
|
|
||
| optimized = drop_diagonal_before_measurement(circuit) | ||
|
|
||
| # Expected: CZ(q0, q1) - Z(q0) - M(q1) | ||
| expected = cirq.Circuit(cirq.CZ(q0, q1), cirq.Z(q0), cirq.Moment(cirq.measure(q1, key='m1'))) | ||
|
|
||
| cirq.testing.assert_same_circuits(optimized, expected) | ||
|
|
||
|
|
||
| def test_feature_request_full_example(): | ||
| """Test the full feature request #4935 with measurements on both qubits.""" | ||
| q0, q1 = cirq.LineQubit.range(2) | ||
|
|
||
| # From feature request | ||
| circuit = cirq.Circuit( | ||
| cirq.Z(q0), | ||
| cirq.CZ(q0, q1), | ||
| cirq.Z(q1), | ||
| cirq.Moment(cirq.measure(q0, key='m0'), cirq.measure(q1, key='m1')), | ||
| ) | ||
|
|
||
| optimized = drop_diagonal_before_measurement(circuit) | ||
|
|
||
| # Should simplify to just measurements | ||
| expected = cirq.Circuit(cirq.measure(q0, key='m0'), cirq.measure(q1, key='m1')) | ||
|
|
||
| cirq.testing.assert_same_circuits(optimized, expected) | ||
|
|
||
|
|
||
| def test_preserves_non_diagonal_gates(): | ||
| """Test that non-diagonal gates are preserved.""" | ||
| q = cirq.NamedQubit('q') | ||
|
|
||
| circuit = cirq.Circuit(cirq.H(q), cirq.X(q), cirq.Z(q), cirq.measure(q, key='m')) | ||
|
|
||
| optimized = drop_diagonal_before_measurement(circuit) | ||
|
|
||
| # Verify the physics hasn't changed (handles PhasedX vs Y differences) | ||
| cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(circuit, optimized) | ||
|
|
||
|
|
||
| def test_diagonal_gates_commute_before_measurement(): | ||
| """Test that multiple recognized diagonal gates are all removed when all qubits are measured. | ||
|
|
||
| This tests the property that recognized diagonal gates (Z, CZ) commute with each other, | ||
| so we don't remove qubits from measured_qubits when we encounter them. This allows | ||
| earlier diagonal gates in the circuit to also be removed. | ||
| """ | ||
| q0, q1 = cirq.LineQubit.range(2) | ||
|
|
||
| # Circuit with multiple recognized diagonal gates before measurements | ||
| circuit = cirq.Circuit( | ||
| cirq.CZ(q0, q1), | ||
| cirq.Z(q0), | ||
| cirq.Z(q1), | ||
| cirq.measure(q0, key='m0'), | ||
| cirq.measure(q1, key='m1'), | ||
| ) | ||
|
|
||
| optimized = drop_diagonal_before_measurement(circuit) | ||
|
|
||
| # All recognized diagonal gates should be removed since all qubits are measured | ||
| expected = cirq.Circuit(cirq.measure(q0, key='m0'), cirq.measure(q1, key='m1')) | ||
|
|
||
| cirq.testing.assert_same_circuits(optimized, expected) | ||
|
|
||
|
|
||
| def test_unrecognized_diagonal_breaks_chain(): | ||
| """Test that a CZ followed by an unrecognized diagonal 4x4 unitary is handled correctly. | ||
|
|
||
| Even if a gate is diagonal, if it's not a ZPowGate or CZPowGate, it won't be recognized | ||
| and will break the optimization chain. The earlier CZ gate cannot be removed because | ||
| the unrecognized diagonal gate blocks it. | ||
| """ | ||
| q0, q1 = cirq.LineQubit.range(2) | ||
|
|
||
| # Create a custom diagonal 4x4 unitary (not a CZPowGate) | ||
| # This is diagonal but won't be recognized by _is_z_or_cz_pow_gate | ||
| diagonal_matrix = np.diag([1, 1j, -1, -1j]) | ||
| custom_diagonal_gate = cirq.MatrixGate(diagonal_matrix) | ||
|
|
||
| # Circuit: CZ -> custom diagonal -> measurements | ||
| circuit = cirq.Circuit( | ||
| cirq.CZ(q0, q1), | ||
| custom_diagonal_gate(q0, q1), | ||
| cirq.measure(q0, key='m0'), | ||
| cirq.measure(q1, key='m1'), | ||
| ) | ||
|
|
||
| optimized = drop_diagonal_before_measurement(circuit) | ||
|
|
||
| # The custom diagonal gate is not recognized, so it blocks the chain | ||
| # Only the custom diagonal gate can be removed... wait, no! It's not recognized | ||
| # so it won't be removed at all. And it breaks the chain for q0 and q1. | ||
| # So the CZ also cannot be removed. | ||
| cirq.testing.assert_same_circuits(optimized, circuit) | ||
|
|
||
|
|
||
| def test_is_z_or_cz_pow_gate_helper_edge_cases(): | ||
| """Test edge cases in _is_z_or_cz_pow_gate helper function for full coverage.""" | ||
|
|
||
| q = cirq.NamedQubit('q') | ||
|
|
||
| # Test Z gates (including variants like S and T) | ||
| assert _is_z_or_cz_pow_gate(cirq.Z(q)) | ||
| assert _is_z_or_cz_pow_gate(cirq.S(q)) # S is Z**0.5 | ||
| assert _is_z_or_cz_pow_gate(cirq.T(q)) # T is Z**0.25 | ||
|
|
||
| # Test identity gate | ||
| assert _is_z_or_cz_pow_gate(cirq.I(q)) | ||
|
|
||
| # Test non-diagonal gates | ||
| assert not _is_z_or_cz_pow_gate(cirq.H(q)) | ||
| assert not _is_z_or_cz_pow_gate(cirq.X(q)) | ||
| assert not _is_z_or_cz_pow_gate(cirq.Y(q)) | ||
|
|
||
| # Test two-qubit CZ gate | ||
| q0, q1 = cirq.LineQubit.range(2) | ||
| assert _is_z_or_cz_pow_gate(cirq.CZ(q0, q1)) | ||
|
|
||
| # Other diagonal gates (like CCZ) are not detected by the optimized version | ||
| # This is intentional - eject_z is only effective for Z and CZ anyway | ||
| assert not _is_z_or_cz_pow_gate(cirq.CCZ(q0, q1, q)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring should be updated to reflect the code. A diagonal gate is dropped only if all its qubits are measured. So the CZ(q0, q1) - Z(q1) - measure(q1) example is not correct, because the CZ gate can't be removed. If q0 is also measured then it can be removed.