Skip to content

🚧 Python Compiler #1677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import pennylane as qml

Check notice on line 1 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L1

Missing module docstring (missing-module-docstring)
from pennylane.tape import QuantumScript, QuantumScriptBatch

Check notice on line 2 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L2

Unused QuantumScript imported from pennylane.tape (unused-import)

Check notice on line 2 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L2

Unused QuantumScriptBatch imported from pennylane.tape (unused-import)
from pennylane.transforms import cancel_inverses as pl_cancel_inverses

Check notice on line 3 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L3

Unused cancel_inverses imported from pennylane.transforms as pl_cancel_inverses (unused-import)
from pennylane.typing import PostprocessingFn

Check notice on line 4 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L4

Unused PostprocessingFn imported from pennylane.typing (unused-import)
from catalyst.tracing.contexts import EvaluationContext

Check notice on line 5 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L5

Unused EvaluationContext imported from catalyst.tracing.contexts (unused-import)
from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath

import catalyst
import functools

Check notice on line 9 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L9

Unused import functools (unused-import)

from dataclasses import dataclass

Check notice on line 11 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L11

Unused dataclass imported from dataclasses (unused-import)
from xdsl import context, passes
from xdsl.utils import parse_pipeline

Check notice on line 13 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L13

Unused parse_pipeline imported from xdsl.utils (unused-import)
from xdsl.dialects import builtin, func

def xdsl_transform(_klass):

Check notice on line 16 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L16

Missing function or method docstring (missing-function-docstring)

def identity_transform(tape):
return tape, lambda args: args[0]

identity_transform.__name__ = "xdsl_transform" + _klass.__name__
transform = qml.transform(identity_transform)
catalyst.from_plxpr.register_transform(transform, _klass.name, False)
from catalyst.python_compiler import register_pass

Check notice on line 24 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L24

Import outside toplevel (catalyst.python_compiler.register_pass) (import-outside-toplevel)
register_pass(_klass.name, lambda : _klass())

Check notice on line 25 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L25

Lambda may not be necessary (unnecessary-lambda)

return transform


from xdsl.rewriter import InsertPoint

Check notice on line 30 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L30

Unused InsertPoint imported from xdsl.rewriter (unused-import)

Check notice on line 30 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L30

Import "from xdsl.rewriter import InsertPoint" should be placed at the top of the module (wrong-import-position)
from xdsl import pattern_rewriter

Check notice on line 31 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L31

Import "from xdsl import pattern_rewriter" should be placed at the top of the module (wrong-import-position)

from catalyst.python_compiler import quantum

Check notice on line 33 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L33

Imports from package catalyst are not grouped (ungrouped-imports)

Check notice on line 33 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L33

Import "from catalyst.python_compiler import quantum" should be placed at the top of the module (wrong-import-position)


self_inverses = ("PauliZ", "PauliX", "PauliY", "Hadamard", "Identity")

Check notice on line 36 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L36

Missing function or method docstring (missing-function-docstring)

Check notice on line 36 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L36

Unused argument 'ctx' (unused-argument)


def cancel_ops(rewriter, op, next_op):

Check notice on line 39 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L39

Missing function or method docstring (missing-function-docstring)
rewriter.replace_all_uses_with(next_op.results[0], op.in_qubits[0])
rewriter.erase_op(next_op)
rewriter.erase_op(op)
owner = op.in_qubits[0].owner

if isinstance(owner, quantum.CustomOp) and owner.gate_name.data in self_inverses:

Check notice on line 45 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L45

Missing function or method docstring (missing-function-docstring)
next_user = None

for use in owner.results[0].uses:
user = use.operation
if isinstance(user, quantum.CustomOp) and user.gate_name.data == owner.gate_name.data:
next_user = user
break

if next_user is not None:
cancel_ops(rewriter, owner, next_user)


class DeepCancelInversesSingleQubitPattern(pattern_rewriter.RewritePattern):

Check notice on line 58 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L58

