From c7405ce7d26ea45cc64ea17c71ee6ddcf93fa1fa Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Thu, 1 May 2025 09:44:07 -0400 Subject: [PATCH 1/4] E.C. From 49b34400b0a172d00ef1a3f6277d4749405b1124 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Thu, 1 May 2025 09:59:27 -0400 Subject: [PATCH 2/4] Skeleton of implementation --- .../python_compiler/merge_rotations_xdsl.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 pennylane/compiler/python_compiler/merge_rotations_xdsl.py diff --git a/pennylane/compiler/python_compiler/merge_rotations_xdsl.py b/pennylane/compiler/python_compiler/merge_rotations_xdsl.py new file mode 100644 index 00000000000..ab3dfe91423 --- /dev/null +++ b/pennylane/compiler/python_compiler/merge_rotations_xdsl.py @@ -0,0 +1,40 @@ +# Copyright 2018-2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This file contains the implementation of the merge_rotations transform, +written using xDSL.""" + +from dataclasses import dataclass + +from xdsl import context, passes, pattern_rewriter +from xdsl.dialects import builtin, func + +from .quantum_dialect import CustomOp + + +class MergeRotationsSingleQubitPattern(pattern_rewriter.RewritePattern): + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter): + + pass + + +@dataclass(frozen=True) +class MergeRotationsSingleQubitPass(passes.ModulePass): + name = "merge-rotations-single-qubit" + + def apply(self, ctx: context.MLContext, module: builtin.ModuleOp) -> None: + pattern_rewriter.PatternRewriteWalker( + pattern_rewriter.GreedyRewritePatternApplier([MergeRotationsSingleQubitPattern()]) + ).rewrite_module(module) From 00ec3e26146def08671fbcaa11c286ec2ece70d3 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Thu, 1 May 2025 09:44:07 -0400 Subject: [PATCH 3/4] E.C. From 420bb25742c8169d4a7f44fcb0f2b7873881227f Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 9 May 2025 10:06:33 -0400 Subject: [PATCH 4/4] Add implementation for merge rotations --- .../python_compiler/merge_rotations_xdsl.py | 78 +++++++++++++++++-- 1 file changed, 71 insertions(+), 7 deletions(-) diff --git a/pennylane/compiler/python_compiler/merge_rotations_xdsl.py b/pennylane/compiler/python_compiler/merge_rotations_xdsl.py index ab3dfe91423..961c6323622 100644 --- a/pennylane/compiler/python_compiler/merge_rotations_xdsl.py +++ b/pennylane/compiler/python_compiler/merge_rotations_xdsl.py @@ -18,23 +18,87 @@ from dataclasses import dataclass from xdsl import context, passes, pattern_rewriter -from xdsl.dialects import builtin, func +from xdsl.dialects import arith, builtin, func +from xdsl.ir import Operation +from xdsl.rewriter import InsertPoint + +from pennylane.ops.qubit.attributes import composable_rotations from .quantum_dialect import CustomOp -class MergeRotationsSingleQubitPattern(pattern_rewriter.RewritePattern): +def _can_merge(op: CustomOp, next_op: Operation) -> bool: + if isinstance(next_op, CustomOp): + if op.gate_name.data == next_op.gate_name.data: + if op.out_qubits == next_op.in_qubits and op.out_ctrl_qubits == next_op.in_ctrl_qubits: + return True + + return False + + +class MergeRotationsPattern( + pattern_rewriter.RewritePattern +): # pylint: disable=too-few-public-methods + """RewritePattern for merging consecutive composable rotations.""" + @pattern_rewriter.op_type_rewrite_pattern - def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter): + def match_and_rewrite( + self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter + ): # pylint: disable=arguments-differ + """Implementation of rewriting FuncOps that may contain operations corresponding to + consecutive composable rotations.""" + for op in funcOp.body.walk(): + if not isinstance(op, CustomOp): + continue - pass + gate_name = op.gate_name.data + if gate_name not in composable_rotations or gate_name == "Rot": + # Can handle all composible rotations except Rot... for now + continue + + param = op.operands[0] + while True: + next_user = None + for use in op.results[0].uses: + user = use.operation + if _can_merge(op, user): + next_user = user + break + + if next_user is None: + break + + for q1, q2 in zip(op.in_qubits, op.out_qubits, strict=True): + rewriter.replace_all_uses_with(q2, q1) + for cq1, cq2 in zip(op.in_ctrl_qubits, op.out_ctrl_qubits, strict=True): + rewriter.replace_all_uses_with(cq2, cq1) + + rewriter.erase_op(op) + next_param = next_user.operands[0] + addOp = arith.AddfOp(param, next_param) + rewriter.insert_op(addOp, InsertPoint.before(next_user)) + param = addOp.result + new_op = CustomOp( + operands=(param, next_user.in_qubits[0], None, None), + properties=next_user.properties, + attributes=next_user.attributes, + successors=next_user.successors, + regions=next_user.regions, + result_types=(next_user.result_types, []), + ) + rewriter.replace_op(next_user, new_op) + op = new_op @dataclass(frozen=True) -class MergeRotationsSingleQubitPass(passes.ModulePass): - name = "merge-rotations-single-qubit" +class MergeRotationsPass(passes.ModulePass): + """Pass for merging consecutive composable rotation gates.""" + + name = "merge-rotations" + # pylint: disable=arguments-renamed def apply(self, ctx: context.MLContext, module: builtin.ModuleOp) -> None: + """Apply the merge rotations pass.""" pattern_rewriter.PatternRewriteWalker( - pattern_rewriter.GreedyRewritePatternApplier([MergeRotationsSingleQubitPattern()]) + pattern_rewriter.GreedyRewritePatternApplier([MergeRotationsPattern()]) ).rewrite_module(module)