Skip to content

[Decomposition] Custom decomposition rules for symbolic operators #7347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
925f8a2
[Decomposition] Custom decomposition rules for symbolic operators
astralcai Apr 29, 2025
fa09eac
fix doc?
astralcai Apr 29, 2025
95c0383
tests and bug fix
astralcai Apr 29, 2025
8f0a0de
update changelog
astralcai Apr 29, 2025
5b074ed
add more tests
astralcai Apr 29, 2025
ff90882
one more test case
astralcai Apr 29, 2025
9a57289
Merge branch 'master' into symbolic-rules-01
astralcai Apr 29, 2025
956b9a2
pylint
astralcai Apr 29, 2025
a8ccb91
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai Apr 29, 2025
32dba0b
pylint
astralcai Apr 29, 2025
edff8d4
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai Apr 29, 2025
26a314d
Apply suggestions from code review
astralcai May 1, 2025
02f08f2
Apply suggestions from code review
astralcai May 1, 2025
fa54ba2
lol
astralcai May 1, 2025
b6d6dbe
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai May 2, 2025
a85a40f
minor fix
astralcai May 2, 2025
c18b60d
Merge branch 'master' into symbolic-rules-01
astralcai May 2, 2025
5db0fc2
get rid of funny business
astralcai May 5, 2025
28e0f0d
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai May 5, 2025
8bd5e8f
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai May 5, 2025
74604ab
minor rename
astralcai May 7, 2025
220f671
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai May 7, 2025
d12873d
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai May 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 55 additions & 21 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,32 +91,66 @@
qml.RZ(omega, wires=wires[0])
qml.GlobalPhase(-phase)

