-
Notifications
You must be signed in to change notification settings - Fork 658
[Python Compiler] Prototype for the merge_rotations
transform implemented with xDSL
#7364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: xdsl-cancel-inverses
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,104 @@ | ||||||
# Copyright 2018-2025 Xanadu Quantum Technologies Inc. | ||||||
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
# you may not use this file except in compliance with the License. | ||||||
# You may obtain a copy of the License at | ||||||
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|
||||||
# Unless required by applicable law or agreed to in writing, software | ||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
|
||||||
"""This file contains the implementation of the merge_rotations transform, | ||||||
written using xDSL.""" | ||||||
|
||||||
from dataclasses import dataclass | ||||||
|
||||||
from xdsl import context, passes, pattern_rewriter | ||||||
from xdsl.dialects import arith, builtin, func | ||||||
from xdsl.ir import Operation | ||||||
from xdsl.rewriter import InsertPoint | ||||||
|
||||||
from pennylane.ops.qubit.attributes import composable_rotations | ||||||
|
||||||
from .quantum_dialect import CustomOp | ||||||
|
||||||
|
||||||
def _can_merge(op: CustomOp, next_op: Operation) -> bool: | ||||||
if isinstance(next_op, CustomOp): | ||||||
if op.gate_name.data == next_op.gate_name.data: | ||||||
if op.out_qubits == next_op.in_qubits and op.out_ctrl_qubits == next_op.in_ctrl_qubits: | ||||||
return True | ||||||
|
||||||
return False | ||||||
|
||||||
|
||||||
class MergeRotationsPattern( | ||||||
pattern_rewriter.RewritePattern | ||||||
): # pylint: disable=too-few-public-methods | ||||||
"""RewritePattern for merging consecutive composable rotations.""" | ||||||
|
||||||
@pattern_rewriter.op_type_rewrite_pattern | ||||||
def match_and_rewrite( | ||||||
self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter | ||||||
): # pylint: disable=arguments-differ | ||||||
"""Implementation of rewriting FuncOps that may contain operations corresponding to | ||||||
consecutive composable rotations.""" | ||||||
for op in funcOp.body.walk(): | ||||||
if not isinstance(op, CustomOp): | ||||||
continue | ||||||
|
||||||
gate_name = op.gate_name.data | ||||||
if gate_name not in composable_rotations or gate_name == "Rot": | ||||||
# Can handle all composible rotations except Rot... for now | ||||||
continue | ||||||
|
||||||
param = op.operands[0] | ||||||
while True: | ||||||
next_user = None | ||||||
for use in op.results[0].uses: | ||||||
user = use.operation | ||||||
if _can_merge(op, user): | ||||||
next_user = user | ||||||
break | ||||||
|
||||||
if next_user is None: | ||||||
break | ||||||
|
||||||
for q1, q2 in zip(op.in_qubits, op.out_qubits, strict=True): | ||||||
rewriter.replace_all_uses_with(q2, q1) | ||||||
for cq1, cq2 in zip(op.in_ctrl_qubits, op.out_ctrl_qubits, strict=True): | ||||||
rewriter.replace_all_uses_with(cq2, cq1) | ||||||
|
||||||
rewriter.erase_op(op) | ||||||
next_param = next_user.operands[0] | ||||||
addOp = arith.AddfOp(param, next_param) | ||||||
rewriter.insert_op(addOp, InsertPoint.before(next_user)) | ||||||
param = addOp.result | ||||||
new_op = CustomOp( | ||||||
operands=(param, next_user.in_qubits[0], None, None), | ||||||
properties=next_user.properties, | ||||||
attributes=next_user.attributes, | ||||||
successors=next_user.successors, | ||||||
regions=next_user.regions, | ||||||
result_types=(next_user.result_types, []), | ||||||
) | ||||||
Comment on lines
+81
to
+88
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you want to, either in this PR or in a future one, we should create a custom initialization procedure for CustomOp. |
||||||
rewriter.replace_op(next_user, new_op) | ||||||
op = new_op | ||||||
|
||||||
|
||||||
@dataclass(frozen=True) | ||||||
class MergeRotationsPass(passes.ModulePass): | ||||||
"""Pass for merging consecutive composable rotation gates.""" | ||||||
|
||||||
name = "merge-rotations" | ||||||
|
||||||
# pylint: disable=arguments-renamed | ||||||
def apply(self, ctx: context.MLContext, module: builtin.ModuleOp) -> None: | ||||||
"""Apply the merge rotations pass.""" | ||||||
pattern_rewriter.PatternRewriteWalker( | ||||||
pattern_rewriter.GreedyRewritePatternApplier([MergeRotationsPattern()]) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Once you have the pattern to match only for Rot. If you think this would be a good design, but I have the freedom to implement it inside the current class. |
||||||
).rewrite_module(module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could write a different pattern for Rot if it is difficult to fit here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pattern for
Rot
requires non-trivial numerical computations, which I didn't want to bother with yet 😅We can come back to it later, but I think a cool way to implement it would be to generate MLIR for the function that we use to fuse the rotation angles using
jax.jit
for mergingRot
gates and use that in the pattern.