Missing class docstring (missing-class-docstring)
@pattern_rewriter.op_type_rewrite_pattern
def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the implementation a bit better by replacing recursion with a while loop. cancel_ops is not needed anymore:

    @pattern_rewriter.op_type_rewrite_pattern
    def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter):
        """Deep Cancel for Self Inverses"""
        for op in funcOp.body.walk():

            while isinstance(op, CustomOp) and op.gate_name.data in self_inverses:

                next_user = None
                for use in op.results[0].uses:
                    user = use.operation
                    if isinstance(user, CustomOp) and user.gate_name.data == op.gate_name.data:
                        next_user = user
                        break

                if next_user is None:
                    break

                rewriter._replace_all_uses_with(next_user.results[0], op.in_qubits[0])
                rewriter.erase_op(next_user)
                rewriter.erase_op(op)

                op = op.in_qubits[0].owner

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""Deep Cancel for Self Inverses"""
print(funcOp)
for op in funcOp.body.walk():
if not isinstance(op, quantum.CustomOp):
continue

if op.gate_name.data not in self_inverses:
continue

next_user = None
for use in op.results[0].uses:
user = use.operation
if isinstance(user, quantum.CustomOp) and user.gate_name.data == op.gate_name.data:
next_user = user
break

if next_user is not None:
cancel_ops(rewriter, op, next_user)


@xdsl_transform
class DeepCancelInversesSingleQubitPass(passes.ModulePass):

Check notice on line 82 in a.py

View check run for this annotation

codefactor.io / CodeFactor

a.py#L82

Missing class docstring (missing-class-docstring)
name = "deep-cancel-inverses-single-qubit"

def apply(self, ctx: context.MLContext, module: builtin.ModuleOp) -> None:
pattern_rewriter.PatternRewriteWalker(
pattern_rewriter.GreedyRewritePatternApplier([DeepCancelInversesSingleQubitPattern()])
).rewrite_module(module)

qml.capture.enable()

@catalyst.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])
@DeepCancelInversesSingleQubitPass
@qml.qnode(qml.device("lightning.qubit", wires=1))
def captured_circuit(x: float):
qml.RX(x, wires=0)
qml.Hadamard(wires=0)
qml.Hadamard(wires=0)
return qml.state()


qml.capture.disable()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note that if capture is disabled before executing, the tracing will happen without the plxpr pipeline. I'll share more details on Slack.



captured_circuit(0.0)
65 changes: 65 additions & 0 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,67 @@

return output_object_name, out_IR

@debug_logger
def is_using_python_compiler(self):

Check notice on line 471 in frontend/catalyst/compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/compiler.py#L471

Missing function or method docstring (missing-function-docstring)
xdsl_path = pathlib.Path("xdsl-does-not-use-a-real-path")
using_xdsl = False
if xdsl_path in self.options.pass_plugins:
plugins = self.options.pass_plugins
self.options.pass_plugins = tuple(elem for elem in plugins if elem != xdsl_path)
using_xdsl = True

if xdsl_path in self.options.dialect_plugins:
plugins = self.options.dialect_plugins
self.options.dialect_plugins = tuple(elem for elem in plugins if elem != xdsl_path)
using_xdsl = True

return using_xdsl

def run_python_compiler(self, mlir_module):

Check notice on line 486 in frontend/catalyst/compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/compiler.py#L486

Missing function or method docstring (missing-function-docstring)

# TODO: check that xdsl is available to be imported
import xdsl

Check notice on line 489 in frontend/catalyst/compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/compiler.py#L489

Import outside toplevel (xdsl) (import-outside-toplevel)
from xdsl.context import Context

Check notice on line 490 in frontend/catalyst/compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/compiler.py#L490

Import outside toplevel (xdsl.context.Context) (import-outside-toplevel)
from xdsl import passes

Check notice on line 491 in frontend/catalyst/compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/compiler.py#L491

Import outside toplevel (xdsl.passes) (import-outside-toplevel)
from xdsl.dialects import arith, builtin, func, scf, tensor, transform

Check notice on line 492 in frontend/catalyst/compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/compiler.py#L492

Import outside toplevel (xdsl.dialects.arith, xdsl.dialects.builtin, xdsl.dialects.func, xdsl.dialects.scf, xdsl.dialects.tensor, xdsl.dialects.transform) (import-outside-toplevel)
from .python_compiler import quantum, ApplyTransformSequence

Check notice on line 493 in frontend/catalyst/compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/compiler.py#L493

Import outside toplevel (python_compiler.quantum, python_compiler.ApplyTransformSequence) (import-outside-toplevel)
generic_assembly_format = mlir_module.operation.get_asm(binary=False, print_generic_op_form=True, assume_verified=True)

Check notice on line 494 in frontend/catalyst/compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/compiler.py#L494

Line too long (127/100) (line-too-long)
ctx = Context(allow_unregistered=True)
ctx.load_dialect(arith.Arith)
ctx.load_dialect(builtin.Builtin)
ctx.load_dialect(func.Func)
ctx.load_dialect(scf.Scf)
ctx.load_dialect(tensor.Tensor)
ctx.load_dialect(transform.Transform)
ctx.load_dialect(quantum.QuantumDialect)
# TODO: In order of importance
# TODO: Load gradient
# TODO: Load the stablehlo dialect? I kind of want to wait until all operations
# are properly represented. I am not sure how an incompletely known dialect would be handled
# but it is likely worse than an unregistered dialect.
# TODO: Load Catalyst
# TODO: Load ion/ppm/mbqc/zne...
module = xdsl.parser.Parser(ctx, generic_assembly_format).parse_module()
pipeline = passes.PipelinePass((ApplyTransformSequence(),))
pipeline.apply(ctx, module)

from jax._src.interpreters import mlir

Check notice on line 514 in frontend/catalyst/compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/compiler.py#L514

Import outside toplevel (jax._src.interpreters.mlir) (import-outside-toplevel)
from jaxlib.mlir.dialects import stablehlo

Check notice on line 515 in frontend/catalyst/compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/compiler.py#L515

Import outside toplevel (jaxlib.mlir.dialects.stablehlo) (import-outside-toplevel)
from jaxlib.mlir.ir import Context, Module

Check notice on line 516 in frontend/catalyst/compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/compiler.py#L516

Import outside toplevel (jaxlib.mlir.ir.Context, jaxlib.mlir.ir.Module) (import-outside-toplevel)
from xdsl.printer import Printer

Check notice on line 517 in frontend/catalyst/compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/compiler.py#L517

Import outside toplevel (xdsl.printer.Printer) (import-outside-toplevel)
import io

Check notice on line 518 in frontend/catalyst/compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/compiler.py#L518

Import outside toplevel (io) (import-outside-toplevel)

buffer = io.StringIO()
Printer(stream=buffer, print_generic_format=True).print(module)
with Context() as ctx:
ctx.allow_unregistered_dialects = True
ctx.append_dialect_registry(mlir.upstream_dialects)
stablehlo.register_dialect(ctx)
module = Module.parse(buffer.getvalue())
return module

# TODO: transform the program based on the transform dialect

@debug_logger
def run(self, mlir_module, *args, **kwargs):
"""Compile an MLIR module to a shared object.
Expand All @@ -483,6 +544,10 @@
(str): filename of shared object
"""

python_compiler = self.is_using_python_compiler()
if python_compiler:
mlir_module = self.run_python_compiler(mlir_module)

return self.run_from_ir(
mlir_module.operation.get_asm(
binary=False, print_generic_op_form=False, assume_verified=True
Expand Down
17 changes: 12 additions & 5 deletions frontend/catalyst/from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,10 @@ def handle_qnode(
}


# This is our registration factory for PL transforms. The loop below iterates
# across the map above and generates a custom handler for each transform.
# In order to ensure early binding, we pass the PL plxpr transform and the
# Catalyst pass as arguments whose default values are set by the loop.
for pl_transform, (pass_name, decomposition) in transforms_to_passes.items():
# pylint: disable-next=redefined-outer-name
def register_transform(pl_transform, pass_name, decomposition):
"""Register pennylane transforms and their conversion to Catalyst transforms"""

# pylint: disable=unused-argument, too-many-arguments, cell-var-from-loop
@WorkflowInterpreter.register_primitive(pl_transform._primitive)
def handle_transform(
Expand Down Expand Up @@ -251,6 +250,14 @@ def wrapper(*args):
return self.eval(inner_jaxpr, consts, *non_const_args)


# This is our registration factory for PL transforms. The loop below iterates
# across the map above and generates a custom handler for each transform.
# In order to ensure early binding, we pass the PL plxpr transform and the
# Catalyst pass as arguments whose default values are set by the loop.
for pl_transform, (pass_name, decomposition) in transforms_to_passes.items():
register_transform(pl_transform, pass_name, decomposition)


class QFuncPlxprInterpreter(PlxprInterpreter):
"""An interpreter that converts plxpr into catalyst-variant jaxpr.

