diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index e3d1c9a0d35..a6e37eb0882 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -150,6 +150,7 @@ SpinInversionGaugeTransformer as SpinInversionGaugeTransformer, SqrtCZGaugeTransformer as SqrtCZGaugeTransformer, SqrtISWAPGaugeTransformer as SqrtISWAPGaugeTransformer, + CPhaseGaugeTransformerMM as CPhaseGaugeTransformerMM, ) from cirq.transformers.randomized_measurements import ( diff --git a/cirq-core/cirq/transformers/gauge_compiling/__init__.py b/cirq-core/cirq/transformers/gauge_compiling/__init__.py index 8a032be778b..7eb589df32f 100644 --- a/cirq-core/cirq/transformers/gauge_compiling/__init__.py +++ b/cirq-core/cirq/transformers/gauge_compiling/__init__.py @@ -47,3 +47,11 @@ from cirq.transformers.gauge_compiling.idle_moments_gauge import ( IdleMomentsGauge as IdleMomentsGauge, ) + +from cirq.transformers.gauge_compiling.multi_moment_gauge_compiling import ( + MultiMomentGaugeTransformer as MultiMomentGaugeTransformer, +) + +from cirq.transformers.gauge_compiling.multi_moment_cphase_gauge import ( + CPhaseGaugeTransformerMM as CPhaseGaugeTransformerMM, +) diff --git a/cirq-core/cirq/transformers/gauge_compiling/multi_moment_cphase_gauge.py b/cirq-core/cirq/transformers/gauge_compiling/multi_moment_cphase_gauge.py new file mode 100644 index 00000000000..fe800037b9f --- /dev/null +++ b/cirq-core/cirq/transformers/gauge_compiling/multi_moment_cphase_gauge.py @@ -0,0 +1,245 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Multi-Moment Gauge Transformer for the cphase gate.""" + +from __future__ import annotations + +from typing import cast + +import numpy as np +from attrs import field, frozen + +from cirq import circuits, ops +from cirq.transformers.gauge_compiling.multi_moment_gauge_compiling import ( + MultiMomentGaugeTransformer, +) + +_PAULIS: np.ndarray = np.array((ops.I, ops.X, ops.Y, ops.Z), dtype=object) +_COMMUTING_GATES = {ops.I, ops.Z} # I,Z Commute with ZPowGate and CZPowGate; X,Y anti-commute. + + +def _merge_pauliandzpow(left: _PauliAndZPow, right: _PauliAndZPow) -> _PauliAndZPow: + # 1. Commute left.zpow and right.pauli: + # ─left.pauli─left.zpow─right.pauli─right.zpow─ + # ==> ─left.pauli─right.pauli─(+/-left.zpow─right.zpow)─ + if right.pauli in _COMMUTING_GATES: + new_zpow_exp = left.zpow.exponent + right.zpow.exponent + else: + new_zpow_exp = -left.zpow.exponent + right.zpow.exponent + + # 2. Merge left.pauli and right.pauli + new_pauli = left.pauli + if right.pauli is not ops.I: + if new_pauli is ops.I: + new_pauli = right.pauli + else: + # left.pauli * right.pauli + new_pauli = cast(ops.Pauli, new_pauli).phased_pauli_product( + cast(ops.Pauli, right.pauli) + )[1] + + return _PauliAndZPow(pauli=new_pauli, zpow=ops.ZPowGate(exponent=new_zpow_exp)) + + +@frozen +class _PauliAndZPow: + """A gate represented by a Pauli followed by a ZPowGate. + + The order is ─Pauli──ZPowGate─. + + Attributes: + pauli: The Pauli gate. + zpow: The ZPowGate. + """ + + pauli: ops.Pauli | ops.IdentityGate = ops.I + zpow: ops.ZPowGate = ops.ZPowGate(exponent=0) + + def merge_left(self, left: _PauliAndZPow) -> _PauliAndZPow: + """Merges another `_PauliAndZPow` from the left. + + Calculates `─left─self─` and returns a new `_PauliAndZPow` instance. + """ + return _merge_pauliandzpow(left, self) + + def merge_right(self, right: _PauliAndZPow) -> _PauliAndZPow: + """Merges another `_PauliAndZPow` from the right. + + Calculates `─self─right─` and returns a new `_PauliAndZPow` instance. + """ + return _merge_pauliandzpow(self, right) + + def after_cphase( + self, cphase: ops.CZPowGate + ) -> tuple[ops.CZPowGate, _PauliAndZPow, _PauliAndZPow]: + """Pull self through cphase. + + Returns: + A tuple of + (updated cphase gate, pull_through of this qubit, pull_through of the other qubit). + """ + if self.pauli in _COMMUTING_GATES: + return cphase, _PauliAndZPow(self.pauli, self.zpow), _PauliAndZPow() + else: + # Taking self.pauli==X gate as an example: + # 0: ─X─Z^t──@────── 0: ─X──@─────Z^t─ 0: ─@──────X──Z^t── + # │ ==> │ ==> │ + # 1: ────────@^exp── 1: ────@^exp───── 1: ─@^-exp─Z^exp─── + # Similarly for X|Y on qubit 0/1, the result is always flipping cphase and + # add an extra Rz rotation on the other qubit. + return ( + cast(ops.CZPowGate, cphase**-1), + _PauliAndZPow(self.pauli, self.zpow), + _PauliAndZPow(zpow=ops.ZPowGate(exponent=cphase.exponent)), + ) + + def after_pauli(self, pauli: ops.Pauli | ops.IdentityGate) -> _PauliAndZPow: + """Calculates ─self─pauli─ ==> ─pauli─output─.""" + if pauli in _COMMUTING_GATES: + return _PauliAndZPow(self.pauli, self.zpow) + else: + return _PauliAndZPow(self.pauli, ops.ZPowGate(exponent=-self.zpow.exponent)) + + def after_zpow(self, zpow: ops.ZPowGate) -> tuple[ops.ZPowGate, _PauliAndZPow]: + """Calculates ─self─zpow─ ==> ─+/-zpow─output─.""" + if self.pauli in _COMMUTING_GATES: + return zpow, _PauliAndZPow(self.pauli, self.zpow) + else: + return ops.ZPowGate(exponent=-zpow.exponent), self + + def __str__(self) -> str: + return f"─{self.pauli}──{self.zpow}─" + + def to_single_qubit_gate(self) -> ops.PhasedXZGate | ops.ZPowGate | ops.IdentityGate: + """Converts the _PauliAndZPow to a single-qubit gate.""" + exp = self.zpow.exponent + match self.pauli: + case ops.I: + if exp % 2 == 0: + return ops.I + return self.zpow + case ops.X: + return ops.PhasedXZGate(x_exponent=1, z_exponent=exp, axis_phase_exponent=0) + case ops.Y: + return ops.PhasedXZGate(x_exponent=1, z_exponent=exp - 1, axis_phase_exponent=0) + case _: # ops.Z + return ops.ZPowGate(exponent=1 + exp) + + +def _pull_through_single_cphase( + cphase: ops.CZPowGate, input0: _PauliAndZPow, input1: _PauliAndZPow +) -> tuple[ops.CZPowGate, _PauliAndZPow, _PauliAndZPow]: + """Pulls input0 and input1 through a CZPowGate. + Input: Output: + 0: ─(input0)─@───── 0: ─@────────(output0)─ + │ ==> │ + 1: ─(input1)─@^exp─ 1: ─@^+/-exp─(output1)─ + """ + + # Step 1; pull input0 through CZPowGate. + # 0: ─input0─@───── 0: ────────@─────────output0─ + # │ ==> │ + # 1: ─input1─@^exp─ 1: ─input1─@^+/-exp──output1─ + output_cphase, output0, output1 = input0.after_cphase(cphase) + + # Step 2; similar to step 1, pull input1 through CZPowGate. + # 0: ─@──────────pulled0────output0─ 0: ─@────────output0─ + # ==> │ ==> │ + # 1: ─@^+/-exp───pulled1────output1─ 1: ─@^+/-exp─output1─ + output_cphase, pulled1, pulled0 = input1.after_cphase(output_cphase) + output0 = output0.merge_left(pulled0) + output1 = output1.merge_left(pulled1) + + return output_cphase, output0, output1 + + +_TARGET_GATESET: ops.Gateset = ops.Gateset(ops.CZPowGate) +_SUPPORTED_GATESET: ops.Gateset = ops.Gateset(ops.Pauli, ops.IdentityGate, ops.ZPowGate) + + +@frozen +class CPhaseGaugeTransformerMM(MultiMomentGaugeTransformer): + """A gauge transformer for the cphase gate.""" + + target: ops.GateFamily | ops.Gateset = field(default=_TARGET_GATESET, init=False) + supported_gates: ops.GateFamily | ops.Gateset = field(default=_SUPPORTED_GATESET) + + def sample_left_moment( + self, active_qubits: frozenset[ops.Qid], rng: np.random.Generator + ) -> circuits.Moment: + """Samples a random single-qubit moment to be inserted before the target block.""" + return circuits.Moment([cast(ops.Gate, rng.choice(_PAULIS)).on(q) for q in active_qubits]) + + def gauge_on_moments( + self, + moments_to_gauge: list[circuits.Moment], + prng: np.random.Generator = np.random.default_rng(), + ) -> list[circuits.Moment]: + """Gauges a block of moments that contains at least a cphase gate in each of the moment. + + Args: + moments_to_gauge: A list of moments to be gauged. + prng: A pseudorandom number generator. + + Returns: + A list of moments after gauging. + """ + active_qubits = circuits.Circuit.from_moments(*moments_to_gauge).all_qubits() + left_moment = self.sample_left_moment(active_qubits, prng) + pulled: dict[ops.Qid, _PauliAndZPow] = { + op.qubits[0]: _PauliAndZPow(pauli=cast(ops.Pauli | ops.IdentityGate, op.gate)) + for op in left_moment + if op.gate + } + ret: list[circuits.Moment] = [left_moment] + # The loop iterates through each moment of the target block, propagating + # the `pulled` gauge from left to right. In each iteration, `prev` holds + # the gauge to the left of the current `moment`, and the loop computes + # the transformed `moment` and the new `pulled` gauge to its right. + for moment in moments_to_gauge: + # Calculate --prev--moment-- ==> --updated_momment--pulled-- + prev = pulled + pulled = {} + ops_at_updated_moment: list[ops.Operation] = [] + for op in moment: + # Pull prev through ops at the moment. + if op.gate: + match op.gate: + case ops.CZPowGate(): + q0, q1 = op.qubits + new_gate, pulled[q0], pulled[q1] = _pull_through_single_cphase( + op.gate, prev[q0], prev[q1] + ) + ops_at_updated_moment.append(new_gate.on(q0, q1)) + case ops.Pauli() | ops.IdentityGate(): + q = op.qubits[0] + ops_at_updated_moment.append(op) + pulled[q] = prev[q].after_pauli(op.gate) + case ops.ZPowGate(): + q = op.qubits[0] + new_zpow, pulled[q] = prev[q].after_zpow(op.gate) + ops_at_updated_moment.append(new_zpow.on(q)) + case _: + raise ValueError(f"Gate type {type(op.gate)} is not supported.") + # Keep the other ops of prev + for q, gate in prev.items(): + if q not in pulled: + pulled[q] = gate + ret.append(circuits.Moment(ops_at_updated_moment)) + last_moment = circuits.Moment( + [gate.to_single_qubit_gate().on(q) for q, gate in pulled.items()] + ) + ret.append(last_moment) + return ret diff --git a/cirq-core/cirq/transformers/gauge_compiling/multi_moment_cphase_gauge_test.py b/cirq-core/cirq/transformers/gauge_compiling/multi_moment_cphase_gauge_test.py new file mode 100644 index 00000000000..321c93846fd --- /dev/null +++ b/cirq-core/cirq/transformers/gauge_compiling/multi_moment_cphase_gauge_test.py @@ -0,0 +1,243 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import numpy as np +import pytest + +import cirq +from cirq import I, X, Y, Z, ZPowGate +from cirq.transformers.gauge_compiling.multi_moment_cphase_gauge import ( + _PauliAndZPow, + CPhaseGaugeTransformerMM, +) + + +def test_gauge_on_single_cphase(): + """Test case. + Input: + 0: ───@─────── + │ + 1: ───@^0.2─── + Example output: + 0: ───X───@────────PhXZ(a=0,x=1,z=0)─── + │ + 1: ───I───@^-0.2───Z^0.2─────────────── + """ + q0, q1 = cirq.LineQubit.range(2) + + input_circuit = cirq.Circuit(cirq.Moment(cirq.CZ(q0, q1) ** 0.2)) + + class _TestCPhaseGaugeTransformerMM(CPhaseGaugeTransformerMM): + def sample_left_moment( + self, + active_qubits: frozenset[cirq.Qid], + rng: np.random.Generator = np.random.default_rng(), + ) -> cirq.Moment: + return cirq.Moment(g1(q0), g2(q1)) + + for g1 in [X, Y, Z, I]: + for g2 in [X, Y, Z, I]: # Test with all possible samples of the left moment. + cphase_transformer = _TestCPhaseGaugeTransformerMM() + output_circuit = cphase_transformer(input_circuit) + cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( + input_circuit, output_circuit, {q: q for q in input_circuit.all_qubits()} + ) + + +def test_gauge_on_cz_moments(): + """Test case. + Input: + ┌──┐ + 0: ───@────@─────H───────@───@─── + │ │ │ │ + 1: ───@────┼@────────────@───@─── + ││ + 2: ───@────@┼────────@───@───@─── + │ │ │ │ │ + 3: ───@─────@────────@───@───@─── + └──┘ + Example output: + ┌──┐ + 0: ───X───@────@─────PhXZ(a=0,x=1,z=1)──────H───X───────@───@───PhXZ(a=0,x=1,z=2)──── + │ │ │ │ + 1: ───I───@────┼@────Z──────────────────────────X───────@───@───PhXZ(a=2,x=1,z=-2)─── + ││ + 2: ───Y───@────@┼────PhXZ(a=1.5,x=1,z=-1)───────Z───@───@───@───Z──────────────────── + │ │ │ │ │ + 3: ───Z───@─────@────Z^0────────────────────────I───@───@───@───Z^0────────────────── + └──┘ + """ + q0, q1, q2, q3 = cirq.LineQubit.range(4) + input_circuit = cirq.Circuit( + cirq.Moment(cirq.CZ(q0, q1), cirq.CZ(q2, q3)), + cirq.Moment(cirq.CZ(q0, q2), cirq.CZ(q1, q3)), + cirq.Moment(cirq.H(q0)), + cirq.Moment(cirq.CZ(q2, q3)), + cirq.Moment(cirq.CZ(q0, q1), cirq.CZ(q2, q3)), + cirq.Moment(cirq.CZ(q0, q1), cirq.CZ(q2, q3)), + ) + transformer = CPhaseGaugeTransformerMM() + + output_circuit = transformer(input_circuit) + cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( + input_circuit, output_circuit, {q: q for q in input_circuit.all_qubits()} + ) + + +def test_is_target_moment(): + q0, q1, q2 = cirq.LineQubit.range(3) + + target_moments = [ + cirq.Moment(cirq.CZ(q0, q1) ** 0.2), + cirq.Moment(cirq.CZ(q0, q1) ** 0.2, cirq.X(q2)), + ] + non_target_moments = [ + cirq.Moment(cirq.X(q0), cirq.Y(q1)), + cirq.Moment(cirq.CZ(q0, q1) ** 0.2, cirq.Rz(rads=-0.8).on(q2)), + cirq.Moment(cirq.CZ(q0, q1).with_tags("ignore")), + cirq.Moment(cirq.CZ(q0, q1)).with_tags("ignore"), + ] + cphase_transformer = CPhaseGaugeTransformerMM(supported_gates=cirq.Gateset(cirq.Pauli)) + for m in target_moments: + assert cphase_transformer.is_target_moment(m) + for m in non_target_moments: + assert not cphase_transformer.is_target_moment( + m, cirq.TransformerContext(tags_to_ignore={'ignore'}) + ) + + +def test_gauge_on_cphase_moments(): + """Test case. + Input: + ┌──┐ + 0: ───@────────@─────H───Rz(-0.255π)───────────@───────@─────── + │ │ │ │ + 1: ───@^0.2────┼@──────────────────────────────@^0.1───@─────── + ││ + 2: ───@────────@┼────────@─────────────@───────@───────@─────── + │ │ │ │ │ │ + 3: ───@─────────@────────@^0.2─────────@^0.2───@───────@^0.2─── + └──┘ + Example output: + ┌──┐ + 0: ───Y───@─────────@─────PhXZ(a=0,x=1,z=0)───H───X───Rz(0.255π)────────────@───────@────────PhXZ(a=0,x=1,z=1.1)─── + │ │ │ │ + 1: ───I───@^-0.2────┼@────Z^0.2───────────────────Y─────────────────────────@^0.1───@────────PhXZ(a=0,x=1,z=0.1)─── + ││ + 2: ───X───@─────────@┼────PhXZ(a=0,x=1,z=1)───────X───@────────────@────────@───────@────────PhXZ(a=0,x=1,z=0)───── + │ │ │ │ │ │ + 3: ───Z───@──────────@────I───────────────────────I───@^-0.2───────@^-0.2───@───────@^-0.2───Z^-0.4──────────────── + └──┘ + """ # noqa: E501 + q0, q1, q2, q3 = cirq.LineQubit.range(4) + cphase_transformer = CPhaseGaugeTransformerMM() + for _ in range(5): + input_circuit = cirq.Circuit( + cirq.Moment(cirq.CZ(q0, q1) ** 0.2, cirq.CZ(q2, q3)), + cirq.Moment(cirq.CZ(q0, q2), cirq.CZ(q1, q3)), + cirq.Moment(cirq.H(q0)), + cirq.Moment(cirq.CZ(q2, q3) ** 0.2, cirq.Rz(rads=-0.8).on(q0)), + cirq.Moment(cirq.CZ(q2, q3) ** 0.2), + cirq.Moment(cirq.CZ(q0, q1) ** 0.1, cirq.CZ(q2, q3)), + cirq.Moment(cirq.CZ(q0, q1), cirq.CZ(q2, q3) ** 0.2), + ) + + output_circuit = cphase_transformer(input_circuit) + cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( + input_circuit, output_circuit, {q: q for q in input_circuit.all_qubits()} + ) + + +def test_gauge_on_czpow_only_moments(): + q0, q1, q2 = cirq.LineQubit.range(3) + + input_circuit = cirq.Circuit(cirq.Moment(cirq.CZ(q0, q1) ** 0.2, X(q2))) + cphase_transformer = CPhaseGaugeTransformerMM(supported_gates=cirq.Gateset()) + output_circuit = cphase_transformer(input_circuit) + + # Since X isn't in supported_gates, the moment won't be gauged. + assert input_circuit == output_circuit + + +def test_gauge_on_supported_gates(): + q0, q1, q2, q3 = cirq.LineQubit.range(4) + cphase_transformer = CPhaseGaugeTransformerMM() + for g1 in [X, Z**0.6, I, Z]: + for g2 in [Y, cirq.Rz(rads=0.2), Z**0.7]: + input_circuit = cirq.Circuit( + cirq.Moment(cirq.CZ(q0, q1) ** 0.2, g1(q2), g2(q3)), + cirq.Moment(cirq.CZ(q0, q2), g2(q1), g1(q3)), + ) + output_circuit = cphase_transformer(input_circuit) + cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( + input_circuit, output_circuit, {q: q for q in input_circuit.all_qubits()} + ) + + +def test_gauge_on_unsupported_gates(): + q0, q1, q2, q3 = cirq.LineQubit.range(4) + + cphase_transformer = CPhaseGaugeTransformerMM(supported_gates=cirq.Gateset(cirq.CNOT)) + with pytest.raises(ValueError, match="Gate type .* is not supported."): + cphase_transformer(cirq.Circuit(cirq.CNOT(q0, q1), cirq.CZ(q2, q3))) + + +def test_pauli_and_phxz_util_str(): + assert str(_PauliAndZPow(pauli=X)) == '─X──Z**0─' + assert str(_PauliAndZPow(pauli=X, zpow=Z**0.1)) == '─X──Z**0.1─' + + +def test_pauli_and_phxz_util_gate_merges(): + """Tests _PauliAndZPow's merge_left() and merge_right().""" + for left_pauli in [X, Y, Z, I]: + for right_pauli in [X, Y, Z, I]: + left = _PauliAndZPow(pauli=left_pauli, zpow=ZPowGate(exponent=0.2)) + right = _PauliAndZPow(pauli=right_pauli, zpow=ZPowGate(exponent=0.6)) + merge1 = right.merge_left(left) + merge2 = left.merge_right(right) + + assert np.allclose( + cirq.unitary(merge1.to_single_qubit_gate()), + cirq.unitary(merge2.to_single_qubit_gate()), + ) + q = cirq.LineQubit(0) + cirq.testing.assert_allclose_up_to_global_phase( + cirq.unitary( + cirq.Circuit( + left.to_single_qubit_gate().on(q), right.to_single_qubit_gate().on(q) + ) + ), + cirq.unitary(merge1.to_single_qubit_gate()), + atol=1e-6, + ) + + +def test_pauli_and_phxz_util_to_1q_gate(): + """Tests _PauliAndZPow.to_single_qubit_gate().""" + q = cirq.LineQubit(0) + for pauli in [cirq.X, cirq.Y, cirq.Z, cirq.I]: + for zpow in [cirq.ZPowGate(exponent=exp) for exp in [0, 0.1, 0.5, 1, 10.2]]: + cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( + cirq.Circuit(pauli(q), zpow(q)), + cirq.Circuit(_PauliAndZPow(pauli=pauli, zpow=zpow).to_single_qubit_gate().on(q)), + {q: q}, + ) + + +def test_deep_not_supported(): + with pytest.raises(ValueError, match="GaugeTransformer cannot be used with deep=True"): + t = CPhaseGaugeTransformerMM() + t(cirq.Circuit(), context=cirq.TransformerContext(deep=True)) diff --git a/cirq-core/cirq/transformers/gauge_compiling/multi_moment_gauge_compiling.py b/cirq-core/cirq/transformers/gauge_compiling/multi_moment_gauge_compiling.py new file mode 100644 index 00000000000..8bfc6cc9434 --- /dev/null +++ b/cirq-core/cirq/transformers/gauge_compiling/multi_moment_gauge_compiling.py @@ -0,0 +1,127 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the abstraction for multi-moment gauge compiling as a cirq transformer.""" + +import abc + +import attrs +import numpy as np + +from cirq import circuits, ops +from cirq.transformers import transformer_api + + +@transformer_api.transformer +@attrs.frozen +class MultiMomentGaugeTransformer(abc.ABC): + """A gauge transformer that wraps target blocks of moments with single-qubit gates. + + In detail, a "gauging moment" of single-qubit gates is inserted before a target block of + moments. These gates are then commuted through the block, resulting in a corresponding + moment of gates after it. + + q₀: ... ───LG0───╭───────────╮────RG0───... + │ │ + q₁: ... ───LG1───┤ moments ├────RG1───... + │ to be │ + q₂: ... ───LG2───┤ gauged on ├────RG2───... + │ │ + q₃: ... ───LG3───╰───────────╯────RG3───... + + Attributes: + target: The target gate, gate family or gateset, must exist in each of the moment in + the "moments to be gauged". + supported_gates: The gates that are supported in the "moments to be gauged". + """ + + target: ops.GateFamily | ops.Gateset + supported_gates: ops.GateFamily | ops.Gateset + + @abc.abstractmethod + def gauge_on_moments(self, moments_to_gauge: list[circuits.Moment]) -> list[circuits.Moment]: + """Gauges a block of moments. + + Args: + moments_to_gauge: A list of moments to be gauged. + + Returns: + A list of moments after gauging. + """ + + @abc.abstractmethod + def sample_left_moment( + self, active_qubits: frozenset[ops.Qid], rng: np.random.Generator + ) -> circuits.Moment: + """Samples a random single-qubit moment to be inserted before the target block. + + Args: + active_qubits: The qubits on which the sampled gates should be applied. + rng: A pseudorandom number generator. + + Returns: + The sampled moment. + """ + + def is_target_moment( + self, moment: circuits.Moment, context: transformer_api.TransformerContext | None = None + ) -> bool: + """Checks if a moment is a target for gauging. + + A moment is a target moment if it contains at least one target op and + all its operations are supported by this transformer. + """ + # skip the moment if the moment is tagged to be ignored + if context and set(moment.tags).intersection(context.tags_to_ignore): + return False + + has_target_gates: bool = False + for op in moment: + if ( + context + and isinstance(op, ops.TaggedOperation) + and set(op.tags).intersection(context.tags_to_ignore) + ): # skip the moment if the op is tagged to be ignored + return False + if op.gate: + if op in self.target: + has_target_gates = True + elif op not in self.supported_gates: + return False + return has_target_gates + + def __call__( + self, + circuit: circuits.AbstractCircuit, + *, + context: transformer_api.TransformerContext | None = None, + ) -> circuits.AbstractCircuit: + if context is None: + context = transformer_api.TransformerContext(deep=False) + if context.deep: + raise ValueError('GaugeTransformer cannot be used with deep=True') + output_moments: list[circuits.Moment] = [] + moments_to_gauge: list[circuits.Moment] = [] + for moment in circuit: + if self.is_target_moment(moment, context): + moments_to_gauge.append(moment) + else: + if moments_to_gauge: + output_moments.extend(self.gauge_on_moments(moments_to_gauge)) + moments_to_gauge.clear() + output_moments.append(moment) + if moments_to_gauge: + output_moments.extend(self.gauge_on_moments(moments_to_gauge)) + + return circuits.Circuit.from_moments(*output_moments)