-
Notifications
You must be signed in to change notification settings - Fork 47
🚧 Python Compiler #1677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🚧 Python Compiler #1677
Changes from 19 commits
1e1deb3
2ffc717
7052374
3259167
66f0dfe
3cc2b83
1ae56d6
ebb4f07
fddf3ee
9fc48cd
a217847
88b9d49
75c89c9
f0c9810
93f70e4
e5d1469
f9e5475
855e8c5
cd276de
d3dac88
4cd81ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import pennylane as qml | ||
from pennylane.tape import QuantumScript, QuantumScriptBatch | ||
Check notice on line 2 in a.py
|
||
from pennylane.transforms import cancel_inverses as pl_cancel_inverses | ||
from pennylane.typing import PostprocessingFn | ||
from catalyst.tracing.contexts import EvaluationContext | ||
from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath | ||
|
||
import catalyst | ||
import functools | ||
|
||
from dataclasses import dataclass | ||
from xdsl import context, passes | ||
from xdsl.utils import parse_pipeline | ||
from xdsl.dialects import builtin, func | ||
|
||
def xdsl_transform(_klass): | ||
|
||
def identity_transform(tape): | ||
return tape, lambda args: args[0] | ||
|
||
identity_transform.__name__ = "xdsl_transform" + _klass.__name__ | ||
transform = qml.transform(identity_transform) | ||
catalyst.from_plxpr.register_transform(transform, _klass.name, False) | ||
from catalyst.python_compiler import register_pass | ||
register_pass(_klass.name, lambda : _klass()) | ||
|
||
return transform | ||
|
||
|
||
from xdsl.rewriter import InsertPoint | ||
Check notice on line 30 in a.py
|
||
from xdsl import pattern_rewriter | ||
|
||
from catalyst.python_compiler import quantum | ||
Check notice on line 33 in a.py
|
||
|
||
|
||
self_inverses = ("PauliZ", "PauliX", "PauliY", "Hadamard", "Identity") | ||
Check notice on line 36 in a.py
|
||
|
||
|
||
def cancel_ops(rewriter, op, next_op): | ||
rewriter.replace_all_uses_with(next_op.results[0], op.in_qubits[0]) | ||
rewriter.erase_op(next_op) | ||
rewriter.erase_op(op) | ||
owner = op.in_qubits[0].owner | ||
|
||
if isinstance(owner, quantum.CustomOp) and owner.gate_name.data in self_inverses: | ||
next_user = None | ||
|
||
for use in owner.results[0].uses: | ||
user = use.operation | ||
if isinstance(user, quantum.CustomOp) and user.gate_name.data == owner.gate_name.data: | ||
next_user = user | ||
break | ||
|
||
if next_user is not None: | ||
cancel_ops(rewriter, owner, next_user) | ||
|
||
|
||
class DeepCancelInversesSingleQubitPattern(pattern_rewriter.RewritePattern): | ||
@pattern_rewriter.op_type_rewrite_pattern | ||
def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter): | ||
"""Deep Cancel for Self Inverses""" | ||
print(funcOp) | ||
for op in funcOp.body.walk(): | ||
if not isinstance(op, quantum.CustomOp): | ||
continue | ||
|
||
if op.gate_name.data not in self_inverses: | ||
continue | ||
|
||
next_user = None | ||
for use in op.results[0].uses: | ||
user = use.operation | ||
if isinstance(user, quantum.CustomOp) and user.gate_name.data == op.gate_name.data: | ||
next_user = user | ||
break | ||
|
||
if next_user is not None: | ||
cancel_ops(rewriter, op, next_user) | ||
|
||
|
||
@xdsl_transform | ||
class DeepCancelInversesSingleQubitPass(passes.ModulePass): | ||
name = "deep-cancel-inverses-single-qubit" | ||
|
||
def apply(self, ctx: context.MLContext, module: builtin.ModuleOp) -> None: | ||
pattern_rewriter.PatternRewriteWalker( | ||
pattern_rewriter.GreedyRewritePatternApplier([DeepCancelInversesSingleQubitPattern()]) | ||
).rewrite_module(module) | ||
|
||
qml.capture.enable() | ||
|
||
@catalyst.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) | ||
@DeepCancelInversesSingleQubitPass | ||
@qml.qnode(qml.device("lightning.qubit", wires=1)) | ||
def captured_circuit(x: float): | ||
qml.RX(x, wires=0) | ||
qml.Hadamard(wires=0) | ||
qml.Hadamard(wires=0) | ||
return qml.state() | ||
|
||
|
||
qml.capture.disable() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a note that if capture is disabled before executing, the tracing will happen without the plxpr pipeline. I'll share more details on Slack. |
||
|
||
|
||
captured_circuit(0.0) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright 2025 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""XDSL Plugin Interface""" | ||
|
||
# This file contains what looks like a plugin | ||
# but in reality it is just not a "real" MLIR plugin. | ||
# It just follows the same convention to be able to add passes | ||
# that are implemented in xDSL | ||
|
||
from pathlib import Path | ||
|
||
|
||
def getXDSLPluginAbsolutePath(): | ||
"""Returns a fake path""" | ||
return Path("xdsl-does-not-use-a-real-path") | ||
|
||
|
||
def name2pass(name): | ||
return getXDSLPluginAbsolutePath(), name |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
from typing import Callable | ||
|
||
from dataclasses import dataclass | ||
from xdsl import context, passes | ||
from xdsl.utils import parse_pipeline | ||
from xdsl.dialects import builtin | ||
|
||
from dataclasses import dataclass | ||
|
||
from xdsl.dialects import builtin, transform | ||
from xdsl.interpreter import Interpreter | ||
from xdsl.interpreters.transform import TransformFunctions | ||
from xdsl.passes import Context, ModulePass | ||
from xdsl.utils.exceptions import PassFailedException | ||
|
||
from xdsl.interpreter import ( | ||
Interpreter, | ||
InterpreterFunctions, | ||
PythonValues, | ||
ReturnedValues, | ||
TerminatorValue, | ||
impl, | ||
impl_callable, | ||
impl_terminator, | ||
register_impls, | ||
) | ||
|
||
@register_impls | ||
class TransformFunctionsExt(TransformFunctions): | ||
ctx: Context | ||
passes: dict[str, Callable[[], type[ModulePass]]] | ||
|
||
def __init__( | ||
self, ctx: Context, available_passes: dict[str, Callable[[], type[ModulePass]]] | ||
): | ||
self.ctx = ctx | ||
self.passes = available_passes | ||
|
||
@impl(transform.ApplyRegisteredPassOp) | ||
def run_apply_registered_pass_op( | ||
self, | ||
interpreter: Interpreter, | ||
op: transform.ApplyRegisteredPassOp, | ||
args: PythonValues, | ||
) -> PythonValues: | ||
pass_name = op.pass_name.data | ||
requested_by_user = passes.PipelinePass.build_pipeline_tuples( | ||
self.passes, parse_pipeline.parse_pipeline(pass_name) | ||
) | ||
# TODO: Switch between catalyst and xDSL | ||
|
||
schedule = tuple( | ||
pass_type.from_pass_spec(spec) for pass_type, spec in requested_by_user | ||
) | ||
pipeline = passes.PipelinePass(schedule) | ||
pipeline.apply(self.ctx, args[0]) | ||
return (args[0],) | ||
|
||
from xdsl.transforms import get_all_passes | ||
Check notice on line 59 in frontend/catalyst/python_compiler/__init__.py
|
||
|
||
updated_passes = get_all_passes() | ||
|
||
def register_pass(name, _callable): | ||
global updated_passes | ||
updated_passes[name] = _callable | ||
|
||
@dataclass(frozen=True) | ||
class TransformInterpreterPass(ModulePass): | ||
"""Transform dialect interpreter""" | ||
|
||
name = "transform-interpreter" | ||
|
||
entry_point: str = "__transform_main" | ||
|
||
@staticmethod | ||
def find_transform_entry_point( | ||
root: builtin.ModuleOp, entry_point: str | ||
) -> transform.NamedSequenceOp: | ||
for op in root.walk(): | ||
if ( | ||
isinstance(op, transform.NamedSequenceOp) | ||
and op.sym_name.data == entry_point | ||
): | ||
return op | ||
raise PassFailedException( | ||
f"{root} could not find a nested named sequence with name: {entry_point}" | ||
) | ||
|
||
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None: | ||
schedule = TransformInterpreterPass.find_transform_entry_point( | ||
op, self.entry_point | ||
) | ||
interpreter = Interpreter(op) | ||
global updated_passes | ||
interpreter.register_implementations(TransformFunctionsExt(ctx, updated_passes)) | ||
interpreter.call_op(schedule, (op,)) | ||
|
||
@dataclass(frozen=True) | ||
class ApplyTransformSequence(passes.ModulePass): | ||
name = "apply-transform-sequence" | ||
|
||
def apply(self, ctx: context.MLContext, module: builtin.ModuleOp) -> None: | ||
nested_modules = [] | ||
for region in module.regions: | ||
for block in region.blocks: | ||
for op in block.ops: | ||
if isinstance(op, builtin.ModuleOp): | ||
nested_modules.append(op) | ||
|
||
pipeline = passes.PipelinePass((TransformInterpreterPass(),)) | ||
for op in nested_modules: | ||
pipeline.apply(ctx, op) | ||
|
||
for op in nested_modules: | ||
for region in op.regions: | ||
for block in region.blocks: | ||
for op in block.ops: | ||
if isinstance(op, builtin.ModuleOp) and op.get_attr_or_prop("transform.with_named_sequence"): | ||
block.erase_op(op) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made the implementation a bit better by replacing recursion with a while loop.
cancel_ops
is not needed anymore:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
d3dac88