Skip to content

[Python Compiler] Prototype for the merge_rotations transform implemented with xDSL #7364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: xdsl-cancel-inverses
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions pennylane/compiler/python_compiler/merge_rotations_xdsl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2018-2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This file contains the implementation of the merge_rotations transform,
written using xDSL."""

from dataclasses import dataclass

from xdsl import context, passes, pattern_rewriter
from xdsl.dialects import arith, builtin, func
from xdsl.ir import Operation
from xdsl.rewriter import InsertPoint

from pennylane.ops.qubit.attributes import composable_rotations

from .quantum_dialect import CustomOp


def _can_merge(op: CustomOp, next_op: Operation) -> bool:
if isinstance(next_op, CustomOp):
if op.gate_name.data == next_op.gate_name.data:
if op.out_qubits == next_op.in_qubits and op.out_ctrl_qubits == next_op.in_ctrl_qubits:
return True

return False


class MergeRotationsPattern(
pattern_rewriter.RewritePattern
): # pylint: disable=too-few-public-methods
"""RewritePattern for merging consecutive composable rotations."""

@pattern_rewriter.op_type_rewrite_pattern
def match_and_rewrite(
self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter
): # pylint: disable=arguments-differ
"""Implementation of rewriting FuncOps that may contain operations corresponding to
consecutive composable rotations."""
for op in funcOp.body.walk():
if not isinstance(op, CustomOp):
continue

gate_name = op.gate_name.data
if gate_name not in composable_rotations or gate_name == "Rot":
# Can handle all composible rotations except Rot... for now
Comment on lines +55 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could write a different pattern for Rot if it is difficult to fit here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pattern for Rot requires non-trivial numerical computations, which I didn't want to bother with yet 😅

We can come back to it later, but I think a cool way to implement it would be to generate MLIR for the function that we use to fuse the rotation angles using jax.jit for merging Rot gates and use that in the pattern.

continue

param = op.operands[0]
while True:
next_user = None
for use in op.results[0].uses:
user = use.operation
if _can_merge(op, user):
next_user = user
break

if next_user is None:
break

for q1, q2 in zip(op.in_qubits, op.out_qubits, strict=True):
rewriter.replace_all_uses_with(q2, q1)
for cq1, cq2 in zip(op.in_ctrl_qubits, op.out_ctrl_qubits, strict=True):
rewriter.replace_all_uses_with(cq2, cq1)

rewriter.erase_op(op)
next_param = next_user.operands[0]
addOp = arith.AddfOp(param, next_param)
rewriter.insert_op(addOp, InsertPoint.before(next_user))
param = addOp.result
new_op = CustomOp(
operands=(param, next_user.in_qubits[0], None, None),
properties=next_user.properties,
attributes=next_user.attributes,
successors=next_user.successors,
regions=next_user.regions,
result_types=(next_user.result_types, []),
)
Comment on lines +81 to +88
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to, either in this PR or in a future one, we should create a custom initialization procedure for CustomOp.

rewriter.replace_op(next_user, new_op)
op = new_op


@dataclass(frozen=True)
class MergeRotationsPass(passes.ModulePass):
"""Pass for merging consecutive composable rotation gates."""

name = "merge-rotations"

# pylint: disable=arguments-renamed
def apply(self, ctx: context.MLContext, module: builtin.ModuleOp) -> None:
"""Apply the merge rotations pass."""
pattern_rewriter.PatternRewriteWalker(
pattern_rewriter.GreedyRewritePatternApplier([MergeRotationsPattern()])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pattern_rewriter.GreedyRewritePatternApplier([MergeRotationsPattern()])
pattern_rewriter.GreedyRewritePatternApplier([MergeRotXRotYRotZPattern(), MergeRotPattern()])

Once you have the pattern to match only for Rot. If you think this would be a good design, but I have the freedom to implement it inside the current class.

).rewrite_module(module)