Expand Down
31 changes: 31 additions & 0 deletions frontend/catalyst/passes/xdsl_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""XDSL Plugin Interface"""

# This file contains what looks like a plugin
# but in reality it is just not a "real" MLIR plugin.
# It just follows the same convention to be able to add passes
# that are implemented in xDSL

from pathlib import Path


def getXDSLPluginAbsolutePath():
"""Returns a fake path"""
return Path("xdsl-does-not-use-a-real-path")


def name2pass(name):

Check notice on line 30 in frontend/catalyst/passes/xdsl_plugin/__init__.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/passes/xdsl_plugin/__init__.py#L30

Missing function or method docstring (missing-function-docstring)
return getXDSLPluginAbsolutePath(), name
120 changes: 120 additions & 0 deletions frontend/catalyst/python_compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from typing import Callable

Check notice on line 1 in frontend/catalyst/python_compiler/__init__.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/python_compiler/__init__.py#L1

Missing module docstring (missing-module-docstring)

from dataclasses import dataclass
from xdsl import context, passes
from xdsl.utils import parse_pipeline
from xdsl.dialects import builtin

from dataclasses import dataclass

from xdsl.dialects import builtin, transform
from xdsl.interpreter import Interpreter
from xdsl.interpreters.transform import TransformFunctions
from xdsl.passes import Context, ModulePass
from xdsl.utils.exceptions import PassFailedException

Check notice on line 15 in frontend/catalyst/python_compiler/__init__.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/python_compiler/__init__.py#L15

Missing function or method docstring (missing-function-docstring)
from xdsl.interpreter import (
Interpreter,
InterpreterFunctions,
PythonValues,
ReturnedValues,
TerminatorValue,
impl,
impl_callable,
impl_terminator,
register_impls,
)

@register_impls
class TransformFunctionsExt(TransformFunctions):

Check notice on line 29 in frontend/catalyst/python_compiler/__init__.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/python_compiler/__init__.py#L29

Missing class docstring (missing-class-docstring)
ctx: Context
passes: dict[str, Callable[[], type[ModulePass]]]

def __init__(
self, ctx: Context, available_passes: dict[str, Callable[[], type[ModulePass]]]
):
self.ctx = ctx
self.passes = available_passes

@impl(transform.ApplyRegisteredPassOp)
def run_apply_registered_pass_op(

Check notice on line 40 in frontend/catalyst/python_compiler/__init__.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/python_compiler/__init__.py#L40

Missing function or method docstring (missing-function-docstring)
self,
interpreter: Interpreter,

Check notice on line 42 in frontend/catalyst/python_compiler/__init__.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/python_compiler/__init__.py#L42

Unused argument 'interpreter' (unused-argument)
op: transform.ApplyRegisteredPassOp,
args: PythonValues,
) -> PythonValues:
pass_name = op.pass_name.data
requested_by_user = passes.PipelinePass.build_pipeline_tuples(
self.passes, parse_pipeline.parse_pipeline(pass_name)
)
# TODO: Switch between catalyst and xDSL

schedule = tuple(
pass_type.from_pass_spec(spec) for pass_type, spec in requested_by_user
)
pipeline = passes.PipelinePass(schedule)
pipeline.apply(self.ctx, args[0])
return (args[0],)

from xdsl.transforms import get_all_passes

Check notice on line 59 in frontend/catalyst/python_compiler/__init__.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/python_compiler/__init__.py#L59

Import "from xdsl.transforms import get_all_passes" should be placed at the top of the module (wrong-import-position)

updated_passes = get_all_passes()

def register_pass(name, _callable):

Check notice on line 63 in frontend/catalyst/python_compiler/__init__.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/python_compiler/__init__.py#L63

Missing function or method docstring (missing-function-docstring)
global updated_passes

Check notice on line 64 in frontend/catalyst/python_compiler/__init__.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/python_compiler/__init__.py#L64

Using global for 'updated_passes' but no assignment is done (global-variable-not-assigned)
updated_passes[name] = _callable

@dataclass(frozen=True)
class TransformInterpreterPass(ModulePass):
"""Transform dialect interpreter"""

Check notice on line 70 in frontend/catalyst/python_compiler/__init__.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/python_compiler/__init__.py#L70

Missing function or method docstring (missing-function-docstring)
name = "transform-interpreter"

entry_point: str = "__transform_main"

@staticmethod
def find_transform_entry_point(
root: builtin.ModuleOp, entry_point: str
) -> transform.NamedSequenceOp:
for op in root.walk():
if (
isinstance(op, transform.NamedSequenceOp)
and op.sym_name.data == entry_point
):

Check notice on line 83 in frontend/catalyst/python_compiler/__init__.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/python_compiler/__init__.py#L83

Missing function or method docstring (missing-function-docstring)
return op
raise PassFailedException(
f"{root} could not find a nested named sequence with name: {entry_point}"
)

def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
schedule = TransformInterpreterPass.find_transform_entry_point(
op, self.entry_point
)

Check notice on line 92 in frontend/catalyst/python_compiler/__init__.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/python_compiler/__init__.py#L92

Missing class docstring (missing-class-docstring)
interpreter = Interpreter(op)
global updated_passes

Check notice on line 94 in frontend/catalyst/python_compiler/__init__.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/python_compiler/__init__.py#L94

Using global for 'updated_passes' but no assignment is done (global-variable-not-assigned)
interpreter.register_implementations(TransformFunctionsExt(ctx, updated_passes))
interpreter.call_op(schedule, (op,))

@dataclass(frozen=True)

Check notice on line 98 in frontend/catalyst/python_compiler/__init__.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/python_compiler/__init__.py#L98

Bad indentation. Found 10 spaces, expected 12 (bad-indentation)
class ApplyTransformSequence(passes.ModulePass):

Check notice on line 99 in frontend/catalyst/python_compiler/__init__.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/python_compiler/__init__.py#L99

Bad indentation. Found 14 spaces, expected 16 (bad-indentation)
name = "apply-transform-sequence"

Check notice on line 100 in frontend/catalyst/python_compiler/__init__.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/python_compiler/__init__.py#L100

Bad indentation. Found 18 spaces, expected 20 (bad-indentation)

Check notice on line 101 in frontend/catalyst/python_compiler/__init__.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/python_compiler/__init__.py#L101

Bad indentation. Found 22 spaces, expected 24 (bad-indentation)
def apply(self, ctx: context.MLContext, module: builtin.ModuleOp) -> None:
nested_modules = []
for region in module.regions:
for block in region.blocks:
for op in block.ops:
if isinstance(op, builtin.ModuleOp):
nested_modules.append(op)

pipeline = passes.PipelinePass((TransformInterpreterPass(),))
for op in nested_modules:

Check notice on line 111 in frontend/catalyst/python_compiler/__init__.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/python_compiler/__init__.py#L111

Line too long (117/100) (line-too-long)
pipeline.apply(ctx, op)

Check notice on line 113 in frontend/catalyst/python_compiler/__init__.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/python_compiler/__init__.py#L113

Trailing newlines (trailing-newlines)
for op in nested_modules:
for region in op.regions:
for block in region.blocks:
for op in block.ops:
if isinstance(op, builtin.ModuleOp) and op.get_attr_or_prop("transform.with_named_sequence"):
block.erase_op(op)

Loading
Loading