From 925f8a27f69c5bcae1d2163400eeb17431a4978a Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Tue, 29 Apr 2025 10:42:38 -0400 Subject: [PATCH 01/14] [Decomposition] Custom decomposition rules for symbolic operators --- doc/releases/changelog-dev.md | 75 ++++++++--- .../decomposition/decomposition_graph.py | 125 +++++++++--------- pennylane/decomposition/decomposition_rule.py | 42 ++++-- .../decomposition/symbolic_decomposition.py | 6 +- pennylane/decomposition/utils.py | 2 +- tests/decomposition/conftest.py | 22 +-- tests/decomposition/test_decomp_utils.py | 1 + .../decomposition/test_decomposition_graph.py | 14 +- .../decomposition/test_decomposition_rule.py | 2 +- .../test_symbolic_decomposition.py | 28 ++-- 10 files changed, 183 insertions(+), 134 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 3e5e44d6d36..59e7d50e3ee 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -92,32 +92,65 @@ qml.RZ(omega, wires=wires[0]) qml.GlobalPhase(-phase) + # This decomposition will be ignored for `QubitUnitary` on more than one wire. qml.add_decomps(QubitUnitary, zyz_decomposition) ``` - - This decomposition will be ignored for `QubitUnitary` on more than one wire. - -* The :func:`~.transforms.decompose` transform now supports symbolic operators (e.g., `Adjoint` and `Controlled`) specified as strings in the `gate_set` argument - when the new graph-based decomposition system is enabled. - [(#7331)](https://github.com/PennyLaneAI/pennylane/pull/7331) - ```python - from functools import partial - import pennylane as qml +* Symbolic operator types (e.g., `Adjoint`, `Controlled`, and `Pow`) can now be specified as strings + in various parts of the new graph-based decomposition system, specifically: + * The `gate_set` argument of the :func:`~.transforms.decompose` transform now supports adding symbolic + operators to the target gate set. + [(#7331)](https://github.com/PennyLaneAI/pennylane/pull/7331) + ```python + from functools import partial + import pennylane as qml - qml.decomposition.enable_graph() + qml.decomposition.enable_graph() - @partial(qml.transforms.decompose, gate_set={"T", "Adjoint(T)", "H", "CNOT"}) - @qml.qnode(qml.device("default.qubit")) - def circuit(): - qml.Toffoli(wires=[0, 1, 2]) - ``` - ```pycon - >>> print(qml.draw(circuit)()) - 0: ───────────╭●───────────╭●────╭●──T──╭●─┤ - 1: ────╭●─────│─────╭●─────│───T─╰X──T†─╰X─┤ - 2: ──H─╰X──T†─╰X──T─╰X──T†─╰X──T──H────────┤ - ``` + @partial(qml.transforms.decompose, gate_set={"T", "Adjoint(T)", "H", "CNOT"}) + @qml.qnode(qml.device("default.qubit")) + def circuit(): + qml.Toffoli(wires=[0, 1, 2]) + ``` + ```pycon + >>> print(qml.draw(circuit)()) + 0: ───────────╭●───────────╭●────╭●──T──╭●─┤ + 1: ────╭●─────│─────╭●─────│───T─╰X──T†─╰X─┤ + 2: ──H─╰X──T†─╰X──T─╰X──T†─╰X──T──H────────┤ + ``` + * Symbolic operator types can now be given as strings to the `op_type` argument of :func:`~.add_decomps`, + or as keys of the dictionaries passed to the `alt_decomps` and `fixed_decomps` arguments of the + :func:`~.transforms.decompose` transform, allowing custom decomposition rules to be defined and + registered for symbolic operators. + ```python + @register_resources({qml.RY: 1}) + def my_adjoint_ry(phi, wires, **_): + qml.RY(-phi, wires=wires) + + @qml.register_resources({qml.RX: 1}) + def my_adjoint_rx(phi, wires, **__): + qml.RX(-phi, wires) + + # Registers a decomposition rule for the adjoint of RY globally + qml.add_decomps("Adjoint(RY)", my_adjoint_ry) + + @partial( + qml.transforms.decompose, + gate_set={"RX", "CNOT"}, + fixed_decomps={"Adjoint(RX)": my_adjoint_rx} + ) + @qml.qnode(qml.device("default.qubit")) + def circuit(): + qml.adjoint(qml.RX(0.5), wires=[0]) + qml.CNOT(wires=[0, 1]) + qml.adjoint(qml.RY(0.5), wires=[1]) + return qml.expval(qml.Z(0)) + ``` + ```pycon + >>> print(qml.draw(circuit)()) + 0: ──RX(-0.50)─╭●────────────┤ + 1: ────────────╰X──RY(-0.50)─┤ + ```

Improvements 🛠

