Skip to content

Commit af6e9cd

Browse files
committed
add initial impl
1 parent 920a62c commit af6e9cd

File tree

7 files changed

+31
-11
lines changed

7 files changed

+31
-11
lines changed

pennylane/compiler/python_compiler/impl.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from .quantum_dialect import QuantumDialect as Quantum
1414

15+
from .transforms import ApplyTransformSequence
16+
1517
class Compiler:
1618

1719
def run(self, jmod: jaxModule) -> jaxModule:
@@ -33,7 +35,7 @@ def run(self, jmod: jaxModule) -> jaxModule:
3335
ctx.load_dialect(Quantum)
3436

3537
xmod: builtin.ModuleOp = Parser(ctx, gentxtmod).parse_module()
36-
pipeline = passes.PipelinePass((ApplyTransformSequence(),))
38+
pipeline = PipelinePass((ApplyTransformSequence(),))
3739
# xmod is modified in place
3840
pipeline.apply(ctx, xmod)
3941

@@ -43,6 +45,6 @@ def run(self, jmod: jaxModule) -> jaxModule:
4345
ctx.allow_unregistered_dialects = True
4446
ctx.append_dialect_registry(mlir.upstream_dialects)
4547
stablehlo.register_dialect(ctx)
46-
newmod: jaxModule = Module.parse(buffer.getvalue())
48+
newmod: jaxModule = jaxModule.parse(buffer.getvalue())
4749

4850
return newmod

pennylane/compiler/python_compiler/transforms/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from apply_transform_sequence import ApplyTransformSequence
2-
from transform_interpreter import TransformInterpreterPass
1+
from .apply_transform_sequence import ApplyTransformSequence
2+
from .transform_interpreter import TransformInterpreterPass
33

44
__all__ = [
55
"ApplyTransformSequence",

pennylane/compiler/python_compiler/transforms/apply_transform_sequence.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from dataclasses import dataclass
2+
13
from xdsl.dialects import builtin
24
from xdsl.context import MLContext
35
from xdsl.passes import ModulePass, PipelinePass
46

5-
from transform_interpreter import TransformInterpreterPass
7+
from .transform_interpreter import TransformInterpreterPass
68

79
@dataclass(frozen=True)
810
class ApplyTransformSequence(ModulePass):
@@ -16,7 +18,7 @@ def apply(self, ctx: MLContext, module: builtin.ModuleOp) -> None:
1618
if isinstance(op, builtin.ModuleOp):
1719
nested_modules.append(op)
1820

19-
pipeline = PipelinePass((TransformInterpreterPass({}),))
21+
pipeline = PipelinePass((TransformInterpreterPass(passes={}),))
2022
for op in nested_modules:
2123
pipeline.apply(ctx, op)
2224

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .transform_interpreter_catalyst import TransformInterpreterPass
2+
3+
__all__ = [
4+
"TransformInterpreterPass"
5+
]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .impl import TransformFunctionsExt
2+
3+
__all__ = [
4+
"TransformFunctionsExt"
5+
]

pennylane/compiler/python_compiler/transforms/transform_interpreter/interpreter/impl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import io
22
from typing import Callable
33

4-
from xdsl.interpreter import Interpreter, PythonValues, register_impls
4+
from xdsl.interpreter import Interpreter, PythonValues, impl, register_impls
55
from xdsl.interpreters.transform import TransformFunctions
6+
from xdsl.context import Context
67
from xdsl.dialects import transform
78
from xdsl.parser import Parser
89
from xdsl.passes import ModulePass, PipelinePass
910
from xdsl.printer import Printer
1011
from xdsl.rewriter import Rewriter
12+
from xdsl.utils import parse_pipeline
1113

1214
from catalyst.compiler import _quantum_opt
1315

pennylane/compiler/python_compiler/transforms/transform_interpreter/transform_interpreter_catalyst.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
from dataclasses import dataclass
2+
from typing import Callable
23

4+
from xdsl.context import Context
35
from xdsl.dialects import builtin, transform
46
from xdsl.interpreters import Interpreter
57
from xdsl.passes import ModulePass
68

7-
from interpreter import TransformFunctionsExt
9+
from .interpreter import TransformFunctionsExt
810

9-
@dataclass(frozen=True)
1011
class TransformInterpreterPass(ModulePass):
1112
"""Transform dialect interpreter"""
1213

14+
passes: dict[str, Callable[[], type[ModulePass]]]
1315
name = "transform-interpreter"
1416

1517
entry_point: str = "__transform_main"
16-
passes: dict[str, Callable[[], type[ModulePass]]]
18+
19+
def __init__(self, passes):
20+
self.passes = passes
1721

1822
@staticmethod
1923
def find_transform_entry_point(
@@ -34,6 +38,6 @@ def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
3438
op, self.entry_point
3539
)
3640
interpreter = Interpreter(op)
37-
interpreter.register_implementations(TransformFunctionsExt(ctx, passes))
41+
interpreter.register_implementations(TransformFunctionsExt(ctx, self.passes))
3842
schedule.parent_op().detach()
3943
interpreter.call_op(schedule, (op,))

0 commit comments

Comments
 (0)