Skip to content

Multi-layer CZGaugeTransformer #7330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cirq-core/cirq/transformers/gauge_compiling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
SpinInversionGaugeTransformer as SpinInversionGaugeTransformer,
)

from cirq.transformers.gauge_compiling.cz_gauge import CZGaugeTransformer as CZGaugeTransformer
from cirq.transformers.gauge_compiling.cz_gauge import (
CZGaugeTransformer as CZGaugeTransformer,
CZGaugeTransformerML as CZGaugeTransformerML,
)

from cirq.transformers.gauge_compiling.iswap_gauge import (
ISWAPGaugeTransformer as ISWAPGaugeTransformer,
Expand Down
35 changes: 34 additions & 1 deletion cirq-core/cirq/transformers/gauge_compiling/cz_gauge.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

"""A Gauge Transformer for the CZ gate."""

from cirq import ops
from typing import List

import numpy as np

from cirq import circuits, ops
from cirq.ops.common_gates import CZ
from cirq.transformers.gauge_compiling.gauge_compiling import (
ConstantGauge,
Expand Down Expand Up @@ -43,4 +47,33 @@
]
)


def _multi_layer_pull_through_cz(
moments: List[circuits.Moment], rng: np.random.Generator
) -> List[circuits.Moment]:
# Check all the ops are CZ first
if not all(op.gate == CZ for moment in moments for op in moment):
raise ValueError(f"Input moments must only contain CZ gates:\nmoments = {moments}.")

left: List[ops.Operation] = [
rng.choice([ops.I, ops.X, ops.Y, ops.Z]).on(q)
for q in circuits.Circuit(moments).all_qubits()
]
if not left:
return moments

ps: ops.PauliString = ops.PauliString(left)
pulled_through: ops.PauliString = ps.after(moments)
ret = [circuits.Moment(left)] + moments
ret.append(circuits.Moment([pauli_gate(q) for q, pauli_gate in pulled_through.items()]))
return ret


CZGaugeTransformer = GaugeTransformer(target=CZ, gauge_selector=CZGaugeSelector)

# Multi-layer pull through version of CZGaugeTransformer
CZGaugeTransformerML = GaugeTransformer(
target=CZ,
gauge_selector=CZGaugeSelector,
multi_layer_pull_thourgh_fn=_multi_layer_pull_through_cz,
)
43 changes: 42 additions & 1 deletion cirq-core/cirq/transformers/gauge_compiling/cz_gauge_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

import cirq
from cirq.transformers.gauge_compiling import CZGaugeTransformer
from cirq.transformers.gauge_compiling import CZGaugeTransformer, CZGaugeTransformerML
from cirq.transformers.gauge_compiling.gauge_compiling_test_utils import GaugeTester


class TestCZGauge(GaugeTester):
two_qubit_gate = cirq.CZ
gauge_transformer = CZGaugeTransformer


def test_multi_layer_pull_through():
"""Test case.
Input:
┌──┐
0: ───@────@─────H───────@───@───
│ │ │ │
1: ───@────┼@────────────@───@───
││
2: ───@────@┼────────@───@───@───
│ │ │ │ │
3: ───@─────@────────@───@───@───
└──┘
An example output:
┌──┐
0: ───Z───@────@─────Z───H───X───────@───@───X───
│ │ │ │
1: ───Y───@────┼@────X───────I───────@───@───────
││
2: ───Y───@────@┼────X───────I───@───@───@───────
│ │ │ │ │
3: ───X───@─────@────X───────I───@───@───@───────
└──┘
"""
q0, q1, q2, q3 = cirq.LineQubit.range(4)
input_circuit = cirq.Circuit(
cirq.Moment(cirq.CZ(q0, q1), cirq.CZ(q2, q3)),
cirq.Moment(cirq.CZ(q0, q2), cirq.CZ(q1, q3)),
cirq.Moment(cirq.H(q0)),
cirq.Moment(cirq.CZ(q2, q3)),
cirq.Moment(cirq.CZ(q0, q1), cirq.CZ(q2, q3)),
cirq.Moment(cirq.CZ(q0, q1), cirq.CZ(q2, q3)),
)
transformer = CZGaugeTransformerML

output_circuit = transformer(input_circuit, prng=np.random.default_rng())
cirq.testing.assert_circuits_have_same_unitary_given_final_permutation(
input_circuit, output_circuit, {q: q for q in input_circuit.all_qubits()}
)
20 changes: 20 additions & 0 deletions cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def __init__(
target: Union[ops.Gate, ops.Gateset, ops.GateFamily],
gauge_selector: Callable[[np.random.Generator], Gauge],
two_qubit_gate_symbolizer: Optional[TwoQubitGateSymbolizer] = None,
multi_layer_pull_thourgh_fn: Optional[
Callable[[List[circuits.Moment], List[circuits.Moment]], List[circuits.Moment]]
] = None,
) -> None:
"""Constructs a GaugeTransformer.

Expand All @@ -208,6 +211,7 @@ def __init__(
self.target = ops.GateFamily(target) if isinstance(target, ops.Gate) else target
self.gauge_selector = gauge_selector
self.two_qubit_gate_symbolizer = two_qubit_gate_symbolizer
self.multi_layer_pull_thourgh_fn = multi_layer_pull_thourgh_fn

def __call__(
self,
Expand All @@ -224,7 +228,21 @@ def __call__(
new_moments = []
left: List[List[ops.Operation]] = []
right: List[List[ops.Operation]] = []
all_target_moments: List[circuits.Moment] = []

for moment in circuit:
if self.multi_layer_pull_thourgh_fn and all(
[
op in self.target and not set(op.tags).intersection(context.tags_to_ignore)
for op in moment
]
): # all ops are target 2-qubit gates
all_target_moments.append(moment)
continue
if all_target_moments:
new_moments.extend(self.multi_layer_pull_thourgh_fn(all_target_moments, rng))
all_target_moments.clear()

left.clear()
right.clear()
center: List[ops.Operation] = []
Expand All @@ -247,6 +265,8 @@ def __call__(
new_moments.append(center)
if right:
new_moments.extend(_build_moments(right))
if all_target_moments:
new_moments.extend(self.multi_layer_pull_thourgh_fn(all_target_moments, rng))
return circuits.Circuit.from_moments(*new_moments)

def as_sweep(
Expand Down
Loading