# This decomposition will be ignored for `QubitUnitary` on more than one wire.
qml.add_decomps(QubitUnitary, zyz_decomposition)
```

This decomposition will be ignored for `QubitUnitary` on more than one wire.

* The :func:`~.transforms.decompose` transform now supports symbolic operators (e.g., `Adjoint` and `Controlled`) specified as strings in the `gate_set` argument
when the new graph-based decomposition system is enabled.
[(#7331)](https://github.com/PennyLaneAI/pennylane/pull/7331)

```python
from functools import partial
import pennylane as qml
* Symbolic operator types (e.g., `Adjoint`, `Controlled`, and `Pow`) can now be specified as strings
in various parts of the new graph-based decomposition system, specifically:
* The `gate_set` argument of the :func:`~.transforms.decompose` transform now supports adding symbolic
operators in the target gate set.
[(#7331)](https://github.com/PennyLaneAI/pennylane/pull/7331)
```python
from functools import partial
import pennylane as qml

qml.decomposition.enable_graph()
qml.decomposition.enable_graph()

@partial(qml.transforms.decompose, gate_set={"T", "Adjoint(T)", "H", "CNOT"})
@qml.qnode(qml.device("default.qubit"))
def circuit():
qml.Toffoli(wires=[0, 1, 2])
```
```pycon
>>> print(qml.draw(circuit)())
0: ───────────╭●───────────╭●────╭●──T──╭●─┤
1: ────╭●─────│─────╭●─────│───T─╰X──T†─╰X─┤
2: ──H─╰X──T†─╰X──T─╰X──T†─╰X──T──H────────┤
```
@partial(qml.transforms.decompose, gate_set={"T", "Adjoint(T)", "H", "CNOT"})
@qml.qnode(qml.device("default.qubit"))
def circuit():
qml.Toffoli(wires=[0, 1, 2])
```
```pycon
>>> print(qml.draw(circuit)())
0: ───────────╭●───────────╭●────╭●──T──╭●─┤
1: ────╭●─────│─────╭●─────│───T─╰X──T†─╰X─┤
2: ──H─╰X──T†─╰X──T─╰X──T†─╰X──T──H────────┤
```
* Symbolic operator types can now be given as strings to the `op_type` argument of :func:`~.decomposition.add_decomps`,
or as keys of the dictionaries passed to the `alt_decomps` and `fixed_decomps` arguments of the
:func:`~.transforms.decompose` transform, allowing custom decomposition rules to be defined and
registered for symbolic operators.
[(#7347)](https://github.com/PennyLaneAI/pennylane/pull/7347)
```python
@qml.register_resources({qml.RY: 1})
def my_adjoint_ry(phi, wires, **_):
qml.RY(-phi, wires=wires)

@qml.register_resources({qml.RX: 1})
def my_adjoint_rx(phi, wires, **__):
qml.RX(-phi, wires)

# Registers a decomposition rule for the adjoint of RY globally
qml.add_decomps("Adjoint(RY)", my_adjoint_ry)

@partial(
qml.transforms.decompose,
gate_set={"RX", "RY", "CNOT"},
fixed_decomps={"Adjoint(RX)": my_adjoint_rx}
)
@qml.qnode(qml.device("default.qubit"))
def circuit():
qml.adjoint(qml.RX(0.5, wires=[0]))
qml.CNOT(wires=[0, 1])
qml.adjoint(qml.RY(0.5, wires=[1]))
return qml.expval(qml.Z(0))
```
```pycon
>>> print(qml.draw(circuit)())
0: ──RX(-0.50)─╭●────────────┤ <Z>
1: ────────────╰X──RY(-0.50)─┤
```

<h3>Improvements 🛠</h3>

Expand Down
141 changes: 69 additions & 72 deletions pennylane/decomposition/decomposition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@
from .resources import CompressedResourceOp, Resources, resource_rep
from .symbolic_decomposition import (
AdjointDecomp,
adjoint_adjoint_decomp,
adjoint_controlled_decomp,
adjoint_pow_decomp,
pow_decomp,
pow_pow_decomp,
cancel_adjoint,
merge_powers,
repeat_pow_base,
same_type_adjoint_decomp,
same_type_adjoint_ops,
)
Expand Down Expand Up @@ -136,25 +136,42 @@ def __init__(

# Tracks the node indices of various operators.
self._original_ops_indices: set[int] = set()
self._target_ops_indices: set[int] = set()
self._all_op_indices: dict[CompressedResourceOp, int] = {}

# Stores the library of custom decomposition rules
self._fixed_decomps = fixed_decomps or {}
self._alt_decomps = alt_decomps or {}
fixed_decomps = fixed_decomps or {}
alt_decomps = alt_decomps or {}
self._fixed_decomps = {_to_name(k): v for k, v in fixed_decomps.items()}
self._alt_decomps = {_to_name(k): v for k, v in alt_decomps.items()}

# Initializes the graph.
self._graph = rx.PyDiGraph()
self._visitor = None

# Construct the decomposition graph
self._start = self._graph.add_node(None)
self._construct_graph(operations)

def _get_decompositions(self, op_type) -> list[DecompositionRule]:
def _get_decompositions(self, op_node: CompressedResourceOp) -> list[DecompositionRule]:
"""Helper function to get a list of decomposition rules."""
if op_type in self._fixed_decomps:
return [self._fixed_decomps[op_type]]
return self._alt_decomps.get(op_type, []) + list_decomps(op_type)

op_name = _to_name(op_node)

if op_name in self._fixed_decomps:
return [self._fixed_decomps[op_name]]

decomps = self._alt_decomps.get(op_name, []) + list_decomps(op_name)

if issubclass(op_node.op_type, qml.ops.Adjoint):
decomps.extend(self._get_adjoint_decompositions(op_node))

elif issubclass(op_node.op_type, qml.ops.Pow):
decomps.extend(self._get_pow_decompositions(op_node))

elif op_node.op_type in (qml.ops.Controlled, qml.ops.ControlledOp):
decomps.extend(self._get_controlled_decompositions(op_node))

return decomps

def _construct_graph(self, operations):
"""Constructs the decomposition graph."""
Expand All @@ -179,20 +196,10 @@ def _recursively_add_op_node(self, op_node: CompressedResourceOp) -> int:
self._all_op_indices[op_node] = op_node_idx

if op_node.name in self._gate_set:
self._target_ops_indices.add(op_node_idx)
self._graph.add_edge(self._start, op_node_idx, 1)
return op_node_idx

if op_node.op_type in (qml.ops.Controlled, qml.ops.ControlledOp):
# This branch only applies to general controlled operators
return self._add_controlled_decomp_node(op_node, op_node_idx)

if issubclass(op_node.op_type, qml.ops.Adjoint):
return self._add_adjoint_decomp_node(op_node, op_node_idx)

if issubclass(op_node.op_type, qml.ops.Pow):
return self._add_pow_decomp_node(op_node, op_node_idx)

for decomposition in self._get_decompositions(op_node.op_type):
for decomposition in self._get_decompositions(op_node):
self._add_decomp_rule_to_op(decomposition, op_node, op_node_idx)

return op_node_idx
Expand All @@ -208,88 +215,67 @@ def _add_decomp_rule_to_op(
except DecompositionNotApplicable:
pass # ignore decompositions that are not applicable to the given op params.

def _add_adjoint_decomp_node(self, op_node: CompressedResourceOp, op_node_idx: int) -> int:
"""Adds an adjoint decomposition node."""
def _get_adjoint_decompositions(self, op_node: CompressedResourceOp) -> list[DecompositionRule]:
"""Retrieves a list of decomposition rules for an adjoint operator."""

base_class, base_params = op_node.params["base_class"], op_node.params["base_params"]
base_class, base_params = (op_node.params["base_class"], op_node.params["base_params"])

if issubclass(base_class, qml.ops.Adjoint):
rule = adjoint_adjoint_decomp
self._add_decomp_rule_to_op(rule, op_node, op_node_idx)
return op_node_idx
return [cancel_adjoint]

if (
issubclass(base_class, qml.ops.Pow)
and base_params["base_class"] in same_type_adjoint_ops()
):
rule = adjoint_pow_decomp
self._add_decomp_rule_to_op(rule, op_node, op_node_idx)
return op_node_idx
return [adjoint_pow_decomp]

if base_class in same_type_adjoint_ops():
rule = same_type_adjoint_decomp
self._add_decomp_rule_to_op(rule, op_node, op_node_idx)
return op_node_idx
return [same_type_adjoint_decomp]

if (
issubclass(base_class, qml.ops.Controlled)
and base_params["base_class"] in same_type_adjoint_ops()
):
rule = adjoint_controlled_decomp
self._add_decomp_rule_to_op(rule, op_node, op_node_idx)
return op_node_idx
return [adjoint_controlled_decomp]

for base_decomposition in self._get_decompositions(base_class):
rule = AdjointDecomp(base_decomposition)
self._add_decomp_rule_to_op(rule, op_node, op_node_idx)

return op_node_idx
base_rep = resource_rep(base_class, **base_params)
return [AdjointDecomp(base_rule) for base_rule in self._get_decompositions(base_rep)]

def _add_pow_decomp_node(self, op_node: CompressedResourceOp, op_node_idx: int) -> int:
"""Adds a power decomposition node to the graph."""
@staticmethod
def _get_pow_decompositions(op_node: CompressedResourceOp) -> list[DecompositionRule]:
"""Retrieves a list of decomposition rules for a power operator."""

base_class = op_node.params["base_class"]

if issubclass(base_class, qml.ops.Pow):
rule = pow_pow_decomp
self._add_decomp_rule_to_op(rule, op_node, op_node_idx)
return op_node_idx
return [merge_powers]

rule = pow_decomp
self._add_decomp_rule_to_op(rule, op_node, op_node_idx)
return op_node_idx
return [repeat_pow_base]

def _add_controlled_decomp_node(self, op_node: CompressedResourceOp, op_node_idx: int) -> int:
def _get_controlled_decompositions(
self, op_node: CompressedResourceOp
) -> list[DecompositionRule]:
"""Adds a controlled decomposition node to the graph."""

base_class = op_node.params["base_class"]
num_control_wires = op_node.params["num_control_wires"]

# Handle controlled global phase
if base_class is qml.GlobalPhase:
rule = controlled_global_phase_decomp
self._add_decomp_rule_to_op(rule, op_node, op_node_idx)
return op_node_idx
return [controlled_global_phase_decomp]

# Handle controlled-X gates
if base_class is qml.X:
rule = controlled_x_decomp
self._add_decomp_rule_to_op(rule, op_node, op_node_idx)
return op_node_idx
return [controlled_x_decomp]

# Handle custom controlled ops
if (base_class, num_control_wires) in base_to_custom_ctrl_op():
custom_op_type = base_to_custom_ctrl_op()[(base_class, num_control_wires)]
rule = CustomControlledDecomposition(custom_op_type)
self._add_decomp_rule_to_op(rule, op_node, op_node_idx)
return op_node_idx
return [CustomControlledDecomposition(custom_op_type)]

# General case
for base_decomposition in self._get_decompositions(base_class):
rule = ControlledBaseDecomposition(base_decomposition)
self._add_decomp_rule_to_op(rule, op_node, op_node_idx)

return op_node_idx
base_rep = resource_rep(base_class, **op_node.params["base_params"])
return [ControlledBaseDecomposition(rule) for rule in self._get_decompositions(base_rep)]

def _recursively_add_decomposition_node(
self, rule: DecompositionRule, decomp_resource: Resources
Expand All @@ -303,6 +289,11 @@ def _recursively_add_decomposition_node(

d_node = _DecompositionNode(rule, decomp_resource)
d_node_idx = self._graph.add_node(d_node)
if not decomp_resource.gate_counts:
# If an operator decomposes to nothing (e.g., a Hadamard raised to a
# power of 2), we must still connect something to this decomposition
# node so that it is accounted for.
self._graph.add_edge(self._start, d_node_idx, 0)
for op in decomp_resource.gate_counts:
op_node_idx = self._recursively_add_op_node(op)
self._graph.add_edge(op_node_idx, d_node_idx, (op_node_idx, d_node_idx))
Expand All @@ -318,17 +309,12 @@ def solve(self, lazy=True):

"""
self._visitor = _DecompositionSearchVisitor(self._graph, self._original_ops_indices, lazy)
start = self._graph.add_node("dummy")
self._graph.add_edges_from(
[(start, op_node_idx, 1) for op_node_idx in self._target_ops_indices]
)
rx.dijkstra_search(
self._graph,
source=[start],
source=[self._start],
weight_fn=self._visitor.edge_weight,
visitor=self._visitor,
)
self._graph.remove_node(start)
if self._visitor.unsolved_op_indices:
unsolved_ops = [self._graph[op_idx] for op_idx in self._visitor.unsolved_op_indices]
op_names = set(op.name for op in unsolved_ops)
Expand Down Expand Up @@ -466,6 +452,8 @@ def examine_edge(self, edge):
return # nothing is to be done for edges leading to an operator node
if target_idx not in self.distances:
self.distances[target_idx] = Resources() # initialize with empty resource
if src_node is None:
return # special case for when the decomposition produces nothing
self.distances[target_idx] += self.distances[src_idx] * target_node.count(src_node)
if target_idx not in self._num_edges_examined:
self._num_edges_examined[target_idx] = 0
Expand All @@ -481,7 +469,7 @@ def edge_relaxed(self, edge):
"""Triggered when an edge is relaxed during the Dijkstra search."""
src_idx, target_idx, _ = edge
target_node = self._graph[target_idx]
if self._graph[src_idx] == "dummy":
if self._graph[src_idx] is None and not isinstance(target_node, _DecompositionNode):
self.distances[target_idx] = Resources({target_node: 1})
elif isinstance(target_node, CompressedResourceOp):
self.predecessors[target_idx] = src_idx
Expand All @@ -498,3 +486,12 @@ class _DecompositionNode:
def count(self, op: CompressedResourceOp):
"""Find the number of occurrences of an operator in the decomposition."""
return self.decomp_resource.gate_counts.get(op, 0)


def _to_name(op):
if isinstance(op, type):
return op.__name__
if isinstance(op, CompressedResourceOp):
return op.name
assert isinstance(op, str)
return translate_op_alias(op)
Loading