-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Support merging 1-qubit gates in transformers for parameterized circuits #7149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
11104cc
f45f90e
a48d4bb
1ddf7ce
3c46507
ecb6a83
a61b1e3
830d2eb
b09d33c
dac4e32
2982730
37163c0
b680e50
875179b
414adda
f197a89
807f35f
69ad9eb
e67e661
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,11 +14,22 @@ | |
|
||
"""Transformer passes to combine adjacent single-qubit rotations.""" | ||
|
||
from typing import Optional, TYPE_CHECKING | ||
from typing import Callable, cast, Dict, Hashable, List, Optional, Tuple, TYPE_CHECKING | ||
|
||
import sympy | ||
|
||
from cirq import circuits, ops, protocols | ||
from cirq.transformers import merge_k_qubit_gates, transformer_api, transformer_primitives | ||
from cirq.study.resolver import ParamResolver | ||
from cirq.study.sweeps import dict_to_zip_sweep, ListSweep, ProductOrZipSweepLike, Sweep, Zip | ||
from cirq.transformers import ( | ||
align, | ||
merge_k_qubit_gates, | ||
symbolize, | ||
transformer_api, | ||
transformer_primitives, | ||
) | ||
from cirq.transformers.analytical_decompositions import single_qubit_decompositions | ||
from cirq.transformers.tag_transformers import index_tags, remove_tags | ||
|
||
if TYPE_CHECKING: | ||
import cirq | ||
|
@@ -65,6 +76,7 @@ def merge_single_qubit_gates_to_phxz( | |
circuit: 'cirq.AbstractCircuit', | ||
*, | ||
context: Optional['cirq.TransformerContext'] = None, | ||
merge_tags_fn: Optional[Callable[['cirq.CircuitOperation'], List[Hashable]]] = None, | ||
atol: float = 1e-8, | ||
) -> 'cirq.Circuit': | ||
"""Replaces runs of single qubit rotations with a single optional `cirq.PhasedXZGate`. | ||
|
@@ -75,19 +87,21 @@ def merge_single_qubit_gates_to_phxz( | |
Args: | ||
circuit: Input circuit to transform. It will not be modified. | ||
context: `cirq.TransformerContext` storing common configurable options for transformers. | ||
merge_tags_fn: A callable returns the tags to be added to the merged operation. | ||
atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be | ||
dropped, smaller values increase accuracy. | ||
|
||
Returns: | ||
Copy of the transformed input circuit. | ||
""" | ||
|
||
def rewriter(op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE': | ||
u = protocols.unitary(op) | ||
if protocols.num_qubits(op) == 0: | ||
def rewriter(circuit_op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE': | ||
u = protocols.unitary(circuit_op) | ||
if protocols.num_qubits(circuit_op) == 0: | ||
return ops.GlobalPhaseGate(u[0, 0]).on() | ||
gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(u, atol) | ||
return gate(op.qubits[0]) if gate else [] | ||
gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(u, atol) or ops.I | ||
phxz_op = gate.on(circuit_op.qubits[0]) | ||
return phxz_op.with_tags(*merge_tags_fn(circuit_op)) if merge_tags_fn else phxz_op | ||
|
||
return merge_k_qubit_gates.merge_k_qubit_unitaries( | ||
circuit, k=1, context=context, rewriter=rewriter | ||
|
@@ -152,3 +166,191 @@ def merge_func(m1: 'cirq.Moment', m2: 'cirq.Moment') -> Optional['cirq.Moment']: | |
deep=context.deep if context else False, | ||
tags_to_ignore=tuple(tags_to_ignore), | ||
).unfreeze(copy=False) | ||
|
||
|
||
def _sweep_on_symbols(sweep: Sweep, symbols: set[sympy.Symbol]) -> Sweep: | ||
new_resolvers: List['cirq.ParamResolver'] = [] | ||
for resolver in sweep: | ||
param_dict: 'cirq.ParamMappingType' = {s: resolver.value_of(s) for s in symbols} | ||
new_resolvers.append(ParamResolver(param_dict)) | ||
return ListSweep(new_resolvers) | ||
|
||
|
||
def _parameterize_phxz_in_circuits( | ||
circuit_list: List['cirq.Circuit'], | ||
merge_tag_prefix: str, | ||
phxz_symbols: set[sympy.Symbol], | ||
remaining_symbols: set[sympy.Symbol], | ||
sweep: Sweep, | ||
) -> Sweep: | ||
"""Parameterizes the circuits and returns a new sweep.""" | ||
values_by_params: Dict[str, List[float]] = {**{str(s): [] for s in phxz_symbols}} | ||
|
||
for circuit in circuit_list: | ||
for op in circuit.all_operations(): | ||
the_merge_tag: Optional[str] = None | ||
for tag in op.tags: | ||
if str(tag).startswith(merge_tag_prefix): | ||
the_merge_tag = str(tag) | ||
if not the_merge_tag: | ||
continue | ||
sid = the_merge_tag.rsplit("_", maxsplit=-1)[-1] | ||
x, z, a = 0.0, 0.0, 0.0 # Identity gate's parameters | ||
if isinstance(op.gate, ops.PhasedXZGate): | ||
x, z, a = op.gate.x_exponent, op.gate.z_exponent, op.gate.axis_phase_exponent | ||
elif op.gate is not ops.I: | ||
raise RuntimeError( | ||
f"Expected the merged gate to be a PhasedXZGate or IdentityGate," | ||
f" but got {op.gate}." | ||
) | ||
values_by_params[f"x{sid}"].append(x) | ||
values_by_params[f"z{sid}"].append(z) | ||
values_by_params[f"a{sid}"].append(a) | ||
|
||
return Zip( | ||
dict_to_zip_sweep(cast(ProductOrZipSweepLike, values_by_params)), | ||
_sweep_on_symbols(sweep, remaining_symbols), | ||
) | ||
|
||
|
||
def _all_tags_startswith(circuit: 'cirq.AbstractCircuit', startswith: str): | ||
tag_set: set[Hashable] = set() | ||
for op in circuit.all_operations(): | ||
for tag in op.tags: | ||
if str(tag).startswith(startswith): | ||
tag_set.add(tag) | ||
return tag_set | ||
|
||
|
||
def merge_single_qubit_gates_to_phxz_symbolized( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this be folded under the existing Also on a quick read - is the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here are some caveats that we need to take into consideration before making the decision of whether we want to fold symbolized function into the existing one,
will need to be transformed into
in which While if sweep is supplied we may do { To summarize, there are 2 options, Option 1, fold symbolized version into existing functions, the interface will be something like
Option 2, separate 2 functions, the symbolized version doesn't necessary need to follow the transformer decorator, though we can still modify the general transformer definition cc @eliottrosenberg @NoureldinYosri for thoughts and suggestions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pavoljuhas , the impl in this PR doesn't change the transformers' structure (which is option 2 above). If we want to fold the implementations into the existing function, we can always do that later in another PR to modify the transformer interface, wdyt? |
||
circuit: 'cirq.AbstractCircuit', | ||
*, | ||
context: Optional['cirq.TransformerContext'] = None, | ||
sweep: Sweep, | ||
atol: float = 1e-8, | ||
) -> Tuple['cirq.Circuit', Sweep]: | ||
"""Merges consecutive single qubit gates as PhasedXZ Gates. Symbolizes if any of | ||
the consecutive gates is symbolized. | ||
|
||
Example: | ||
>>> q0, q1 = cirq.LineQubit.range(2) | ||
>>> c = cirq.Circuit(\ | ||
cirq.X(q0),\ | ||
cirq.CZ(q0,q1)**sympy.Symbol("cz_exp"),\ | ||
cirq.Y(q0)**sympy.Symbol("y_exp"),\ | ||
cirq.X(q0)) | ||
>>> print(c) | ||
0: ───X───@──────────Y^y_exp───X─── | ||
│ | ||
1: ───────@^cz_exp───────────────── | ||
>>> new_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized(\ | ||
c, sweep=cirq.Zip(cirq.Points(key="cz_exp", points=[0, 1]),\ | ||
cirq.Points(key="y_exp", points=[0, 1]))) | ||
>>> print(new_circuit) | ||
0: ───PhXZ(a=-1,x=1,z=0)───@──────────PhXZ(a=a0,x=x0,z=z0)─── | ||
│ | ||
1: ────────────────────────@^cz_exp────────────────────────── | ||
>>> assert new_sweep[0] == cirq.ParamResolver({'a0': -1, 'x0': 1, 'z0': 0, 'cz_exp': 0}) | ||
>>> assert new_sweep[1] == cirq.ParamResolver({'a0': -0.5, 'x0': 0, 'z0': -1, 'cz_exp': 1}) | ||
|
||
Args: | ||
circuit: Input circuit to transform. It will not be modified. | ||
context: `cirq.TransformerContext` storing common configurable options for transformers. | ||
sweep: Sweep of the symbols in the input circuit, updated Sweep will be returned | ||
based on the transformation. | ||
atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be | ||
dropped, smaller values increase accuracy. | ||
|
||
Returns: | ||
Copy of the transformed input circuit. | ||
""" | ||
deep = context.deep if context else False | ||
|
||
# Tag symbolized single-qubit op. | ||
symbolized_single_tag = "TMP-TAG-symbolized-single" | ||
|
||
circuit_tagged = transformer_primitives.map_operations( | ||
circuit, | ||
lambda op, _: ( | ||
op.with_tags(symbolized_single_tag) | ||
if protocols.is_parameterized(op) and len(op.qubits) == 1 | ||
else op | ||
), | ||
deep=deep, | ||
) | ||
|
||
# Step 0, isolate single qubit symbols and resolve the circuit on them. | ||
single_qubit_gate_symbols: set[sympy.Symbol] = set().union( | ||
*[ | ||
protocols.parameter_symbols(op) if symbolized_single_tag in op.tags else set() | ||
for op in circuit_tagged.all_operations() | ||
] | ||
) | ||
# If all single qubit gates are not parameterized, call the nonparamerized version of | ||
# the transformer. | ||
if not single_qubit_gate_symbols: | ||
return (merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol), sweep) | ||
sweep_of_single: Sweep = _sweep_on_symbols(sweep, single_qubit_gate_symbols) | ||
# Get all resolved circuits from all sets of resolvers in sweep_of_single. | ||
resolved_circuits = [ | ||
protocols.resolve_parameters(circuit_tagged, resolver) for resolver in sweep_of_single | ||
] | ||
|
||
# Step 1, merge single qubit gates per resolved circuit, preserving | ||
# the symbolized_single_tag with indexes. | ||
merged_circuits: List['cirq.Circuit'] = [] | ||
for resolved_circuit in resolved_circuits: | ||
merged_circuit = index_tags( | ||
merge_single_qubit_gates_to_phxz( | ||
resolved_circuit, | ||
context=context, | ||
merge_tags_fn=lambda circuit_op: ( | ||
[symbolized_single_tag] | ||
if any( | ||
symbolized_single_tag in set(op.tags) | ||
for op in circuit_op.circuit.all_operations() | ||
) | ||
else [] | ||
), | ||
atol=atol, | ||
), | ||
target_tags={symbolized_single_tag}, | ||
context=context, | ||
) | ||
merged_circuits.append(merged_circuit) | ||
|
||
if not all( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this should happen here ... this should be a test for the correctness of the transformer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's to validate the input parameters If we don't check here, Step 2 of parameterization would possibly crash with different kind of errors. I should probably update the error message as |
||
_all_tags_startswith(merged_circuits[0], startswith=symbolized_single_tag) | ||
== _all_tags_startswith(merged_circuit, startswith=symbolized_single_tag) | ||
for merged_circuit in merged_circuits | ||
): | ||
raise RuntimeError("Different resolvers in sweep resulted in different merged structures.") | ||
|
||
# Step 2, get the new symbolized circuit by symbolization on indexed symbolized_single_tag. | ||
new_circuit = align.align_right( | ||
remove_tags( | ||
symbolize.symbolize_single_qubit_gates_by_indexed_tags( | ||
merged_circuits[0], tag_prefix=symbolized_single_tag | ||
), | ||
remove_if=lambda tag: str(tag).startswith(symbolized_single_tag), | ||
) | ||
) | ||
|
||
# Step 3, get N sets of parameterizations as new_sweep. | ||
phxz_symbols: set[sympy.Symbol] = set().union( | ||
*[ | ||
set( | ||
[sympy.Symbol(tag.replace(f"{symbolized_single_tag}_", s)) for s in ["x", "z", "a"]] | ||
) | ||
for tag in _all_tags_startswith(merged_circuits[0], startswith=symbolized_single_tag) | ||
] | ||
) | ||
# Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged. | ||
remaining_symbols: set[sympy.Symbol] = set( | ||
protocols.parameter_symbols(circuit) - single_qubit_gate_symbols | ||
) | ||
new_sweep = _parameterize_phxz_in_circuits( | ||
merged_circuits, symbolized_single_tag, phxz_symbols, remaining_symbols, sweep | ||
) | ||
|
||
return new_circuit, new_sweep |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this is a bit too general ... what is the problem with recieving a list of tags ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kind of need a function here as what I need to do is given a circuit op, set the output tags with more rule based tag setters and I believe the flexibility of the function can help users in different use cases.
rewriter
CircuitOperation(X['tag_needed'] -- Y['tag1']) --> phxz(x,z,a) with no tags
case 1: CircuitOperation(X['tag_needed'] -- Y['tag1']) --'tag_needed' presented--> phxz(...)['phxz_{iter}']
case 2: CircuitOperation(X -- Z) --no 'tags_needed' found --> phxz(...) with no tags
CircuitOperation(X['tag0'] -- Y['tag1']) --> phxz(...)['tag0', 'tag1']