Skip to content

Commit e5d1469

Browse files
committed
Printing as a pass
1 parent 93f70e4 commit e5d1469

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

a.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def identity_transform(tape):
3030
# All passes inherit from passes.ModulePass
3131
class PrintModule(passes.ModulePass):
3232
# All passes require a name field
33-
name = "print"
33+
name = "remove-chained-self-inverse"
3434

3535
# All passes require an apply method with this signature.
3636
def apply(self, ctx: context.MLContext, module: builtin.ModuleOp) -> None:

frontend/catalyst/compiler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,8 +488,9 @@ def run_python_compiler(self, mlir_module):
488488
# TODO: check that xdsl is available to be imported
489489
import xdsl
490490
from xdsl.context import Context
491+
from xdsl import passes
491492
from xdsl.dialects import arith, builtin, func, scf, tensor, transform
492-
from .python_compiler import quantum
493+
from .python_compiler import quantum, PrintModule
493494
generic_assembly_format = mlir_module.operation.get_asm(binary=False, print_generic_op_form=True, assume_verified=True)
494495
ctx = Context(allow_unregistered=True)
495496
ctx.load_dialect(arith.Arith)
@@ -507,6 +508,8 @@ def run_python_compiler(self, mlir_module):
507508
# TODO: Load Catalyst
508509
# TODO: Load ion/ppm/mbqc/zne...
509510
module = xdsl.parser.Parser(ctx, generic_assembly_format).parse_module()
511+
pipeline = passes.PipelinePass((PrintModule(),))
512+
pipeline.apply(ctx, module)
510513

511514
from jax._src.interpreters import mlir
512515
from jaxlib.mlir.dialects import stablehlo
@@ -520,7 +523,6 @@ def run_python_compiler(self, mlir_module):
520523
ctx.allow_unregistered_dialects = True
521524
ctx.append_dialect_registry(mlir.upstream_dialects)
522525
stablehlo.register_dialect(ctx)
523-
524526
module = Module.parse(buffer.getvalue())
525527
return module
526528

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing import Callable
2+
3+
from dataclasses import dataclass
4+
from xdsl import context, passes
5+
from xdsl.utils import parse_pipeline
6+
from xdsl.dialects import builtin
7+
8+
@dataclass(frozen=True)
9+
# All passes inherit from passes.ModulePass
10+
class PrintModule(passes.ModulePass):
11+
# All passes require a name field
12+
name = "print"
13+
14+
# All passes require an apply method with this signature.
15+
def apply(self, ctx: context.MLContext, module: builtin.ModuleOp) -> None:
16+
print("Hello from inside the pass\n", module)

0 commit comments

Comments
 (0)