Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3c35dc8
print-parse-print-parse
erick-xanadu May 1, 2025
6c2e364
starting point
erick-xanadu May 1, 2025
f5546cb
add initial impl
erick-xanadu May 1, 2025
9dec689
add register pass
erick-xanadu May 2, 2025
0cc27bc
Use context instead of MLContext
erick-xanadu May 2, 2025
d487995
Apply suggestions from code review
erick-xanadu May 13, 2025
074b564
style
erick-xanadu May 14, 2025
5a688db
style
erick-xanadu May 14, 2025
ac551bd
no frozen
erick-xanadu May 14, 2025
e6fbbc0
add stablehlo
erick-xanadu May 14, 2025
ddd8363
changelog
erick-xanadu May 14, 2025
0960cec
Merge branch 'master' into eochoa/2025-05-01/python-compiler
erick-xanadu May 14, 2025
852f379
nocover
erick-xanadu May 14, 2025
b834451
style
erick-xanadu May 14, 2025
59ce1c6
f
erick-xanadu May 14, 2025
adb0b52
too-few-public-methods
erick-xanadu May 14, 2025
0264101
exclude file
erick-xanadu May 14, 2025
fc5fe01
changelog
erick-xanadu May 14, 2025
3b275f5
Add pragma: exclude file
erick-xanadu May 14, 2025
05b4366
jax utils
erick-xanadu May 15, 2025
b48aaef
tests
erick-xanadu May 15, 2025
ae3a76a
Merge branch 'master' into eochoa/2025-05-01/python-compiler
erick-xanadu May 15, 2025
5ba37dd
Revert "Add pragma: exclude file"
erick-xanadu May 15, 2025
8edbd93
codecov ignore for now
erick-xanadu May 15, 2025
37f63c7
skip import
erick-xanadu May 15, 2025
84dee73
import skip jax
erick-xanadu May 15, 2025
3ef0ec3
isinstance fix
erick-xanadu May 15, 2025
98a81d4
importor
erick-xanadu May 15, 2025
3f57c75
rename file
erick-xanadu May 15, 2025
650b6a3
more coverage
erick-xanadu May 15, 2025
d319fb9
more coverage
erick-xanadu May 15, 2025
21c81ea
no cover
erick-xanadu May 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@
decomposition [Theorem 7 in Shende et al.](https://arxiv.org/abs/quant-ph/0406176)
that contains fewer gates than the previous decomposition.
[(#7370)](https://github.com/PennyLaneAI/pennylane/pull/7370)

* An experimental integration for a Python compiler using [xDSL](https://xdsl.dev/index) has been introduced.
This is similar to [Catalyst's MLIR dialects](https://docs.pennylane.ai/projects/catalyst/en/stable/dev/dialects.html#mlir-dialects-in-catalyst),
but it is coded in Python instead of C++.
[(#7357)](https://github.com/PennyLaneAI/pennylane/pull/7357)
[(#7367)](https://github.com/PennyLaneAI/pennylane/pull/7367)

* PennyLane supports `JAX` version 0.6.0.
[(#7299)](https://github.com/PennyLaneAI/pennylane/pull/7299)
Expand All @@ -166,11 +172,6 @@
and :class:`~.SelectPauliRot`, now takes much less computational effort and memory.
[(#7377)](https://github.com/PennyLaneAI/pennylane/pull/7377)

* An experimental quantum dialect written in [xDSL](https://xdsl.dev/index) has been introduced.
This is similar to [Catalyst's MLIR dialects](https://docs.pennylane.ai/projects/catalyst/en/stable/dev/dialects.html#mlir-dialects-in-catalyst),
but it is coded in Python instead of C++.
[(#7357)](https://github.com/PennyLaneAI/pennylane/pull/7357)

* The :func:`~.transforms.cancel_inverses` transform no longer changes the order of operations that don't have shared wires, providing a deterministic output.
[(#7328)](https://github.com/PennyLaneAI/pennylane/pull/7328)

Expand Down
Empty file.
74 changes: 74 additions & 0 deletions pennylane/compiler/python_compiler/impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file contains the implementation of the PennyLane-xDSL integration API."""


import io

from jax._src.interpreters import mlir
from jaxlib.mlir.dialects import stablehlo
from jaxlib.mlir.ir import Context as jaxContext # pylint: disable=no-name-in-module
from jaxlib.mlir.ir import Module as jaxModule # pylint: disable=no-name-in-module
from xdsl.context import Context as xdslContext
from xdsl.dialects import arith, builtin, func, scf
from xdsl.dialects import stablehlo as xdslstablehlo
from xdsl.dialects import tensor, transform
from xdsl.parser import Parser
from xdsl.passes import PipelinePass
from xdsl.printer import Printer

from .quantum_dialect import QuantumDialect as Quantum
from .transforms import ApplyTransformSequence


# pylint: disable=too-few-public-methods
class Compiler:
"""Compiler namespace"""

@staticmethod
def run(jmod: jaxModule) -> jaxModule:
"""Runs the apply-transform-sequence pass.

The apply-transform-sequence pass is a "meta-pass". In other words,
it is a pass that runs other passes.
"""

gentxtmod: str = jmod.operation.get_asm(
binary=False, print_generic_op_form=True, assume_verified=True
)

ctx = xdslContext(allow_unregistered=True)
ctx.load_dialect(arith.Arith)
ctx.load_dialect(builtin.Builtin)
ctx.load_dialect(func.Func)
ctx.load_dialect(scf.Scf)
ctx.load_dialect(xdslstablehlo.StableHLO)
ctx.load_dialect(tensor.Tensor)
ctx.load_dialect(transform.Transform)
ctx.load_dialect(Quantum)

xmod: builtin.ModuleOp = Parser(ctx, gentxtmod).parse_module()
pipeline = PipelinePass((ApplyTransformSequence(),))
# xmod is modified in place
pipeline.apply(ctx, xmod)

buffer = io.StringIO()
Printer(stream=buffer, print_generic_format=True).print(xmod)
with jaxContext() as ctx:
ctx.allow_unregistered_dialects = True
ctx.append_dialect_registry(mlir.upstream_dialects)
stablehlo.register_dialect(ctx) # pylint: disable=no-member
newmod: jaxModule = jaxModule.parse(buffer.getvalue())

return newmod
139 changes: 139 additions & 0 deletions pennylane/compiler/python_compiler/jax_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for translating JAX to xDSL"""

from functools import wraps
from typing import Any, Callable, TypeAlias

import jax
import jaxlib


from jax.extend import mlir as jmlir # pylint: disable=no-name-in-module
from jaxlib.mlir.ir import Context as jContext # pylint: disable=no-name-in-module
from jaxlib.mlir.ir import Module as jModule # pylint: disable=no-name-in-module
from jaxlib.mlir.dialects import stablehlo as jstablehlo # pylint: disable=no-name-in-module

from xdsl.dialects import arith as xarith
from xdsl.dialects import builtin as xbuiltin
from xdsl.dialects import func as xfunc
from xdsl.dialects import scf as xscf
from xdsl.dialects import stablehlo as xstablehlo
from xdsl.dialects import tensor as xtensor
from xdsl.dialects import transform as xtransform

from xdsl.parser import Parser as xParser
from xdsl.context import Context as xContext

JaxJittedFunction: TypeAlias = jaxlib.xla_extension.PjitFunction


def _module_inline(func: JaxJittedFunction, *args, **kwargs) -> jModule:
"""Get the module from the jax.jitted function"""
return func.lower(*args, **kwargs).compiler_ir()


def module(func: JaxJittedFunction) -> Callable[Any, jModule]:
"""
Decorator for _module_inline
"""

@wraps(func)
def wrapper(*args, **kwargs) -> jModule:
return _module_inline(func, *args, **kwargs)

return wrapper


def _generic_inline(func: JaxJittedFunction, *args, **kwargs) -> str: # pragma: no cover
"""
Create the generic textual representation for the jax.jit'ed function
"""
lowered = func.lower(*args, **kwargs)
mod = lowered.compiler_ir()
return mod.operation.get_asm(binary=False, print_generic_op_form=True, assume_verified=True)


def generic(func: JaxJittedFunction) -> Callable[Any, str]: # pragma: no cover
"""
Decorator for _generic_inline.
"""

@wraps(func)
def wrapper(*args, **kwargs) -> str:
return _generic_inline(func, *args, **kwargs)

return wrapper


def parse_generic_to_xdsl_module(program: str) -> xbuiltin.ModuleOp: # pragma: no cover
"""Parses generic MLIR program to xDSL module"""
ctx = xContext(allow_unregistered=True)
ctx.load_dialect(xarith.Arith)
ctx.load_dialect(xbuiltin.Builtin)
ctx.load_dialect(xfunc.Func)
ctx.load_dialect(xscf.Scf)
ctx.load_dialect(xstablehlo.StableHLO)
ctx.load_dialect(xtensor.Tensor)
ctx.load_dialect(xtransform.Transform)
moduleOp: xbuiltin.ModuleOp = xParser(ctx, program).parse_module()
return moduleOp


def parse_generic_to_jax_module(program: str) -> jModule: # pragma: no cover
"""Parses an MLIR program in string representation to a jax Module"""
with jContext() as ctx:
ctx.allow_unregistered_dialects = True
jstablehlo.register_dialect(ctx) # pylint: disable=no-member
return jModule.parse(program)


def jax_from_docstring(func: Callable) -> jModule: # pragma: no cover
"""Parses an MLIR program in string representation located in the docstring."""

@wraps(func)
def wrapper(*_, **__):
return parse_generic_to_jax_module(func.__doc__)

return wrapper


def _xdsl_module_inline(
func: JaxJittedFunction, *args, **kwargs
) -> xbuiltin.ModuleOp: # pragma: no cover
generic_repr = _generic_inline(func, *args, **kwargs)
return parse_generic_to_xdsl_module(generic_repr)


def xdsl_from_docstring(func: Callable) -> xbuiltin.ModuleOp: # pragma: no cover
"""Parses a docstring into an xdsl module"""

@wraps(func)
def wrapper(*_, **__):
return parse_generic_to_xdsl_module(func.__doc__)

return wrapper


def xdsl_module(func: JaxJittedFunction) -> Callable[Any, xbuiltin.ModuleOp]: # pragma: no cover
"""
Decorator for _xdsl_module_inline
"""

@wraps(func)
def wrapper(*args, **kwargs) -> xbuiltin.ModuleOp:
return _xdsl_module_inline(func, *args, **kwargs)

return wrapper
24 changes: 24 additions & 0 deletions pennylane/compiler/python_compiler/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PennyLane-xDSL transformations API."""

from xdsl.transforms.transform_interpreter import TransformInterpreterPass
from .apply_transform_sequence import ApplyTransformSequence, register_pass


__all__ = [
"ApplyTransformSequence",
"TransformInterpreterPass",
"register_pass",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file contains the pass that applies all passes present in the program representation."""


from dataclasses import dataclass

from xdsl.context import Context
from xdsl.dialects import builtin
from xdsl.passes import ModulePass, PipelinePass

from .transform_interpreter import TransformInterpreterPass # pylint: disable=no-name-in-module

available_passes = {}


def register_pass(name, _callable):
"""Registers the passes available in the dictionary"""
available_passes[name] = _callable # pragma: no cover


# pylint: disable=too-few-public-methods
@dataclass(frozen=True)
class ApplyTransformSequence(ModulePass):
"""
Looks for nested modules. Nested modules in this context are guaranteed to correspond
to qnodes. These modules are already annotated with which passes are to be executed.
The pass ApplyTransformSequence will run passes annotated in the qnode modules.

At the end, we delete the list of passes as they have already been applied.
"""

name = "apply-transform-sequence"

def apply( # pylint: disable=arguments-renamed,no-self-use
self, ctx: Context, module: builtin.ModuleOp
) -> None:
"""Applies the transformation"""
nested_modules = []
for region in module.regions:
for block in region.blocks:
for op in block.ops:
if isinstance(op, builtin.ModuleOp):
nested_modules.append(op)

pipeline = PipelinePass(
# pylint: disable-next=unexpected-keyword-arg
(TransformInterpreterPass(passes=available_passes),)
)
for op in nested_modules:
pipeline.apply(ctx, op)

for mod in nested_modules:
for region in mod.regions:
for block in region.blocks:
for op in block.ops:
if isinstance(op, builtin.ModuleOp) and op.get_attr_or_prop(
"transform.with_named_sequence"
):
block.erase_op(op) # pragma: no cover
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Public API"""

from .transform_interpreter_catalyst import TransformInterpreterPass

__all__ = ["TransformInterpreterPass"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Public API"""

from .impl import TransformFunctionsExt

__all__ = ["TransformFunctionsExt"]
Loading