diff --git a/pennylane/decomposition/decomposition_graph.py b/pennylane/decomposition/decomposition_graph.py index cdbbba327f7..2a28b40d804 100644 --- a/pennylane/decomposition/decomposition_graph.py +++ b/pennylane/decomposition/decomposition_graph.py @@ -44,11 +44,11 @@ from .resources import CompressedResourceOp, Resources, resource_rep from .symbolic_decomposition import ( AdjointDecomp, - adjoint_adjoint_decomp, adjoint_controlled_decomp, adjoint_pow_decomp, - pow_decomp, - pow_pow_decomp, + cancel_adjoint, + merge_powers, + repeat_pow_base, same_type_adjoint_decomp, same_type_adjoint_ops, ) @@ -140,8 +140,10 @@ def __init__( self._all_op_indices: dict[CompressedResourceOp, int] = {} # Stores the library of custom decomposition rules - self._fixed_decomps = fixed_decomps or {} - self._alt_decomps = alt_decomps or {} + fixed_decomps = fixed_decomps or {} + alt_decomps = alt_decomps or {} + self._fixed_decomps = {_to_name(k): v for k, v in fixed_decomps.items()} + self._alt_decomps = {_to_name(k): v for k, v in alt_decomps.items()} # Initializes the graph. self._graph = rx.PyDiGraph() @@ -150,11 +152,26 @@ def __init__( # Construct the decomposition graph self._construct_graph(operations) - def _get_decompositions(self, op_type) -> list[DecompositionRule]: + def _get_decompositions(self, op: CompressedResourceOp) -> list[DecompositionRule]: """Helper function to get a list of decomposition rules.""" - if op_type in self._fixed_decomps: - return [self._fixed_decomps[op_type]] - return self._alt_decomps.get(op_type, []) + list_decomps(op_type) + + op_name = _to_name(op) + + if op_name in self._fixed_decomps: + return [self._fixed_decomps[op_name]] + + decomps = self._alt_decomps.get(op_name, []) + list_decomps(op_name) + + if issubclass(op.op_type, qml.ops.Adjoint): + decomps.extend(self._get_adjoint_decompositions(op)) + + elif issubclass(op.op_type, qml.ops.Pow): + decomps.extend(self._get_pow_decompositions(op)) + + elif op.op_type in (qml.ops.Controlled, qml.ops.ControlledOp): + decomps.extend(self._get_controlled_decompositions(op)) + + return decomps def _construct_graph(self, operations): """Constructs the decomposition graph.""" @@ -182,17 +199,7 @@ def _recursively_add_op_node(self, op_node: CompressedResourceOp) -> int: self._target_ops_indices.add(op_node_idx) return op_node_idx - if op_node.op_type in (qml.ops.Controlled, qml.ops.ControlledOp): - # This branch only applies to general controlled operators - return self._add_controlled_decomp_node(op_node, op_node_idx) - - if issubclass(op_node.op_type, qml.ops.Adjoint): - return self._add_adjoint_decomp_node(op_node, op_node_idx) - - if issubclass(op_node.op_type, qml.ops.Pow): - return self._add_pow_decomp_node(op_node, op_node_idx) - - for decomposition in self._get_decompositions(op_node.op_type): + for decomposition in self._get_decompositions(op_node): self._add_decomp_rule_to_op(decomposition, op_node, op_node_idx) return op_node_idx @@ -208,88 +215,65 @@ def _add_decomp_rule_to_op( except DecompositionNotApplicable: pass # ignore decompositions that are not applicable to the given op params. - def _add_adjoint_decomp_node(self, op_node: CompressedResourceOp, op_node_idx: int) -> int: - """Adds an adjoint decomposition node.""" + def _get_adjoint_decompositions(self, op: CompressedResourceOp) -> list[DecompositionRule]: + """Retrieves a list of decomposition rules for an adjoint operator.""" - base_class, base_params = op_node.params["base_class"], op_node.params["base_params"] + base_class, base_params = op.params["base_class"], op.params["base_params"] if issubclass(base_class, qml.ops.Adjoint): - rule = adjoint_adjoint_decomp - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - return op_node_idx + return [cancel_adjoint] if ( issubclass(base_class, qml.ops.Pow) and base_params["base_class"] in same_type_adjoint_ops() ): - rule = adjoint_pow_decomp - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - return op_node_idx + return [adjoint_pow_decomp] if base_class in same_type_adjoint_ops(): - rule = same_type_adjoint_decomp - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - return op_node_idx + return [same_type_adjoint_decomp] if ( issubclass(base_class, qml.ops.Controlled) and base_params["base_class"] in same_type_adjoint_ops() ): - rule = adjoint_controlled_decomp - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - return op_node_idx - - for base_decomposition in self._get_decompositions(base_class): - rule = AdjointDecomp(base_decomposition) - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) + return [adjoint_controlled_decomp] - return op_node_idx + base_rep = resource_rep(base_class, **base_params) + return [AdjointDecomp(base_rule) for base_rule in self._get_decompositions(base_rep)] - def _add_pow_decomp_node(self, op_node: CompressedResourceOp, op_node_idx: int) -> int: - """Adds a power decomposition node to the graph.""" + @staticmethod + def _get_pow_decompositions(op: CompressedResourceOp) -> list[DecompositionRule]: + """Retrieves a list of decomposition rules for a power operator.""" - base_class = op_node.params["base_class"] + base_class = op.params["base_class"] if issubclass(base_class, qml.ops.Pow): - rule = pow_pow_decomp - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - return op_node_idx + return [merge_powers] - rule = pow_decomp - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - return op_node_idx + return [repeat_pow_base] - def _add_controlled_decomp_node(self, op_node: CompressedResourceOp, op_node_idx: int) -> int: + def _get_controlled_decompositions(self, op: CompressedResourceOp) -> list[DecompositionRule]: """Adds a controlled decomposition node to the graph.""" - base_class = op_node.params["base_class"] - num_control_wires = op_node.params["num_control_wires"] + base_class = op.params["base_class"] + num_control_wires = op.params["num_control_wires"] # Handle controlled global phase if base_class is qml.GlobalPhase: - rule = controlled_global_phase_decomp - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - return op_node_idx + return [controlled_global_phase_decomp] # Handle controlled-X gates if base_class is qml.X: - rule = controlled_x_decomp - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - return op_node_idx + return [controlled_x_decomp] # Handle custom controlled ops if (base_class, num_control_wires) in base_to_custom_ctrl_op(): custom_op_type = base_to_custom_ctrl_op()[(base_class, num_control_wires)] - rule = CustomControlledDecomposition(custom_op_type) - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - return op_node_idx + return [CustomControlledDecomposition(custom_op_type)] # General case - for base_decomposition in self._get_decompositions(base_class): - rule = ControlledBaseDecomposition(base_decomposition) - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - - return op_node_idx + base_rep = resource_rep(base_class, **op.params["base_params"]) + return [ControlledBaseDecomposition(rule) for rule in self._get_decompositions(base_rep)] def _recursively_add_decomposition_node( self, rule: DecompositionRule, decomp_resource: Resources @@ -498,3 +482,12 @@ class _DecompositionNode: def count(self, op: CompressedResourceOp): """Find the number of occurrences of an operator in the decomposition.""" return self.decomp_resource.gate_counts.get(op, 0) + + +def _to_name(op): + if isinstance(op, type): + return op.__name__ + if isinstance(op, CompressedResourceOp): + return op.name + assert isinstance(op, str) + return translate_op_alias(op) diff --git a/pennylane/decomposition/decomposition_rule.py b/pennylane/decomposition/decomposition_rule.py index 9d1f79b24a3..9625c4973de 100644 --- a/pennylane/decomposition/decomposition_rule.py +++ b/pennylane/decomposition/decomposition_rule.py @@ -24,6 +24,7 @@ from pennylane.operation import Operator from .resources import CompressedResourceOp, Resources, resource_rep +from .utils import translate_op_alias @overload @@ -191,7 +192,7 @@ def __init__(self, func: Callable, resources: Callable | dict): self._source = inspect.getsource(func) if isinstance(resources, dict): - def resource_fn(): + def resource_fn(*_, **__): return resources self._compute_resources = resource_fn @@ -233,10 +234,10 @@ def _auto_wrap(op_type): _decompositions = defaultdict(list) -"""dict[type, list[DecompositionRule]]: A dictionary mapping operator types to decomposition rules.""" +"""dict[str, list[DecompositionRule]]: A dictionary mapping operator names to decomposition rules.""" -def add_decomps(op_type: Type[Operator], *decomps: DecompositionRule) -> None: +def add_decomps(op_type: Type[Operator] | str, *decomps: DecompositionRule) -> None: """Globally registers new decomposition rules with an operator class. .. note:: @@ -253,7 +254,8 @@ def add_decomps(op_type: Type[Operator], *decomps: DecompositionRule) -> None: decomposition rules that may be chosen if they lead to a more resource-efficient decomposition. Args: - op_type: the operator type for which new decomposition rules are specified. + op_type (type or str): the operator type for which new decomposition rules are specified. + For symbolic operators, use strings such as ``"Adjoint(RY)"``, ``"Pow(H)"``, ``"C(RX)"``, etc. decomps (DecompositionRule): new decomposition rules to add to the given ``op_type``. A decomposition is a quantum function registered with a resource estimate using ``qml.register_resources``. @@ -289,6 +291,17 @@ def my_hadamard2(wires): for the duration of the session. To add alternative decompositions for a particular circuit as opposed to globally, use the ``alt_decomps`` argument of the :func:`~pennylane.transforms.decompose` transform. + Custom decomposition rules can also be specified for symbolic operators. In this case, the + operator type can be specified as a string. For example, + + .. code-block:: python + + @register_resources({qml.RY: 1}) + def adjoint_ry(phi, wires, **_): + qml.RY(-phi, wires=wires) + + qml.add_decomps("Adjoint(RY)", adjoint_ry) + .. seealso:: :func:`~pennylane.transforms.decompose` """ @@ -297,10 +310,12 @@ def my_hadamard2(wires): "A decomposition rule must be a qfunc with a resource estimate " "registered using qml.register_resources" ) - _decompositions[op_type].extend(decomps) + if isinstance(op_type, type): + op_type = op_type.__name__ + _decompositions[translate_op_alias(op_type)].extend(decomps) -def list_decomps(op_type: Type[Operator]) -> list[DecompositionRule]: +def list_decomps(op_type: Type[Operator] | str) -> list[DecompositionRule]: """Lists all stored decomposition rules for an operator class. .. note:: @@ -311,7 +326,8 @@ def list_decomps(op_type: Type[Operator]) -> list[DecompositionRule]: decomposition rules for an operator. Args: - op_type: the operator class to retrieve decomposition rules for. + op_type (type or str): the operator class to retrieve decomposition rules for. For symbolic + operators, use strings such as ``"Adjoint(RY)"``, ``"Pow(H)"``, ``"C(RX)"``, etc. Returns: list[DecompositionRule]: a list of decomposition rules registered for the given operator. @@ -338,10 +354,12 @@ def _crx_to_rx_cz(phi, wires, **__): 1: ──RX(0.25)─╰Z──RX(-0.25)─╰Z─┤ """ - return _decompositions[op_type][:] + if isinstance(op_type, type): + op_type = op_type.__name__ + return _decompositions[translate_op_alias(op_type)][:] -def has_decomp(op_type: Type[Operator]) -> bool: +def has_decomp(op_type: Type[Operator] | str) -> bool: """Checks whether an operator has decomposition rules defined. .. note:: @@ -352,10 +370,14 @@ def has_decomp(op_type: Type[Operator]) -> bool: decomposition rules for an operator. Args: - op_type: the operator class to check for decomposition rules. + op_type (type or str): the operator class to check for decomposition rules. For symbolic + operators, use strings such as ``"Adjoint(RY)"``, ``"Pow(H)"``, ``"C(RX)"``, etc. Returns: bool: whether decomposition rules are defined for the given operator. """ + if isinstance(op_type, type): + op_type = op_type.__name__ + op_type = translate_op_alias(op_type) return op_type in _decompositions and len(_decompositions[op_type]) > 0 diff --git a/pennylane/decomposition/symbolic_decomposition.py b/pennylane/decomposition/symbolic_decomposition.py index e9d6ada4997..dc4acf40084 100644 --- a/pennylane/decomposition/symbolic_decomposition.py +++ b/pennylane/decomposition/symbolic_decomposition.py @@ -77,7 +77,7 @@ def _adjoint_adjoint_resource(*_, base_params, **__): @register_resources(_adjoint_adjoint_resource) -def adjoint_adjoint_decomp(*params, wires, base): # pylint: disable=unused-argument +def cancel_adjoint(*params, wires, base): # pylint: disable=unused-argument """Decompose the adjoint of the adjoint of a gate.""" _, [_, metadata] = base.base._flatten() # pylint: disable=protected-access new_struct = wires, metadata @@ -157,7 +157,7 @@ def _pow_resource(base_class, base_params, z): @register_resources(_pow_resource) -def pow_decomp(*_, base, z, **__): +def repeat_pow_base(*_, base, z, **__): """Decompose the power of a gate.""" assert isinstance(z, int) and z >= 0 for _ in range(z): @@ -175,7 +175,7 @@ def _pow_pow_resource(base_class, base_params, z): # pylint: disable=unused-arg @register_resources(_pow_pow_resource) -def pow_pow_decomp(*_, base, z, **__): +def merge_powers(*_, base, z, **__): """Decompose the power of the power of a gate.""" qml.pow(base.base, z=z * base.z) diff --git a/pennylane/decomposition/utils.py b/pennylane/decomposition/utils.py index 3203ee217d2..b911aacd90f 100644 --- a/pennylane/decomposition/utils.py +++ b/pennylane/decomposition/utils.py @@ -41,7 +41,7 @@ def translate_op_alias(op_alias): """Translates an operator alias to its proper name.""" if op_alias in OP_NAME_ALIASES: return OP_NAME_ALIASES[op_alias] - if match := re.match(r"C\((\w+)\)", op_alias): + if match := re.match(r"(?:C|Controlled)\((\w+)\)", op_alias): base_op_name = match.group(1) return f"C({translate_op_alias(base_op_name)})" if match := re.match(r"Adjoint\((\w+)\)", op_alias): diff --git a/tests/decomposition/conftest.py b/tests/decomposition/conftest.py index b148d292144..540b0a14b71 100644 --- a/tests/decomposition/conftest.py +++ b/tests/decomposition/conftest.py @@ -35,7 +35,7 @@ def _cz_to_cnot(*_, **__): raise NotImplementedError -decompositions[qml.CZ] = [_cz_to_cnot] +decompositions["CZ"] = [_cz_to_cnot] @qml.register_resources({qml.Hadamard: 2, qml.CZ: 1}) @@ -43,7 +43,7 @@ def _cnot_to_cz(*_, **__): raise NotImplementedError -decompositions[qml.CNOT] = [_cnot_to_cz] +decompositions["CNOT"] = [_cnot_to_cz] def _multi_rz_decomposition_resources(num_wires): @@ -55,7 +55,7 @@ def _multi_rz_decomposition(*_, **__): raise NotImplementedError -decompositions[qml.MultiRZ] = [_multi_rz_decomposition] +decompositions["MultiRZ"] = [_multi_rz_decomposition] @qml.register_resources({qml.RZ: 2, qml.RX: 1, qml.GlobalPhase: 1}) @@ -68,7 +68,7 @@ def _hadamard_to_rz_ry(*_, **__): raise NotImplementedError -decompositions[qml.Hadamard] = [_hadamard_to_rz_rx, _hadamard_to_rz_ry] +decompositions["Hadamard"] = [_hadamard_to_rz_rx, _hadamard_to_rz_ry] @qml.register_resources({qml.RX: 1, qml.RZ: 2}) @@ -76,7 +76,7 @@ def _ry_to_rx_rz(*_, **__): raise NotImplementedError -decompositions[qml.RY] = [_ry_to_rx_rz] +decompositions["RY"] = [_ry_to_rx_rz] @qml.register_resources({qml.RX: 2, qml.CZ: 2}) @@ -84,7 +84,7 @@ def _crx_to_rx_cz(*_, **__): raise NotImplementedError -decompositions[qml.CRX] = [_crx_to_rx_cz] +decompositions["CRX"] = [_crx_to_rx_cz] @qml.register_resources({qml.RZ: 3, qml.CNOT: 2, qml.GlobalPhase: 1}) @@ -92,7 +92,7 @@ def _cphase_to_rz_cnot(*_, **__): raise NotImplementedError -decompositions[qml.ControlledPhaseShift] = [_cphase_to_rz_cnot] +decompositions["ControlledPhaseShift"] = [_cphase_to_rz_cnot] @qml.register_resources({qml.RZ: 1, qml.GlobalPhase: 1}) @@ -100,7 +100,7 @@ def _phase_shift_to_rz_gp(*_, **__): raise NotImplementedError -decompositions[qml.PhaseShift] = [_phase_shift_to_rz_gp] +decompositions["PhaseShift"] = [_phase_shift_to_rz_gp] @qml.register_resources({qml.RX: 1, qml.GlobalPhase: 1}) @@ -108,7 +108,7 @@ def _x_to_rx(*_, **__): raise NotImplementedError -decompositions[qml.X] = [_x_to_rx] +decompositions["PauliX"] = [_x_to_rx] @qml.register_resources({qml.PhaseShift: 1}) @@ -116,7 +116,7 @@ def _u1_ps(phi, wires, **__): qml.PhaseShift(phi, wires=wires) -decompositions[qml.U1] = [_u1_ps] +decompositions["U1"] = [_u1_ps] @qml.register_resources({qml.PhaseShift: 1}) @@ -124,4 +124,4 @@ def _t_ps(wires, **__): raise NotImplementedError -decompositions[qml.T] = [_t_ps] +decompositions["T"] = [_t_ps] diff --git a/tests/decomposition/test_decomp_utils.py b/tests/decomposition/test_decomp_utils.py index 22ef50a4b2d..48b02d7a908 100644 --- a/tests/decomposition/test_decomp_utils.py +++ b/tests/decomposition/test_decomp_utils.py @@ -63,5 +63,6 @@ def test_translate_op_alias(base_op_alias, expected_op_name): assert translate_op_alias(base_op_alias) == expected_op_name assert translate_op_alias(f"C({base_op_alias})") == f"C({expected_op_name})" + assert translate_op_alias(f"Controlled({base_op_alias})") == f"C({expected_op_name})" assert translate_op_alias(f"Adjoint({base_op_alias})") == f"Adjoint({expected_op_name})" assert translate_op_alias(f"Pow({base_op_alias})") == f"Pow({expected_op_name})" diff --git a/tests/decomposition/test_decomposition_graph.py b/tests/decomposition/test_decomposition_graph.py index 67c777464df..3b14c55c5e8 100644 --- a/tests/decomposition/test_decomposition_graph.py +++ b/tests/decomposition/test_decomposition_graph.py @@ -30,13 +30,15 @@ adjoint_resource_rep, controlled_resource_rep, pow_resource_rep, + resource_rep, ) +from pennylane.decomposition.decomposition_graph import _to_name @pytest.mark.unit @patch( "pennylane.decomposition.decomposition_graph.list_decomps", - side_effect=lambda x: decompositions[x], + side_effect=lambda x: decompositions[_to_name(x)], ) class TestDecompositionGraph: @@ -55,14 +57,14 @@ def custom_hadamard_2(wires): qml.RY(np.pi / 2, wires=wires) graph = DecompositionGraph(operations=[qml.Hadamard(0)], gate_set={"RX", "RY", "RZ"}) - assert graph._get_decompositions(qml.Hadamard) == decompositions[qml.Hadamard] + assert graph._get_decompositions(resource_rep(qml.H)) == decompositions["Hadamard"] graph = DecompositionGraph( operations=[qml.Hadamard(0)], gate_set={"RX", "RY", "RZ"}, fixed_decomps={qml.Hadamard: custom_hadamard}, ) - assert graph._get_decompositions(qml.Hadamard) == [custom_hadamard] + assert graph._get_decompositions(resource_rep(qml.H)) == [custom_hadamard] alt_dec = [custom_hadamard, custom_hadamard_2] graph = DecompositionGraph( @@ -70,8 +72,8 @@ def custom_hadamard_2(wires): gate_set={"RX", "RY", "RZ"}, alt_decomps={qml.Hadamard: alt_dec}, ) - exp_dec = alt_dec + decompositions[qml.Hadamard] - assert graph._get_decompositions(qml.Hadamard) == exp_dec + exp_dec = alt_dec + decompositions["Hadamard"] + assert graph._get_decompositions(resource_rep(qml.H)) == exp_dec graph = DecompositionGraph( operations=[qml.Hadamard(0)], @@ -79,7 +81,7 @@ def custom_hadamard_2(wires): alt_decomps={qml.Hadamard: alt_dec}, fixed_decomps={qml.Hadamard: custom_hadamard}, ) - assert graph._get_decompositions(qml.Hadamard) == [custom_hadamard] + assert graph._get_decompositions(resource_rep(qml.H)) == [custom_hadamard] def test_graph_construction(self, _): """Tests constructing a graph from a single Hadamard.""" diff --git a/tests/decomposition/test_decomposition_rule.py b/tests/decomposition/test_decomposition_rule.py index 22b251fe097..9c9c25007b3 100644 --- a/tests/decomposition/test_decomposition_rule.py +++ b/tests/decomposition/test_decomposition_rule.py @@ -177,7 +177,7 @@ def custom_decomp4(theta, wires, **__): with pytest.raises(TypeError, match="decomposition rule must be a qfunc with a resource"): qml.add_decomps(CustomOp, custom_decomp4) - _decompositions.pop(CustomOp) # cleanup + _decompositions.pop("CustomOp") # cleanup def test_auto_wrap_in_resource_op(self): """Tests that simply classes can be auto-wrapped in a ``CompressionResourceOp``.""" diff --git a/tests/decomposition/test_symbolic_decomposition.py b/tests/decomposition/test_symbolic_decomposition.py index 54ceea53cc6..b704fb0a6f1 100644 --- a/tests/decomposition/test_symbolic_decomposition.py +++ b/tests/decomposition/test_symbolic_decomposition.py @@ -20,11 +20,11 @@ from pennylane.decomposition.resources import Resources, pow_resource_rep from pennylane.decomposition.symbolic_decomposition import ( AdjointDecomp, - adjoint_adjoint_decomp, adjoint_controlled_decomp, adjoint_pow_decomp, - pow_decomp, - pow_pow_decomp, + cancel_adjoint, + merge_powers, + repeat_pow_base, same_type_adjoint_decomp, ) from tests.decomposition.conftest import to_resources @@ -39,12 +39,10 @@ def test_adjoint_adjoint(self): op = qml.adjoint(qml.adjoint(qml.RX(0.5, wires=0))) with qml.queuing.AnnotatedQueue() as q: - adjoint_adjoint_decomp(*op.parameters, wires=op.wires, **op.hyperparameters) + cancel_adjoint(*op.parameters, wires=op.wires, **op.hyperparameters) assert q.queue == [qml.RX(0.5, wires=0)] - assert adjoint_adjoint_decomp.compute_resources(**op.resource_params) == to_resources( - {qml.RX: 1} - ) + assert cancel_adjoint.compute_resources(**op.resource_params) == to_resources({qml.RX: 1}) @pytest.mark.jax def test_adjoint_adjoint_capture(self): @@ -58,7 +56,7 @@ def test_adjoint_adjoint_capture(self): qml.capture.enable() def circuit(): - adjoint_adjoint_decomp(*op.parameters, wires=op.wires, **op.hyperparameters) + cancel_adjoint(*op.parameters, wires=op.wires, **op.hyperparameters) plxpr = qml.capture.make_plxpr(circuit)() collector = CollectOpsandMeas() @@ -207,10 +205,10 @@ def test_pow_pow(self): op = qml.pow(qml.pow(qml.H(0), 3), 2) with qml.queuing.AnnotatedQueue() as q: - pow_pow_decomp(*op.parameters, wires=op.wires, **op.hyperparameters) + merge_powers(*op.parameters, wires=op.wires, **op.hyperparameters) assert q.queue == [qml.pow(qml.H(0), 6)] - assert pow_pow_decomp.compute_resources(**op.resource_params) == to_resources( + assert merge_powers.compute_resources(**op.resource_params) == to_resources( {pow_resource_rep(qml.H, {}, 6): 1} ) @@ -219,10 +217,10 @@ def test_pow_general(self): op = qml.pow(qml.H(0), 3) with qml.queuing.AnnotatedQueue() as q: - pow_decomp(*op.parameters, wires=op.wires, **op.hyperparameters) + repeat_pow_base(*op.parameters, wires=op.wires, **op.hyperparameters) assert q.queue == [qml.H(0), qml.H(0), qml.H(0)] - assert pow_decomp.compute_resources(**op.resource_params) == to_resources({qml.H: 3}) + assert repeat_pow_base.compute_resources(**op.resource_params) == to_resources({qml.H: 3}) @pytest.mark.jax def test_pow_general_capture(self): @@ -236,7 +234,7 @@ def test_pow_general_capture(self): qml.capture.enable() def circuit(): - pow_decomp(*op.parameters, wires=op.wires, **op.hyperparameters) + repeat_pow_base(*op.parameters, wires=op.wires, **op.hyperparameters) plxpr = qml.capture.make_plxpr(circuit)() collector = CollectOpsandMeas() @@ -251,7 +249,7 @@ def test_non_integer_pow_not_implemented(self): op = qml.pow(qml.H(0), 0.5) with pytest.raises(NotImplementedError, match="Non-integer or negative powers"): - pow_decomp.compute_resources(**op.resource_params) + repeat_pow_base.compute_resources(**op.resource_params) op = qml.pow(qml.H(0), -1) with pytest.raises(NotImplementedError, match="Non-integer or negative powers"): - pow_decomp.compute_resources(**op.resource_params) + repeat_pow_base.compute_resources(**op.resource_params) From fa09eac48f45474db1f7bea8f297e4ce04430a08 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Tue, 29 Apr 2025 11:29:02 -0400 Subject: [PATCH 02/14] fix doc? --- doc/releases/changelog-dev.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 59e7d50e3ee..52b8fe5ff2a 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -118,7 +118,7 @@ 1: ────╭●─────│─────╭●─────│───T─╰X──T†─╰X─┤ 2: ──H─╰X──T†─╰X──T─╰X──T†─╰X──T──H────────┤ ``` - * Symbolic operator types can now be given as strings to the `op_type` argument of :func:`~.add_decomps`, + * Symbolic operator types can now be given as strings to the `op_type` argument of :func:`~.decomposition.add_decomps`, or as keys of the dictionaries passed to the `alt_decomps` and `fixed_decomps` arguments of the :func:`~.transforms.decompose` transform, allowing custom decomposition rules to be defined and registered for symbolic operators. From 95c03833e570d784f2c9d94a14767ca9e2675140 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Tue, 29 Apr 2025 12:17:19 -0400 Subject: [PATCH 03/14] tests and bug fix --- .../decomposition/decomposition_graph.py | 22 +-- tests/decomposition/conftest.py | 33 ++++- .../decomposition/test_decomposition_graph.py | 131 ++++++++++-------- .../decomposition/test_decomposition_rule.py | 15 ++ 4 files changed, 135 insertions(+), 66 deletions(-) diff --git a/pennylane/decomposition/decomposition_graph.py b/pennylane/decomposition/decomposition_graph.py index 2a28b40d804..66a63a42b23 100644 --- a/pennylane/decomposition/decomposition_graph.py +++ b/pennylane/decomposition/decomposition_graph.py @@ -54,6 +54,8 @@ ) from .utils import DecompositionError, DecompositionNotApplicable, translate_op_alias +NULL = "null" # sentinel value for the start node in the graph + class DecompositionGraph: # pylint: disable=too-many-instance-attributes """A graph that models a decomposition problem. @@ -136,7 +138,6 @@ def __init__( # Tracks the node indices of various operators. self._original_ops_indices: set[int] = set() - self._target_ops_indices: set[int] = set() self._all_op_indices: dict[CompressedResourceOp, int] = {} # Stores the library of custom decomposition rules @@ -150,6 +151,7 @@ def __init__( self._visitor = None # Construct the decomposition graph + self._start = self._graph.add_node(NULL) self._construct_graph(operations) def _get_decompositions(self, op: CompressedResourceOp) -> list[DecompositionRule]: @@ -196,7 +198,7 @@ def _recursively_add_op_node(self, op_node: CompressedResourceOp) -> int: self._all_op_indices[op_node] = op_node_idx if op_node.name in self._gate_set: - self._target_ops_indices.add(op_node_idx) + self._graph.add_edge(self._start, op_node_idx, 1) return op_node_idx for decomposition in self._get_decompositions(op_node): @@ -287,6 +289,11 @@ def _recursively_add_decomposition_node( d_node = _DecompositionNode(rule, decomp_resource) d_node_idx = self._graph.add_node(d_node) + if not decomp_resource.gate_counts: + # If an operator decomposes to nothing (e.g., a Hadamard raised to a + # power of 2), we must still connect something to this decomposition + # node so that it is accounted for. + self._graph.add_edge(self._start, d_node_idx, 0) for op in decomp_resource.gate_counts: op_node_idx = self._recursively_add_op_node(op) self._graph.add_edge(op_node_idx, d_node_idx, (op_node_idx, d_node_idx)) @@ -302,17 +309,12 @@ def solve(self, lazy=True): """ self._visitor = _DecompositionSearchVisitor(self._graph, self._original_ops_indices, lazy) - start = self._graph.add_node("dummy") - self._graph.add_edges_from( - [(start, op_node_idx, 1) for op_node_idx in self._target_ops_indices] - ) rx.dijkstra_search( self._graph, - source=[start], + source=[self._start], weight_fn=self._visitor.edge_weight, visitor=self._visitor, ) - self._graph.remove_node(start) if self._visitor.unsolved_op_indices: unsolved_ops = [self._graph[op_idx] for op_idx in self._visitor.unsolved_op_indices] op_names = set(op.name for op in unsolved_ops) @@ -450,6 +452,8 @@ def examine_edge(self, edge): return # nothing is to be done for edges leading to an operator node if target_idx not in self.distances: self.distances[target_idx] = Resources() # initialize with empty resource + if src_node == NULL: + return # special case for when the decomposition produces nothing self.distances[target_idx] += self.distances[src_idx] * target_node.count(src_node) if target_idx not in self._num_edges_examined: self._num_edges_examined[target_idx] = 0 @@ -465,7 +469,7 @@ def edge_relaxed(self, edge): """Triggered when an edge is relaxed during the Dijkstra search.""" src_idx, target_idx, _ = edge target_node = self._graph[target_idx] - if self._graph[src_idx] == "dummy": + if self._graph[src_idx] == NULL and not isinstance(target_node, _DecompositionNode): self.distances[target_idx] = Resources({target_node: 1}) elif isinstance(target_node, CompressedResourceOp): self.predecessors[target_idx] = src_idx diff --git a/tests/decomposition/conftest.py b/tests/decomposition/conftest.py index 540b0a14b71..c9440af424c 100644 --- a/tests/decomposition/conftest.py +++ b/tests/decomposition/conftest.py @@ -19,7 +19,7 @@ from collections import defaultdict import pennylane as qml -from pennylane.decomposition import Resources +from pennylane.decomposition import DecompositionNotApplicable, Resources from pennylane.decomposition.decomposition_rule import _auto_wrap decompositions = defaultdict(list) @@ -125,3 +125,34 @@ def _t_ps(wires, **__): decompositions["T"] = [_t_ps] + + +@qml.register_resources({qml.H: 1}) +def _adjoint_hadamard(*_, **__): + raise NotImplementedError + + +decompositions["Adjoint(Hadamard)"] = [_adjoint_hadamard] + + +def _pow_hadamard_resource(z, **__): + return {qml.H: z % 2} + + +@qml.register_resources(_pow_hadamard_resource) +def _pow_hadamard(*_, wires, z, **__): + qml.cond(z % 2 == 1, qml.H)(wires=wires) + + +decompositions["Pow(Hadamard)"] = [_pow_hadamard] + + +def _controlled_hadamard_resource(num_control_wires, num_zero_control_values, **__): + if num_control_wires > 1: + raise DecompositionNotApplicable + return {qml.CH: 1, qml.X: num_zero_control_values * 2} + + +@qml.register_resources(_controlled_hadamard_resource) +def _controlled_hadamard(*_, **__): + raise NotImplementedError diff --git a/tests/decomposition/test_decomposition_graph.py b/tests/decomposition/test_decomposition_graph.py index 3b14c55c5e8..a4c9fc08fcf 100644 --- a/tests/decomposition/test_decomposition_graph.py +++ b/tests/decomposition/test_decomposition_graph.py @@ -88,17 +88,19 @@ def test_graph_construction(self, _): op = qml.Hadamard(wires=[0]) graph = DecompositionGraph(operations=[op], gate_set={"RX", "RZ", "GlobalPhase"}) - # 5 ops and 3 decompositions (2 for Hadamard and 1 for RY) - assert len(graph._graph.nodes()) == 8 - # 8 edges from ops to decompositions and 3 from decompositions to ops - assert len(graph._graph.edges()) == 11 + # 5 ops and 3 decompositions (2 for Hadamard and 1 for RY) and 1 dummy starting node + assert len(graph._graph.nodes()) == 9 + # 8 edges from ops to decompositions, 3 from decompositions to ops, and 3 from the + # dummy starting node to the target gate set. + assert len(graph._graph.edges()) == 14 # Check that graph construction stops at gates in the target gate set. graph2 = DecompositionGraph(operations=[op], gate_set={"RY", "RZ", "GlobalPhase"}) - # 5 ops and 2 decompositions (RY is in the target gate set now) - assert len(graph2._graph.nodes()) == 7 - # 6 edges from ops to decompositions and 2 from decompositions to ops - assert len(graph2._graph.edges()) == 8 + # 5 ops and 2 decompositions (RY is in the target gate set now), and the dummy starting node + assert len(graph2._graph.nodes()) == 8 + # 6 edges from ops to decompositions and 2 from decompositions to ops, + # and 3 from the dummy starting node to the target gate set. + assert len(graph2._graph.edges()) == 11 def test_graph_construction_non_applicable_rules(self, _): """Tests rules that raise DecompositionNotApplicable are skipped.""" @@ -135,10 +137,12 @@ def some_other_rule(*_, **__): gate_set={"CNOT", "RZ"}, alt_decomps={CustomOp: [some_rule, some_other_rule]}, ) - # 3 ops (CustomOp, CNOT, RZ) and 1 decompositions (only some_other_rule) - assert len(graph._graph.nodes()) == 4 - # 2 edges from ops to decompositions and 1 from decompositions to ops - assert len(graph._graph.edges()) == 3 + # 3 ops (CustomOp, CNOT, RZ) and 1 decompositions (only some_other_rule), + # and the dummy starting node + assert len(graph._graph.nodes()) == 5 + # 2 edges from ops to decompositions, 1 from decompositions to ops, + # and 2 from the dummy starting node to the target gate set + assert len(graph._graph.edges()) == 5 def test_gate_set(self, _): """Tests that graph construction stops at the target gate set.""" @@ -175,10 +179,12 @@ def custom_decomp(wires): fixed_decomps={CustomOp: custom_decomp}, ) - # 1 node for CustomOp, 1 decomposition node, and 5 for the ops in the decomposition - assert len(graph._graph.nodes()) == 7 - # 5 edges from ops to decompositions and 1 edge from decompositions to ops - assert len(graph._graph.edges()) == 6 + # 1 node for CustomOp, 1 decomposition node, 5 for the ops in the decomposition, + # and the dummy starting node. + assert len(graph._graph.nodes()) == 8 + # 5 edges from ops to decompositions, 1 edge from decompositions to ops, and 5 + # edges from the dummy starting node to the target gate set. + assert len(graph._graph.edges()) == 11 def test_graph_solve(self, _): """Tests solving a simple graph for the optimal decompositions.""" @@ -286,9 +292,11 @@ def _custom_decomp(*_, **__): ) # 10 ops (CustomOp, MultiRZ(4), MultiRZ(3), CNOT, CZ, RX, RY, RZ, Hadamard, GlobalPhase) # 7 decompositions (1 for CustomOp, 1 for each of the two MultiRZs, 1 for CNOT, 2 for Hadamard, and 1 for RY) - assert len(graph._graph.nodes()) == 17 - # 16 edges from ops to decompositions and 7 from decompositions to ops - assert len(graph._graph.edges()) == 23 + # and the dummy starting node + assert len(graph._graph.nodes()) == 18 + # 16 edges from ops to decompositions and 7 from decompositions to ops, + # and 4 edges from the dummy starting node to the target gate set + assert len(graph._graph.edges()) == 27 graph.solve() assert graph.resource_estimate(op) == to_resources( @@ -319,10 +327,11 @@ def test_controlled_global_phase(self, _): op1 = qml.ctrl(qml.GlobalPhase(0.5), control=[1]) op2 = qml.ctrl(qml.GlobalPhase(0.5), control=[1, 2]) graph = DecompositionGraph([op1, op2], gate_set={"ControlledPhaseShift", "PhaseShift"}) - # 4 op nodes and 2 decomposition nodes. - assert len(graph._graph.nodes()) == 6 - # 2 edges from decompositions to ops and 2 edges from ops to decompositions - assert len(graph._graph.edges()) == 4 + # 4 op nodes and 2 decomposition nodes, and 1 dummy starting node. + assert len(graph._graph.nodes()) == 7 + # 2 edges from decompositions to ops and 2 edges from ops to decompositions, + # and 2 edges from the dummy starting node to the target gate set. + assert len(graph._graph.edges()) == 6 # Verify the decompositions graph.solve() @@ -344,10 +353,11 @@ def test_custom_controlled_op(self, _): operations=[op1, op2], gate_set={"CNOT", "CH"}, ) - # 4 op nodes and 2 decomposition nodes. - assert len(graph._graph.nodes()) == 6 + # 4 op nodes and 2 decomposition nodes, and the dummy starting node + assert len(graph._graph.nodes()) == 7 # 2 edges from decompositions to ops and 2 edges from ops to decompositions - assert len(graph._graph.edges()) == 4 + # and 2 edges from the dummy starting node to the target gate set. + assert len(graph._graph.edges()) == 6 # Verify the decompositions graph.solve() @@ -415,10 +425,11 @@ def custom_controlled_decomp(wires): CustomControlledOp: [custom_controlled_decomp], }, ) - # 18 op nodes and 16 decomposition nodes. - assert len(graph._graph.nodes()) == 34 + # 18 op nodes and 16 decomposition nodes, and the dummy starting node + assert len(graph._graph.nodes()) == 35 # 16 edges from decompositions to ops and 36 edges from ops to decompositions - assert len(graph._graph.edges()) == 52 + # and 6 edge from the dummy starting node to the target gate set. + assert len(graph._graph.edges()) == 58 graph.solve() @@ -445,8 +456,9 @@ def test_adjoint_adjoint(self, _): graph = DecompositionGraph(operations=[op], gate_set={"RX"}) # 2 operator nodes (Adjoint(Adjoint(RX)) and RX), and 1 decomposition node. - assert len(graph._graph.nodes()) == 3 - assert len(graph._graph.edges()) == 2 + # and the dummy starting node + assert len(graph._graph.nodes()) == 4 + assert len(graph._graph.edges()) == 3 graph.solve() with qml.queuing.AnnotatedQueue() as q: @@ -461,19 +473,19 @@ def test_adjoint_pow(self, _): op = qml.adjoint(qml.pow(qml.H(0), z=3)) graph = DecompositionGraph(operations=[op], gate_set={"H"}) - # 3 operator nodes: Adjoint(Pow(H)), Pow(H), and H - # 2 decomposition nodes for Adjoint(Pow(H)) and Pow(H) - assert len(graph._graph.nodes()) == 5 - # 2 edges from decompositions to ops and 2 edges from ops to decompositions - assert len(graph._graph.edges()) == 4 + # 3 operator nodes: Adjoint(Pow(H)), Pow(H), and H, and 1 dummy starting node + # 1 decomposition nodes for Adjoint(Pow(H)) and two decomposition nodes for Pow(H) + assert len(graph._graph.nodes()) == 7 + # 3 edges from decompositions to ops and 3 edges from ops to decompositions + # and 1 edge from the dummy starting node to the target gate set. + assert len(graph._graph.edges()) == 7 graph.solve() with qml.queuing.AnnotatedQueue() as q: graph.decomposition(op)(*op.parameters, wires=op.wires, **op.hyperparameters) assert q.queue == [qml.pow(qml.H(0), z=3)] - # TODO: There should just be a single `H` after we have full support of Pow decompositions. - assert graph.resource_estimate(op) == to_resources({qml.H: 3}) + assert graph.resource_estimate(op) == to_resources({qml.H: 1}) def test_adjoint_custom(self, _): """Tests adjoint of an operator that defines its own adjoint.""" @@ -481,9 +493,9 @@ def test_adjoint_custom(self, _): op = qml.adjoint(qml.RX(0.5, wires=[0])) graph = DecompositionGraph(operations=[op], gate_set={"RX"}) - # 2 operator nodes (Adjoint(RX) and RX), and 1 decomposition node. - assert len(graph._graph.nodes()) == 3 - assert len(graph._graph.edges()) == 2 + # 2 operator nodes (Adjoint(RX) and RX), and 1 decomposition node, and 1 dummy starting node + assert len(graph._graph.nodes()) == 4 + assert len(graph._graph.edges()) == 3 graph.solve() with qml.queuing.AnnotatedQueue() as q: @@ -501,9 +513,11 @@ def test_adjoint_controlled(self, _): graph = DecompositionGraph(operations=[op, op2], gate_set={"ControlledPhaseShift", "CRX"}) # 5 operator nodes: Adjoint(C(RX)), Adjoint(C(U1)), CRX, C(U1), ControlledPhaseShift # 3 decomposition nodes leading into Adjoint(C(RX)), Adjoint(C(U1)), and C(U1) - assert len(graph._graph.nodes()) == 8 - # 3 edges from decompositions to ops and 3 edges from ops to decompositions - assert len(graph._graph.edges()) == 6 + # and the dummy starting node + assert len(graph._graph.nodes()) == 9 + # 3 edges from decompositions to ops and 3 edges from ops to decompositions, + # and 2 edges from the dummy starting node to the target gate set. + assert len(graph._graph.edges()) == 8 graph.solve() with qml.queuing.AnnotatedQueue() as q: @@ -546,10 +560,13 @@ def custom_decomp(phi, wires): alt_decomps={CustomOp: [custom_decomp]}, ) # 10 operator nodes: A(CustomOp), A(H), A(CNOT), A(RX), A(T), H, CNOT, RX, A(PhaseShift), PhaseShift - # 6 decomposition nodes for: A(CustomOp), A(H), A(CNOT), A(RX), A(T), A(PhaseShift) - assert len(graph._graph.nodes()) == 16 - # 9 edges from ops to decompositions and 6 edges from decompositions to ops. - assert len(graph._graph.edges()) == 15 + # 5 decomposition nodes for: A(CustomOp), A(CNOT), A(RX), A(T), A(PhaseShift) + # 2 decomposition nodes for A(H) + # 1 dummy starting node + assert len(graph._graph.nodes()) == 18 + # 10 edges from ops to decompositions and 7 edges from decompositions to ops. + # and 4 edges from the dummy starting node to the target gate set. + assert len(graph._graph.edges()) == 21 graph.solve() with qml.queuing.AnnotatedQueue() as q: @@ -566,27 +583,29 @@ def custom_decomp(phi, wires): {qml.H: 1, qml.CNOT: 2, qml.RX: 1, qml.PhaseShift: 1} ) - def test_pow_pow(self, _): + def test_nested_powers(self, _): """Tests nested power decompositions.""" op = qml.pow(qml.pow(qml.H(0), 3), 2) graph = DecompositionGraph(operations=[op], gate_set={"H"}) # 3 operator nodes: Pow(Pow(H)), Pow(H), and H - # 2 decomposition nodes for Pow(Pow(H)) and Pow(H) - assert len(graph._graph.nodes()) == 5 - # 2 edges from decompositions to ops and 2 edges from ops to decompositions - assert len(graph._graph.edges()) == 4 + # 1 decomposition nodes for Pow(Pow(H)) and 2 nodes for Pow(H) + # and the dummy starting node + assert len(graph._graph.nodes()) == 7 + # 3 edges from decompositions to ops and 3 edges from ops to decompositions + # and 1 edge from the dummy starting node to the target gate set. + assert len(graph._graph.edges()) == 7 graph.solve() with qml.queuing.AnnotatedQueue() as q: graph.decomposition(op)(*op.parameters, wires=op.wires, **op.hyperparameters) assert q.queue == [qml.pow(qml.H(0), 6)] - assert graph.resource_estimate(op) == to_resources({qml.H: 6}) + assert graph.resource_estimate(op) == to_resources({}) op2 = qml.pow(qml.H(0), 6) with qml.queuing.AnnotatedQueue() as q: graph.decomposition(op2)(*op2.parameters, wires=op2.wires, **op2.hyperparameters) - assert q.queue == [qml.H(0), qml.H(0), qml.H(0), qml.H(0), qml.H(0), qml.H(0)] - assert graph.resource_estimate(op2) == to_resources({qml.H: 6}) + assert q.queue == [] + assert graph.resource_estimate(op2) == to_resources({}) diff --git a/tests/decomposition/test_decomposition_rule.py b/tests/decomposition/test_decomposition_rule.py index 9c9c25007b3..26247100296 100644 --- a/tests/decomposition/test_decomposition_rule.py +++ b/tests/decomposition/test_decomposition_rule.py @@ -179,6 +179,21 @@ def custom_decomp4(theta, wires, **__): _decompositions.pop("CustomOp") # cleanup + def test_custom_symbolic_decomposition(self): + """Tests that custom decomposition rules for symbolic operators can be registered.""" + + class CustomOp(qml.operation.Operation): # pylint: disable=too-few-public-methods + pass + + @qml.register_resources({qml.RX: 1, qml.RZ: 1}) + def my_adjoint_custom_op(theta, wires, **__): + qml.RX(theta, wires=wires[0]) + qml.RZ(theta, wires=wires[1]) + + qml.add_decomps("Adjoint(CustomOp)", my_adjoint_custom_op) + assert qml.decomposition.has_decomp("Adjoint(CustomOp)") + assert qml.list_decomps("Adjoint(CustomOp)") == [my_adjoint_custom_op] + def test_auto_wrap_in_resource_op(self): """Tests that simply classes can be auto-wrapped in a ``CompressionResourceOp``.""" From 8f0a0de39cd6336c114cf838cc8d6a443ba76ce9 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Tue, 29 Apr 2025 12:21:13 -0400 Subject: [PATCH 04/14] update changelog --- doc/releases/changelog-dev.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 52b8fe5ff2a..ad9db07e711 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -122,6 +122,7 @@ or as keys of the dictionaries passed to the `alt_decomps` and `fixed_decomps` arguments of the :func:`~.transforms.decompose` transform, allowing custom decomposition rules to be defined and registered for symbolic operators. + [(#7347)](https://github.com/PennyLaneAI/pennylane/pull/7347) ```python @register_resources({qml.RY: 1}) def my_adjoint_ry(phi, wires, **_): From 5b074ed794ee1cc4c8ed3d74b0262207a7f1c1f9 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Tue, 29 Apr 2025 12:33:14 -0400 Subject: [PATCH 05/14] add more tests --- tests/decomposition/conftest.py | 12 ++++++--- .../decomposition/test_decomposition_graph.py | 27 +++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/tests/decomposition/conftest.py b/tests/decomposition/conftest.py index c9440af424c..111eaa5ab30 100644 --- a/tests/decomposition/conftest.py +++ b/tests/decomposition/conftest.py @@ -128,8 +128,8 @@ def _t_ps(wires, **__): @qml.register_resources({qml.H: 1}) -def _adjoint_hadamard(*_, **__): - raise NotImplementedError +def _adjoint_hadamard(*_, wires, **__): + qml.H(wires) decompositions["Adjoint(Hadamard)"] = [_adjoint_hadamard] @@ -154,5 +154,9 @@ def _controlled_hadamard_resource(num_control_wires, num_zero_control_values, ** @qml.register_resources(_controlled_hadamard_resource) -def _controlled_hadamard(*_, **__): - raise NotImplementedError +def _controlled_hadamard(*_, wires, control_values, **__): + if not control_values[0]: + qml.PauliX(wires=wires[0]) + qml.CH(wires=wires) + if not control_values[0]: + qml.PauliX(wires=wires[0]) diff --git a/tests/decomposition/test_decomposition_graph.py b/tests/decomposition/test_decomposition_graph.py index a4c9fc08fcf..609df6660b0 100644 --- a/tests/decomposition/test_decomposition_graph.py +++ b/tests/decomposition/test_decomposition_graph.py @@ -609,3 +609,30 @@ def test_nested_powers(self, _): assert q.queue == [] assert graph.resource_estimate(op2) == to_resources({}) + + def test_custom_symbolic_decompositions(self, _): + """Tests that custom symbolic decompositions are used.""" + + graph = DecompositionGraph( + operations=[ + qml.adjoint(qml.H(0)), + qml.pow(qml.H(1), 3), + qml.ops.Controlled(qml.H(0), control_wires=1), + ], + gate_set={"H", "CH"}, + ) + + op1 = qml.adjoint(qml.H(0)) + op2 = qml.pow(qml.H(1), 3) + op3 = qml.ops.Controlled(qml.H(0), control_wires=1) + + graph.solve() + with qml.queuing.AnnotatedQueue() as q: + graph.decomposition(op1)(*op1.parameters, wires=op1.wires, **op1.hyperparameters) + graph.decomposition(op2)(*op2.parameters, wires=op2.wires, **op2.hyperparameters) + graph.decomposition(op3)(*op3.parameters, wires=op3.wires, **op3.hyperparameters) + + assert q.queue == [qml.H(0), qml.H(1), qml.CH(wires=[1, 0])] + assert graph.resource_estimate(op1) == to_resources({qml.H: 1}) + assert graph.resource_estimate(op2) == to_resources({qml.H: 1}) + assert graph.resource_estimate(op3) == to_resources({qml.CH: 1}) From ff9088226621c49cd4c1411dacd196dfef6e2bf1 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Tue, 29 Apr 2025 12:52:22 -0400 Subject: [PATCH 06/14] one more test case --- tests/decomposition/test_decomposition_graph.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/decomposition/test_decomposition_graph.py b/tests/decomposition/test_decomposition_graph.py index 609df6660b0..956c17645ca 100644 --- a/tests/decomposition/test_decomposition_graph.py +++ b/tests/decomposition/test_decomposition_graph.py @@ -613,26 +613,35 @@ def test_nested_powers(self, _): def test_custom_symbolic_decompositions(self, _): """Tests that custom symbolic decompositions are used.""" + @qml.register_resources({qml.RX: 1}) + def my_adjoint_rx(theta, wires, **__): + qml.RX(-theta, wires=wires) + graph = DecompositionGraph( operations=[ qml.adjoint(qml.H(0)), qml.pow(qml.H(1), 3), qml.ops.Controlled(qml.H(0), control_wires=1), + qml.adjoint(qml.RX(0.5, wires=0)), ], - gate_set={"H", "CH"}, + fixed_decomps={"Adjoint(RX)": my_adjoint_rx}, + gate_set={"H", "CH", "RX"}, ) op1 = qml.adjoint(qml.H(0)) op2 = qml.pow(qml.H(1), 3) op3 = qml.ops.Controlled(qml.H(0), control_wires=1) + op4 = qml.adjoint(qml.RX(0.5, wires=0)) graph.solve() with qml.queuing.AnnotatedQueue() as q: graph.decomposition(op1)(*op1.parameters, wires=op1.wires, **op1.hyperparameters) graph.decomposition(op2)(*op2.parameters, wires=op2.wires, **op2.hyperparameters) graph.decomposition(op3)(*op3.parameters, wires=op3.wires, **op3.hyperparameters) + graph.decomposition(op4)(*op4.parameters, wires=op4.wires, **op4.hyperparameters) - assert q.queue == [qml.H(0), qml.H(1), qml.CH(wires=[1, 0])] + assert q.queue == [qml.H(0), qml.H(1), qml.CH(wires=[1, 0]), qml.RX(-0.5, wires=0)] assert graph.resource_estimate(op1) == to_resources({qml.H: 1}) assert graph.resource_estimate(op2) == to_resources({qml.H: 1}) assert graph.resource_estimate(op3) == to_resources({qml.CH: 1}) + assert graph.resource_estimate(op4) == to_resources({qml.RX: 1}) From 956b9a275f49e26fa674a2fd0b4c7d4648d606dd Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Tue, 29 Apr 2025 13:54:34 -0400 Subject: [PATCH 07/14] pylint --- tests/decomposition/conftest.py | 15 --------------- tests/decomposition/test_decomposition_rule.py | 3 --- 2 files changed, 18 deletions(-) diff --git a/tests/decomposition/conftest.py b/tests/decomposition/conftest.py index 111eaa5ab30..bd177694152 100644 --- a/tests/decomposition/conftest.py +++ b/tests/decomposition/conftest.py @@ -145,18 +145,3 @@ def _pow_hadamard(*_, wires, z, **__): decompositions["Pow(Hadamard)"] = [_pow_hadamard] - - -def _controlled_hadamard_resource(num_control_wires, num_zero_control_values, **__): - if num_control_wires > 1: - raise DecompositionNotApplicable - return {qml.CH: 1, qml.X: num_zero_control_values * 2} - - -@qml.register_resources(_controlled_hadamard_resource) -def _controlled_hadamard(*_, wires, control_values, **__): - if not control_values[0]: - qml.PauliX(wires=wires[0]) - qml.CH(wires=wires) - if not control_values[0]: - qml.PauliX(wires=wires[0]) diff --git a/tests/decomposition/test_decomposition_rule.py b/tests/decomposition/test_decomposition_rule.py index 26247100296..94f67064294 100644 --- a/tests/decomposition/test_decomposition_rule.py +++ b/tests/decomposition/test_decomposition_rule.py @@ -182,9 +182,6 @@ def custom_decomp4(theta, wires, **__): def test_custom_symbolic_decomposition(self): """Tests that custom decomposition rules for symbolic operators can be registered.""" - class CustomOp(qml.operation.Operation): # pylint: disable=too-few-public-methods - pass - @qml.register_resources({qml.RX: 1, qml.RZ: 1}) def my_adjoint_custom_op(theta, wires, **__): qml.RX(theta, wires=wires[0]) From 32dba0bf071d2b610dd1e986e344fcf8ddbbd0bc Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Tue, 29 Apr 2025 14:03:00 -0400 Subject: [PATCH 08/14] pylint --- tests/decomposition/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/decomposition/conftest.py b/tests/decomposition/conftest.py index bd177694152..f57a0267eaf 100644 --- a/tests/decomposition/conftest.py +++ b/tests/decomposition/conftest.py @@ -19,7 +19,7 @@ from collections import defaultdict import pennylane as qml -from pennylane.decomposition import DecompositionNotApplicable, Resources +from pennylane.decomposition import Resources from pennylane.decomposition.decomposition_rule import _auto_wrap decompositions = defaultdict(list) From 26a314d2abf8c37fe8d62546918cf4bd6c9efed4 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Thu, 1 May 2025 14:42:45 -0400 Subject: [PATCH 09/14] Apply suggestions from code review --- doc/releases/changelog-dev.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 76a096fd64f..f6f17f75b39 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -99,7 +99,7 @@ * Symbolic operator types (e.g., `Adjoint`, `Controlled`, and `Pow`) can now be specified as strings in various parts of the new graph-based decomposition system, specifically: * The `gate_set` argument of the :func:`~.transforms.decompose` transform now supports adding symbolic - operators to the target gate set. + operators in the target gate set. [(#7331)](https://github.com/PennyLaneAI/pennylane/pull/7331) ```python from functools import partial @@ -137,7 +137,7 @@ @partial( qml.transforms.decompose, - gate_set={"RX", "CNOT"}, + gate_set={"RX", "RY", "CNOT"}, fixed_decomps={"Adjoint(RX)": my_adjoint_rx} ) @qml.qnode(qml.device("default.qubit")) From 02f08f29b013640a3e33ed99f52f6e46fecfe1e9 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Thu, 1 May 2025 15:07:06 -0400 Subject: [PATCH 10/14] Apply suggestions from code review Co-authored-by: Pietropaolo Frisoni --- doc/releases/changelog-dev.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index f6f17f75b39..ed17d74b52a 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -124,7 +124,7 @@ registered for symbolic operators. [(#7347)](https://github.com/PennyLaneAI/pennylane/pull/7347) ```python - @register_resources({qml.RY: 1}) + @qml.register_resources({qml.RY: 1}) def my_adjoint_ry(phi, wires, **_): qml.RY(-phi, wires=wires) @@ -142,9 +142,9 @@ ) @qml.qnode(qml.device("default.qubit")) def circuit(): - qml.adjoint(qml.RX(0.5), wires=[0]) + qml.adjoint(qml.RX(0.5, wires=[0])) qml.CNOT(wires=[0, 1]) - qml.adjoint(qml.RY(0.5), wires=[1]) + qml.adjoint(qml.RY(0.5, wires=[1])) return qml.expval(qml.Z(0)) ``` ```pycon From fa54ba23378fad326431ae70d7e1c9d94ce0994e Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Thu, 1 May 2025 15:18:37 -0400 Subject: [PATCH 11/14] lol --- pennylane/decomposition/decomposition_graph.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pennylane/decomposition/decomposition_graph.py b/pennylane/decomposition/decomposition_graph.py index 66a63a42b23..4813ff4cfa9 100644 --- a/pennylane/decomposition/decomposition_graph.py +++ b/pennylane/decomposition/decomposition_graph.py @@ -54,8 +54,6 @@ ) from .utils import DecompositionError, DecompositionNotApplicable, translate_op_alias -NULL = "null" # sentinel value for the start node in the graph - class DecompositionGraph: # pylint: disable=too-many-instance-attributes """A graph that models a decomposition problem. @@ -151,7 +149,7 @@ def __init__( self._visitor = None # Construct the decomposition graph - self._start = self._graph.add_node(NULL) + self._start = self._graph.add_node(None) self._construct_graph(operations) def _get_decompositions(self, op: CompressedResourceOp) -> list[DecompositionRule]: @@ -452,7 +450,7 @@ def examine_edge(self, edge): return # nothing is to be done for edges leading to an operator node if target_idx not in self.distances: self.distances[target_idx] = Resources() # initialize with empty resource - if src_node == NULL: + if src_node is None: return # special case for when the decomposition produces nothing self.distances[target_idx] += self.distances[src_idx] * target_node.count(src_node) if target_idx not in self._num_edges_examined: @@ -469,7 +467,7 @@ def edge_relaxed(self, edge): """Triggered when an edge is relaxed during the Dijkstra search.""" src_idx, target_idx, _ = edge target_node = self._graph[target_idx] - if self._graph[src_idx] == NULL and not isinstance(target_node, _DecompositionNode): + if self._graph[src_idx] is None and not isinstance(target_node, _DecompositionNode): self.distances[target_idx] = Resources({target_node: 1}) elif isinstance(target_node, CompressedResourceOp): self.predecessors[target_idx] = src_idx From a85a40fb666911c72041789cc7bc90adbda370e9 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Fri, 2 May 2025 16:08:36 -0400 Subject: [PATCH 12/14] minor fix --- pennylane/decomposition/symbolic_decomposition.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pennylane/decomposition/symbolic_decomposition.py b/pennylane/decomposition/symbolic_decomposition.py index dc4acf40084..61ee6322fc6 100644 --- a/pennylane/decomposition/symbolic_decomposition.py +++ b/pennylane/decomposition/symbolic_decomposition.py @@ -76,12 +76,12 @@ def _adjoint_adjoint_resource(*_, base_params, **__): return {resource_rep(base_class, **base_params): 1} +# pylint: disable=protected-access @register_resources(_adjoint_adjoint_resource) def cancel_adjoint(*params, wires, base): # pylint: disable=unused-argument """Decompose the adjoint of the adjoint of a gate.""" - _, [_, metadata] = base.base._flatten() # pylint: disable=protected-access - new_struct = wires, metadata - base.base._unflatten(params, new_struct) # pylint: disable=protected-access + _, struct = base.base._flatten() + base.base._unflatten(params, struct) def _adjoint_controlled_resource(base_class, base_params): From 5db0fc29567c0bab832e701b3a369b34997f70ff Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Mon, 5 May 2025 10:53:09 -0400 Subject: [PATCH 13/14] get rid of funny business --- pennylane/decomposition/symbolic_decomposition.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pennylane/decomposition/symbolic_decomposition.py b/pennylane/decomposition/symbolic_decomposition.py index 61ee6322fc6..15f6fb94128 100644 --- a/pennylane/decomposition/symbolic_decomposition.py +++ b/pennylane/decomposition/symbolic_decomposition.py @@ -80,8 +80,7 @@ def _adjoint_adjoint_resource(*_, base_params, **__): @register_resources(_adjoint_adjoint_resource) def cancel_adjoint(*params, wires, base): # pylint: disable=unused-argument """Decompose the adjoint of the adjoint of a gate.""" - _, struct = base.base._flatten() - base.base._unflatten(params, struct) + base.base._unflatten(*base.base._flatten()) def _adjoint_controlled_resource(base_class, base_params): From 74604ab97f3d8f8e0929986c626e9cb8c081d270 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Wed, 7 May 2025 09:41:16 -0400 Subject: [PATCH 14/14] minor rename --- .../decomposition/decomposition_graph.py | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/pennylane/decomposition/decomposition_graph.py b/pennylane/decomposition/decomposition_graph.py index 4813ff4cfa9..3d668424cb6 100644 --- a/pennylane/decomposition/decomposition_graph.py +++ b/pennylane/decomposition/decomposition_graph.py @@ -152,24 +152,24 @@ def __init__( self._start = self._graph.add_node(None) self._construct_graph(operations) - def _get_decompositions(self, op: CompressedResourceOp) -> list[DecompositionRule]: + def _get_decompositions(self, op_node: CompressedResourceOp) -> list[DecompositionRule]: """Helper function to get a list of decomposition rules.""" - op_name = _to_name(op) + op_name = _to_name(op_node) if op_name in self._fixed_decomps: return [self._fixed_decomps[op_name]] decomps = self._alt_decomps.get(op_name, []) + list_decomps(op_name) - if issubclass(op.op_type, qml.ops.Adjoint): - decomps.extend(self._get_adjoint_decompositions(op)) + if issubclass(op_node.op_type, qml.ops.Adjoint): + decomps.extend(self._get_adjoint_decompositions(op_node)) - elif issubclass(op.op_type, qml.ops.Pow): - decomps.extend(self._get_pow_decompositions(op)) + elif issubclass(op_node.op_type, qml.ops.Pow): + decomps.extend(self._get_pow_decompositions(op_node)) - elif op.op_type in (qml.ops.Controlled, qml.ops.ControlledOp): - decomps.extend(self._get_controlled_decompositions(op)) + elif op_node.op_type in (qml.ops.Controlled, qml.ops.ControlledOp): + decomps.extend(self._get_controlled_decompositions(op_node)) return decomps @@ -215,10 +215,10 @@ def _add_decomp_rule_to_op( except DecompositionNotApplicable: pass # ignore decompositions that are not applicable to the given op params. - def _get_adjoint_decompositions(self, op: CompressedResourceOp) -> list[DecompositionRule]: + def _get_adjoint_decompositions(self, op_node: CompressedResourceOp) -> list[DecompositionRule]: """Retrieves a list of decomposition rules for an adjoint operator.""" - base_class, base_params = op.params["base_class"], op.params["base_params"] + base_class, base_params = (op_node.params["base_class"], op_node.params["base_params"]) if issubclass(base_class, qml.ops.Adjoint): return [cancel_adjoint] @@ -242,21 +242,23 @@ def _get_adjoint_decompositions(self, op: CompressedResourceOp) -> list[Decompos return [AdjointDecomp(base_rule) for base_rule in self._get_decompositions(base_rep)] @staticmethod - def _get_pow_decompositions(op: CompressedResourceOp) -> list[DecompositionRule]: + def _get_pow_decompositions(op_node: CompressedResourceOp) -> list[DecompositionRule]: """Retrieves a list of decomposition rules for a power operator.""" - base_class = op.params["base_class"] + base_class = op_node.params["base_class"] if issubclass(base_class, qml.ops.Pow): return [merge_powers] return [repeat_pow_base] - def _get_controlled_decompositions(self, op: CompressedResourceOp) -> list[DecompositionRule]: + def _get_controlled_decompositions( + self, op_node: CompressedResourceOp + ) -> list[DecompositionRule]: """Adds a controlled decomposition node to the graph.""" - base_class = op.params["base_class"] - num_control_wires = op.params["num_control_wires"] + base_class = op_node.params["base_class"] + num_control_wires = op_node.params["num_control_wires"] # Handle controlled global phase if base_class is qml.GlobalPhase: @@ -272,7 +274,7 @@ def _get_controlled_decompositions(self, op: CompressedResourceOp) -> list[Decom return [CustomControlledDecomposition(custom_op_type)] # General case - base_rep = resource_rep(base_class, **op.params["base_params"]) + base_rep = resource_rep(base_class, **op_node.params["base_params"]) return [ControlledBaseDecomposition(rule) for rule in self._get_decompositions(base_rep)] def _recursively_add_decomposition_node(