diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 1df4c4769c9..2a13656ddd5 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -91,32 +91,66 @@ qml.RZ(omega, wires=wires[0]) qml.GlobalPhase(-phase) + # This decomposition will be ignored for `QubitUnitary` on more than one wire. qml.add_decomps(QubitUnitary, zyz_decomposition) ``` - - This decomposition will be ignored for `QubitUnitary` on more than one wire. - -* The :func:`~.transforms.decompose` transform now supports symbolic operators (e.g., `Adjoint` and `Controlled`) specified as strings in the `gate_set` argument - when the new graph-based decomposition system is enabled. - [(#7331)](https://github.com/PennyLaneAI/pennylane/pull/7331) - ```python - from functools import partial - import pennylane as qml +* Symbolic operator types (e.g., `Adjoint`, `Controlled`, and `Pow`) can now be specified as strings + in various parts of the new graph-based decomposition system, specifically: + * The `gate_set` argument of the :func:`~.transforms.decompose` transform now supports adding symbolic + operators in the target gate set. + [(#7331)](https://github.com/PennyLaneAI/pennylane/pull/7331) + ```python + from functools import partial + import pennylane as qml - qml.decomposition.enable_graph() + qml.decomposition.enable_graph() - @partial(qml.transforms.decompose, gate_set={"T", "Adjoint(T)", "H", "CNOT"}) - @qml.qnode(qml.device("default.qubit")) - def circuit(): - qml.Toffoli(wires=[0, 1, 2]) - ``` - ```pycon - >>> print(qml.draw(circuit)()) - 0: ───────────╭●───────────╭●────╭●──T──╭●─┤ - 1: ────╭●─────│─────╭●─────│───T─╰X──T†─╰X─┤ - 2: ──H─╰X──T†─╰X──T─╰X──T†─╰X──T──H────────┤ - ``` + @partial(qml.transforms.decompose, gate_set={"T", "Adjoint(T)", "H", "CNOT"}) + @qml.qnode(qml.device("default.qubit")) + def circuit(): + qml.Toffoli(wires=[0, 1, 2]) + ``` + ```pycon + >>> print(qml.draw(circuit)()) + 0: ───────────╭●───────────╭●────╭●──T──╭●─┤ + 1: ────╭●─────│─────╭●─────│───T─╰X──T†─╰X─┤ + 2: ──H─╰X──T†─╰X──T─╰X──T†─╰X──T──H────────┤ + ``` + * Symbolic operator types can now be given as strings to the `op_type` argument of :func:`~.decomposition.add_decomps`, + or as keys of the dictionaries passed to the `alt_decomps` and `fixed_decomps` arguments of the + :func:`~.transforms.decompose` transform, allowing custom decomposition rules to be defined and + registered for symbolic operators. + [(#7347)](https://github.com/PennyLaneAI/pennylane/pull/7347) + ```python + @qml.register_resources({qml.RY: 1}) + def my_adjoint_ry(phi, wires, **_): + qml.RY(-phi, wires=wires) + + @qml.register_resources({qml.RX: 1}) + def my_adjoint_rx(phi, wires, **__): + qml.RX(-phi, wires) + + # Registers a decomposition rule for the adjoint of RY globally + qml.add_decomps("Adjoint(RY)", my_adjoint_ry) + + @partial( + qml.transforms.decompose, + gate_set={"RX", "RY", "CNOT"}, + fixed_decomps={"Adjoint(RX)": my_adjoint_rx} + ) + @qml.qnode(qml.device("default.qubit")) + def circuit(): + qml.adjoint(qml.RX(0.5, wires=[0])) + qml.CNOT(wires=[0, 1]) + qml.adjoint(qml.RY(0.5, wires=[1])) + return qml.expval(qml.Z(0)) + ``` + ```pycon + >>> print(qml.draw(circuit)()) + 0: ──RX(-0.50)─╭●────────────┤ + 1: ────────────╰X──RY(-0.50)─┤ + ```

Improvements 🛠

diff --git a/pennylane/decomposition/decomposition_graph.py b/pennylane/decomposition/decomposition_graph.py index cdbbba327f7..3d668424cb6 100644 --- a/pennylane/decomposition/decomposition_graph.py +++ b/pennylane/decomposition/decomposition_graph.py @@ -44,11 +44,11 @@ from .resources import CompressedResourceOp, Resources, resource_rep from .symbolic_decomposition import ( AdjointDecomp, - adjoint_adjoint_decomp, adjoint_controlled_decomp, adjoint_pow_decomp, - pow_decomp, - pow_pow_decomp, + cancel_adjoint, + merge_powers, + repeat_pow_base, same_type_adjoint_decomp, same_type_adjoint_ops, ) @@ -136,25 +136,42 @@ def __init__( # Tracks the node indices of various operators. self._original_ops_indices: set[int] = set() - self._target_ops_indices: set[int] = set() self._all_op_indices: dict[CompressedResourceOp, int] = {} # Stores the library of custom decomposition rules - self._fixed_decomps = fixed_decomps or {} - self._alt_decomps = alt_decomps or {} + fixed_decomps = fixed_decomps or {} + alt_decomps = alt_decomps or {} + self._fixed_decomps = {_to_name(k): v for k, v in fixed_decomps.items()} + self._alt_decomps = {_to_name(k): v for k, v in alt_decomps.items()} # Initializes the graph. self._graph = rx.PyDiGraph() self._visitor = None # Construct the decomposition graph + self._start = self._graph.add_node(None) self._construct_graph(operations) - def _get_decompositions(self, op_type) -> list[DecompositionRule]: + def _get_decompositions(self, op_node: CompressedResourceOp) -> list[DecompositionRule]: """Helper function to get a list of decomposition rules.""" - if op_type in self._fixed_decomps: - return [self._fixed_decomps[op_type]] - return self._alt_decomps.get(op_type, []) + list_decomps(op_type) + + op_name = _to_name(op_node) + + if op_name in self._fixed_decomps: + return [self._fixed_decomps[op_name]] + + decomps = self._alt_decomps.get(op_name, []) + list_decomps(op_name) + + if issubclass(op_node.op_type, qml.ops.Adjoint): + decomps.extend(self._get_adjoint_decompositions(op_node)) + + elif issubclass(op_node.op_type, qml.ops.Pow): + decomps.extend(self._get_pow_decompositions(op_node)) + + elif op_node.op_type in (qml.ops.Controlled, qml.ops.ControlledOp): + decomps.extend(self._get_controlled_decompositions(op_node)) + + return decomps def _construct_graph(self, operations): """Constructs the decomposition graph.""" @@ -179,20 +196,10 @@ def _recursively_add_op_node(self, op_node: CompressedResourceOp) -> int: self._all_op_indices[op_node] = op_node_idx if op_node.name in self._gate_set: - self._target_ops_indices.add(op_node_idx) + self._graph.add_edge(self._start, op_node_idx, 1) return op_node_idx - if op_node.op_type in (qml.ops.Controlled, qml.ops.ControlledOp): - # This branch only applies to general controlled operators - return self._add_controlled_decomp_node(op_node, op_node_idx) - - if issubclass(op_node.op_type, qml.ops.Adjoint): - return self._add_adjoint_decomp_node(op_node, op_node_idx) - - if issubclass(op_node.op_type, qml.ops.Pow): - return self._add_pow_decomp_node(op_node, op_node_idx) - - for decomposition in self._get_decompositions(op_node.op_type): + for decomposition in self._get_decompositions(op_node): self._add_decomp_rule_to_op(decomposition, op_node, op_node_idx) return op_node_idx @@ -208,58 +215,46 @@ def _add_decomp_rule_to_op( except DecompositionNotApplicable: pass # ignore decompositions that are not applicable to the given op params. - def _add_adjoint_decomp_node(self, op_node: CompressedResourceOp, op_node_idx: int) -> int: - """Adds an adjoint decomposition node.""" + def _get_adjoint_decompositions(self, op_node: CompressedResourceOp) -> list[DecompositionRule]: + """Retrieves a list of decomposition rules for an adjoint operator.""" - base_class, base_params = op_node.params["base_class"], op_node.params["base_params"] + base_class, base_params = (op_node.params["base_class"], op_node.params["base_params"]) if issubclass(base_class, qml.ops.Adjoint): - rule = adjoint_adjoint_decomp - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - return op_node_idx + return [cancel_adjoint] if ( issubclass(base_class, qml.ops.Pow) and base_params["base_class"] in same_type_adjoint_ops() ): - rule = adjoint_pow_decomp - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - return op_node_idx + return [adjoint_pow_decomp] if base_class in same_type_adjoint_ops(): - rule = same_type_adjoint_decomp - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - return op_node_idx + return [same_type_adjoint_decomp] if ( issubclass(base_class, qml.ops.Controlled) and base_params["base_class"] in same_type_adjoint_ops() ): - rule = adjoint_controlled_decomp - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - return op_node_idx + return [adjoint_controlled_decomp] - for base_decomposition in self._get_decompositions(base_class): - rule = AdjointDecomp(base_decomposition) - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - - return op_node_idx + base_rep = resource_rep(base_class, **base_params) + return [AdjointDecomp(base_rule) for base_rule in self._get_decompositions(base_rep)] - def _add_pow_decomp_node(self, op_node: CompressedResourceOp, op_node_idx: int) -> int: - """Adds a power decomposition node to the graph.""" + @staticmethod + def _get_pow_decompositions(op_node: CompressedResourceOp) -> list[DecompositionRule]: + """Retrieves a list of decomposition rules for a power operator.""" base_class = op_node.params["base_class"] if issubclass(base_class, qml.ops.Pow): - rule = pow_pow_decomp - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - return op_node_idx + return [merge_powers] - rule = pow_decomp - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - return op_node_idx + return [repeat_pow_base] - def _add_controlled_decomp_node(self, op_node: CompressedResourceOp, op_node_idx: int) -> int: + def _get_controlled_decompositions( + self, op_node: CompressedResourceOp + ) -> list[DecompositionRule]: """Adds a controlled decomposition node to the graph.""" base_class = op_node.params["base_class"] @@ -267,29 +262,20 @@ def _add_controlled_decomp_node(self, op_node: CompressedResourceOp, op_node_idx # Handle controlled global phase if base_class is qml.GlobalPhase: - rule = controlled_global_phase_decomp - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - return op_node_idx + return [controlled_global_phase_decomp] # Handle controlled-X gates if base_class is qml.X: - rule = controlled_x_decomp - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - return op_node_idx + return [controlled_x_decomp] # Handle custom controlled ops if (base_class, num_control_wires) in base_to_custom_ctrl_op(): custom_op_type = base_to_custom_ctrl_op()[(base_class, num_control_wires)] - rule = CustomControlledDecomposition(custom_op_type) - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - return op_node_idx + return [CustomControlledDecomposition(custom_op_type)] # General case - for base_decomposition in self._get_decompositions(base_class): - rule = ControlledBaseDecomposition(base_decomposition) - self._add_decomp_rule_to_op(rule, op_node, op_node_idx) - - return op_node_idx + base_rep = resource_rep(base_class, **op_node.params["base_params"]) + return [ControlledBaseDecomposition(rule) for rule in self._get_decompositions(base_rep)] def _recursively_add_decomposition_node( self, rule: DecompositionRule, decomp_resource: Resources @@ -303,6 +289,11 @@ def _recursively_add_decomposition_node( d_node = _DecompositionNode(rule, decomp_resource) d_node_idx = self._graph.add_node(d_node) + if not decomp_resource.gate_counts: + # If an operator decomposes to nothing (e.g., a Hadamard raised to a + # power of 2), we must still connect something to this decomposition + # node so that it is accounted for. + self._graph.add_edge(self._start, d_node_idx, 0) for op in decomp_resource.gate_counts: op_node_idx = self._recursively_add_op_node(op) self._graph.add_edge(op_node_idx, d_node_idx, (op_node_idx, d_node_idx)) @@ -318,17 +309,12 @@ def solve(self, lazy=True): """ self._visitor = _DecompositionSearchVisitor(self._graph, self._original_ops_indices, lazy) - start = self._graph.add_node("dummy") - self._graph.add_edges_from( - [(start, op_node_idx, 1) for op_node_idx in self._target_ops_indices] - ) rx.dijkstra_search( self._graph, - source=[start], + source=[self._start], weight_fn=self._visitor.edge_weight, visitor=self._visitor, ) - self._graph.remove_node(start) if self._visitor.unsolved_op_indices: unsolved_ops = [self._graph[op_idx] for op_idx in self._visitor.unsolved_op_indices] op_names = set(op.name for op in unsolved_ops) @@ -466,6 +452,8 @@ def examine_edge(self, edge): return # nothing is to be done for edges leading to an operator node if target_idx not in self.distances: self.distances[target_idx] = Resources() # initialize with empty resource + if src_node is None: + return # special case for when the decomposition produces nothing self.distances[target_idx] += self.distances[src_idx] * target_node.count(src_node) if target_idx not in self._num_edges_examined: self._num_edges_examined[target_idx] = 0 @@ -481,7 +469,7 @@ def edge_relaxed(self, edge): """Triggered when an edge is relaxed during the Dijkstra search.""" src_idx, target_idx, _ = edge target_node = self._graph[target_idx] - if self._graph[src_idx] == "dummy": + if self._graph[src_idx] is None and not isinstance(target_node, _DecompositionNode): self.distances[target_idx] = Resources({target_node: 1}) elif isinstance(target_node, CompressedResourceOp): self.predecessors[target_idx] = src_idx @@ -498,3 +486,12 @@ class _DecompositionNode: def count(self, op: CompressedResourceOp): """Find the number of occurrences of an operator in the decomposition.""" return self.decomp_resource.gate_counts.get(op, 0) + + +def _to_name(op): + if isinstance(op, type): + return op.__name__ + if isinstance(op, CompressedResourceOp): + return op.name + assert isinstance(op, str) + return translate_op_alias(op) diff --git a/pennylane/decomposition/decomposition_rule.py b/pennylane/decomposition/decomposition_rule.py index 9d1f79b24a3..9625c4973de 100644 --- a/pennylane/decomposition/decomposition_rule.py +++ b/pennylane/decomposition/decomposition_rule.py @@ -24,6 +24,7 @@ from pennylane.operation import Operator from .resources import CompressedResourceOp, Resources, resource_rep +from .utils import translate_op_alias @overload @@ -191,7 +192,7 @@ def __init__(self, func: Callable, resources: Callable | dict): self._source = inspect.getsource(func) if isinstance(resources, dict): - def resource_fn(): + def resource_fn(*_, **__): return resources self._compute_resources = resource_fn @@ -233,10 +234,10 @@ def _auto_wrap(op_type): _decompositions = defaultdict(list) -"""dict[type, list[DecompositionRule]]: A dictionary mapping operator types to decomposition rules.""" +"""dict[str, list[DecompositionRule]]: A dictionary mapping operator names to decomposition rules.""" -def add_decomps(op_type: Type[Operator], *decomps: DecompositionRule) -> None: +def add_decomps(op_type: Type[Operator] | str, *decomps: DecompositionRule) -> None: """Globally registers new decomposition rules with an operator class. .. note:: @@ -253,7 +254,8 @@ def add_decomps(op_type: Type[Operator], *decomps: DecompositionRule) -> None: decomposition rules that may be chosen if they lead to a more resource-efficient decomposition. Args: - op_type: the operator type for which new decomposition rules are specified. + op_type (type or str): the operator type for which new decomposition rules are specified. + For symbolic operators, use strings such as ``"Adjoint(RY)"``, ``"Pow(H)"``, ``"C(RX)"``, etc. decomps (DecompositionRule): new decomposition rules to add to the given ``op_type``. A decomposition is a quantum function registered with a resource estimate using ``qml.register_resources``. @@ -289,6 +291,17 @@ def my_hadamard2(wires): for the duration of the session. To add alternative decompositions for a particular circuit as opposed to globally, use the ``alt_decomps`` argument of the :func:`~pennylane.transforms.decompose` transform. + Custom decomposition rules can also be specified for symbolic operators. In this case, the + operator type can be specified as a string. For example, + + .. code-block:: python + + @register_resources({qml.RY: 1}) + def adjoint_ry(phi, wires, **_): + qml.RY(-phi, wires=wires) + + qml.add_decomps("Adjoint(RY)", adjoint_ry) + .. seealso:: :func:`~pennylane.transforms.decompose` """ @@ -297,10 +310,12 @@ def my_hadamard2(wires): "A decomposition rule must be a qfunc with a resource estimate " "registered using qml.register_resources" ) - _decompositions[op_type].extend(decomps) + if isinstance(op_type, type): + op_type = op_type.__name__ + _decompositions[translate_op_alias(op_type)].extend(decomps) -def list_decomps(op_type: Type[Operator]) -> list[DecompositionRule]: +def list_decomps(op_type: Type[Operator] | str) -> list[DecompositionRule]: """Lists all stored decomposition rules for an operator class. .. note:: @@ -311,7 +326,8 @@ def list_decomps(op_type: Type[Operator]) -> list[DecompositionRule]: decomposition rules for an operator. Args: - op_type: the operator class to retrieve decomposition rules for. + op_type (type or str): the operator class to retrieve decomposition rules for. For symbolic + operators, use strings such as ``"Adjoint(RY)"``, ``"Pow(H)"``, ``"C(RX)"``, etc. Returns: list[DecompositionRule]: a list of decomposition rules registered for the given operator. @@ -338,10 +354,12 @@ def _crx_to_rx_cz(phi, wires, **__): 1: ──RX(0.25)─╰Z──RX(-0.25)─╰Z─┤ """ - return _decompositions[op_type][:] + if isinstance(op_type, type): + op_type = op_type.__name__ + return _decompositions[translate_op_alias(op_type)][:] -def has_decomp(op_type: Type[Operator]) -> bool: +def has_decomp(op_type: Type[Operator] | str) -> bool: """Checks whether an operator has decomposition rules defined. .. note:: @@ -352,10 +370,14 @@ def has_decomp(op_type: Type[Operator]) -> bool: decomposition rules for an operator. Args: - op_type: the operator class to check for decomposition rules. + op_type (type or str): the operator class to check for decomposition rules. For symbolic + operators, use strings such as ``"Adjoint(RY)"``, ``"Pow(H)"``, ``"C(RX)"``, etc. Returns: bool: whether decomposition rules are defined for the given operator. """ + if isinstance(op_type, type): + op_type = op_type.__name__ + op_type = translate_op_alias(op_type) return op_type in _decompositions and len(_decompositions[op_type]) > 0 diff --git a/pennylane/decomposition/symbolic_decomposition.py b/pennylane/decomposition/symbolic_decomposition.py index e9d6ada4997..15f6fb94128 100644 --- a/pennylane/decomposition/symbolic_decomposition.py +++ b/pennylane/decomposition/symbolic_decomposition.py @@ -76,12 +76,11 @@ def _adjoint_adjoint_resource(*_, base_params, **__): return {resource_rep(base_class, **base_params): 1} +# pylint: disable=protected-access @register_resources(_adjoint_adjoint_resource) -def adjoint_adjoint_decomp(*params, wires, base): # pylint: disable=unused-argument +def cancel_adjoint(*params, wires, base): # pylint: disable=unused-argument """Decompose the adjoint of the adjoint of a gate.""" - _, [_, metadata] = base.base._flatten() # pylint: disable=protected-access - new_struct = wires, metadata - base.base._unflatten(params, new_struct) # pylint: disable=protected-access + base.base._unflatten(*base.base._flatten()) def _adjoint_controlled_resource(base_class, base_params): @@ -157,7 +156,7 @@ def _pow_resource(base_class, base_params, z): @register_resources(_pow_resource) -def pow_decomp(*_, base, z, **__): +def repeat_pow_base(*_, base, z, **__): """Decompose the power of a gate.""" assert isinstance(z, int) and z >= 0 for _ in range(z): @@ -175,7 +174,7 @@ def _pow_pow_resource(base_class, base_params, z): # pylint: disable=unused-arg @register_resources(_pow_pow_resource) -def pow_pow_decomp(*_, base, z, **__): +def merge_powers(*_, base, z, **__): """Decompose the power of the power of a gate.""" qml.pow(base.base, z=z * base.z) diff --git a/pennylane/decomposition/utils.py b/pennylane/decomposition/utils.py index 3203ee217d2..b911aacd90f 100644 --- a/pennylane/decomposition/utils.py +++ b/pennylane/decomposition/utils.py @@ -41,7 +41,7 @@ def translate_op_alias(op_alias): """Translates an operator alias to its proper name.""" if op_alias in OP_NAME_ALIASES: return OP_NAME_ALIASES[op_alias] - if match := re.match(r"C\((\w+)\)", op_alias): + if match := re.match(r"(?:C|Controlled)\((\w+)\)", op_alias): base_op_name = match.group(1) return f"C({translate_op_alias(base_op_name)})" if match := re.match(r"Adjoint\((\w+)\)", op_alias): diff --git a/tests/decomposition/conftest.py b/tests/decomposition/conftest.py index b148d292144..f57a0267eaf 100644 --- a/tests/decomposition/conftest.py +++ b/tests/decomposition/conftest.py @@ -35,7 +35,7 @@ def _cz_to_cnot(*_, **__): raise NotImplementedError -decompositions[qml.CZ] = [_cz_to_cnot] +decompositions["CZ"] = [_cz_to_cnot] @qml.register_resources({qml.Hadamard: 2, qml.CZ: 1}) @@ -43,7 +43,7 @@ def _cnot_to_cz(*_, **__): raise NotImplementedError -decompositions[qml.CNOT] = [_cnot_to_cz] +decompositions["CNOT"] = [_cnot_to_cz] def _multi_rz_decomposition_resources(num_wires): @@ -55,7 +55,7 @@ def _multi_rz_decomposition(*_, **__): raise NotImplementedError -decompositions[qml.MultiRZ] = [_multi_rz_decomposition] +decompositions["MultiRZ"] = [_multi_rz_decomposition] @qml.register_resources({qml.RZ: 2, qml.RX: 1, qml.GlobalPhase: 1}) @@ -68,7 +68,7 @@ def _hadamard_to_rz_ry(*_, **__): raise NotImplementedError -decompositions[qml.Hadamard] = [_hadamard_to_rz_rx, _hadamard_to_rz_ry] +decompositions["Hadamard"] = [_hadamard_to_rz_rx, _hadamard_to_rz_ry] @qml.register_resources({qml.RX: 1, qml.RZ: 2}) @@ -76,7 +76,7 @@ def _ry_to_rx_rz(*_, **__): raise NotImplementedError -decompositions[qml.RY] = [_ry_to_rx_rz] +decompositions["RY"] = [_ry_to_rx_rz] @qml.register_resources({qml.RX: 2, qml.CZ: 2}) @@ -84,7 +84,7 @@ def _crx_to_rx_cz(*_, **__): raise NotImplementedError -decompositions[qml.CRX] = [_crx_to_rx_cz] +decompositions["CRX"] = [_crx_to_rx_cz] @qml.register_resources({qml.RZ: 3, qml.CNOT: 2, qml.GlobalPhase: 1}) @@ -92,7 +92,7 @@ def _cphase_to_rz_cnot(*_, **__): raise NotImplementedError -decompositions[qml.ControlledPhaseShift] = [_cphase_to_rz_cnot] +decompositions["ControlledPhaseShift"] = [_cphase_to_rz_cnot] @qml.register_resources({qml.RZ: 1, qml.GlobalPhase: 1}) @@ -100,7 +100,7 @@ def _phase_shift_to_rz_gp(*_, **__): raise NotImplementedError -decompositions[qml.PhaseShift] = [_phase_shift_to_rz_gp] +decompositions["PhaseShift"] = [_phase_shift_to_rz_gp] @qml.register_resources({qml.RX: 1, qml.GlobalPhase: 1}) @@ -108,7 +108,7 @@ def _x_to_rx(*_, **__): raise NotImplementedError -decompositions[qml.X] = [_x_to_rx] +decompositions["PauliX"] = [_x_to_rx] @qml.register_resources({qml.PhaseShift: 1}) @@ -116,7 +116,7 @@ def _u1_ps(phi, wires, **__): qml.PhaseShift(phi, wires=wires) -decompositions[qml.U1] = [_u1_ps] +decompositions["U1"] = [_u1_ps] @qml.register_resources({qml.PhaseShift: 1}) @@ -124,4 +124,24 @@ def _t_ps(wires, **__): raise NotImplementedError -decompositions[qml.T] = [_t_ps] +decompositions["T"] = [_t_ps] + + +@qml.register_resources({qml.H: 1}) +def _adjoint_hadamard(*_, wires, **__): + qml.H(wires) + + +decompositions["Adjoint(Hadamard)"] = [_adjoint_hadamard] + + +def _pow_hadamard_resource(z, **__): + return {qml.H: z % 2} + + +@qml.register_resources(_pow_hadamard_resource) +def _pow_hadamard(*_, wires, z, **__): + qml.cond(z % 2 == 1, qml.H)(wires=wires) + + +decompositions["Pow(Hadamard)"] = [_pow_hadamard] diff --git a/tests/decomposition/test_decomp_utils.py b/tests/decomposition/test_decomp_utils.py index 22ef50a4b2d..48b02d7a908 100644 --- a/tests/decomposition/test_decomp_utils.py +++ b/tests/decomposition/test_decomp_utils.py @@ -63,5 +63,6 @@ def test_translate_op_alias(base_op_alias, expected_op_name): assert translate_op_alias(base_op_alias) == expected_op_name assert translate_op_alias(f"C({base_op_alias})") == f"C({expected_op_name})" + assert translate_op_alias(f"Controlled({base_op_alias})") == f"C({expected_op_name})" assert translate_op_alias(f"Adjoint({base_op_alias})") == f"Adjoint({expected_op_name})" assert translate_op_alias(f"Pow({base_op_alias})") == f"Pow({expected_op_name})" diff --git a/tests/decomposition/test_decomposition_graph.py b/tests/decomposition/test_decomposition_graph.py index 67c777464df..956c17645ca 100644 --- a/tests/decomposition/test_decomposition_graph.py +++ b/tests/decomposition/test_decomposition_graph.py @@ -30,13 +30,15 @@ adjoint_resource_rep, controlled_resource_rep, pow_resource_rep, + resource_rep, ) +from pennylane.decomposition.decomposition_graph import _to_name @pytest.mark.unit @patch( "pennylane.decomposition.decomposition_graph.list_decomps", - side_effect=lambda x: decompositions[x], + side_effect=lambda x: decompositions[_to_name(x)], ) class TestDecompositionGraph: @@ -55,14 +57,14 @@ def custom_hadamard_2(wires): qml.RY(np.pi / 2, wires=wires) graph = DecompositionGraph(operations=[qml.Hadamard(0)], gate_set={"RX", "RY", "RZ"}) - assert graph._get_decompositions(qml.Hadamard) == decompositions[qml.Hadamard] + assert graph._get_decompositions(resource_rep(qml.H)) == decompositions["Hadamard"] graph = DecompositionGraph( operations=[qml.Hadamard(0)], gate_set={"RX", "RY", "RZ"}, fixed_decomps={qml.Hadamard: custom_hadamard}, ) - assert graph._get_decompositions(qml.Hadamard) == [custom_hadamard] + assert graph._get_decompositions(resource_rep(qml.H)) == [custom_hadamard] alt_dec = [custom_hadamard, custom_hadamard_2] graph = DecompositionGraph( @@ -70,8 +72,8 @@ def custom_hadamard_2(wires): gate_set={"RX", "RY", "RZ"}, alt_decomps={qml.Hadamard: alt_dec}, ) - exp_dec = alt_dec + decompositions[qml.Hadamard] - assert graph._get_decompositions(qml.Hadamard) == exp_dec + exp_dec = alt_dec + decompositions["Hadamard"] + assert graph._get_decompositions(resource_rep(qml.H)) == exp_dec graph = DecompositionGraph( operations=[qml.Hadamard(0)], @@ -79,24 +81,26 @@ def custom_hadamard_2(wires): alt_decomps={qml.Hadamard: alt_dec}, fixed_decomps={qml.Hadamard: custom_hadamard}, ) - assert graph._get_decompositions(qml.Hadamard) == [custom_hadamard] + assert graph._get_decompositions(resource_rep(qml.H)) == [custom_hadamard] def test_graph_construction(self, _): """Tests constructing a graph from a single Hadamard.""" op = qml.Hadamard(wires=[0]) graph = DecompositionGraph(operations=[op], gate_set={"RX", "RZ", "GlobalPhase"}) - # 5 ops and 3 decompositions (2 for Hadamard and 1 for RY) - assert len(graph._graph.nodes()) == 8 - # 8 edges from ops to decompositions and 3 from decompositions to ops - assert len(graph._graph.edges()) == 11 + # 5 ops and 3 decompositions (2 for Hadamard and 1 for RY) and 1 dummy starting node + assert len(graph._graph.nodes()) == 9 + # 8 edges from ops to decompositions, 3 from decompositions to ops, and 3 from the + # dummy starting node to the target gate set. + assert len(graph._graph.edges()) == 14 # Check that graph construction stops at gates in the target gate set. graph2 = DecompositionGraph(operations=[op], gate_set={"RY", "RZ", "GlobalPhase"}) - # 5 ops and 2 decompositions (RY is in the target gate set now) - assert len(graph2._graph.nodes()) == 7 - # 6 edges from ops to decompositions and 2 from decompositions to ops - assert len(graph2._graph.edges()) == 8 + # 5 ops and 2 decompositions (RY is in the target gate set now), and the dummy starting node + assert len(graph2._graph.nodes()) == 8 + # 6 edges from ops to decompositions and 2 from decompositions to ops, + # and 3 from the dummy starting node to the target gate set. + assert len(graph2._graph.edges()) == 11 def test_graph_construction_non_applicable_rules(self, _): """Tests rules that raise DecompositionNotApplicable are skipped.""" @@ -133,10 +137,12 @@ def some_other_rule(*_, **__): gate_set={"CNOT", "RZ"}, alt_decomps={CustomOp: [some_rule, some_other_rule]}, ) - # 3 ops (CustomOp, CNOT, RZ) and 1 decompositions (only some_other_rule) - assert len(graph._graph.nodes()) == 4 - # 2 edges from ops to decompositions and 1 from decompositions to ops - assert len(graph._graph.edges()) == 3 + # 3 ops (CustomOp, CNOT, RZ) and 1 decompositions (only some_other_rule), + # and the dummy starting node + assert len(graph._graph.nodes()) == 5 + # 2 edges from ops to decompositions, 1 from decompositions to ops, + # and 2 from the dummy starting node to the target gate set + assert len(graph._graph.edges()) == 5 def test_gate_set(self, _): """Tests that graph construction stops at the target gate set.""" @@ -173,10 +179,12 @@ def custom_decomp(wires): fixed_decomps={CustomOp: custom_decomp}, ) - # 1 node for CustomOp, 1 decomposition node, and 5 for the ops in the decomposition - assert len(graph._graph.nodes()) == 7 - # 5 edges from ops to decompositions and 1 edge from decompositions to ops - assert len(graph._graph.edges()) == 6 + # 1 node for CustomOp, 1 decomposition node, 5 for the ops in the decomposition, + # and the dummy starting node. + assert len(graph._graph.nodes()) == 8 + # 5 edges from ops to decompositions, 1 edge from decompositions to ops, and 5 + # edges from the dummy starting node to the target gate set. + assert len(graph._graph.edges()) == 11 def test_graph_solve(self, _): """Tests solving a simple graph for the optimal decompositions.""" @@ -284,9 +292,11 @@ def _custom_decomp(*_, **__): ) # 10 ops (CustomOp, MultiRZ(4), MultiRZ(3), CNOT, CZ, RX, RY, RZ, Hadamard, GlobalPhase) # 7 decompositions (1 for CustomOp, 1 for each of the two MultiRZs, 1 for CNOT, 2 for Hadamard, and 1 for RY) - assert len(graph._graph.nodes()) == 17 - # 16 edges from ops to decompositions and 7 from decompositions to ops - assert len(graph._graph.edges()) == 23 + # and the dummy starting node + assert len(graph._graph.nodes()) == 18 + # 16 edges from ops to decompositions and 7 from decompositions to ops, + # and 4 edges from the dummy starting node to the target gate set + assert len(graph._graph.edges()) == 27 graph.solve() assert graph.resource_estimate(op) == to_resources( @@ -317,10 +327,11 @@ def test_controlled_global_phase(self, _): op1 = qml.ctrl(qml.GlobalPhase(0.5), control=[1]) op2 = qml.ctrl(qml.GlobalPhase(0.5), control=[1, 2]) graph = DecompositionGraph([op1, op2], gate_set={"ControlledPhaseShift", "PhaseShift"}) - # 4 op nodes and 2 decomposition nodes. - assert len(graph._graph.nodes()) == 6 - # 2 edges from decompositions to ops and 2 edges from ops to decompositions - assert len(graph._graph.edges()) == 4 + # 4 op nodes and 2 decomposition nodes, and 1 dummy starting node. + assert len(graph._graph.nodes()) == 7 + # 2 edges from decompositions to ops and 2 edges from ops to decompositions, + # and 2 edges from the dummy starting node to the target gate set. + assert len(graph._graph.edges()) == 6 # Verify the decompositions graph.solve() @@ -342,10 +353,11 @@ def test_custom_controlled_op(self, _): operations=[op1, op2], gate_set={"CNOT", "CH"}, ) - # 4 op nodes and 2 decomposition nodes. - assert len(graph._graph.nodes()) == 6 + # 4 op nodes and 2 decomposition nodes, and the dummy starting node + assert len(graph._graph.nodes()) == 7 # 2 edges from decompositions to ops and 2 edges from ops to decompositions - assert len(graph._graph.edges()) == 4 + # and 2 edges from the dummy starting node to the target gate set. + assert len(graph._graph.edges()) == 6 # Verify the decompositions graph.solve() @@ -413,10 +425,11 @@ def custom_controlled_decomp(wires): CustomControlledOp: [custom_controlled_decomp], }, ) - # 18 op nodes and 16 decomposition nodes. - assert len(graph._graph.nodes()) == 34 + # 18 op nodes and 16 decomposition nodes, and the dummy starting node + assert len(graph._graph.nodes()) == 35 # 16 edges from decompositions to ops and 36 edges from ops to decompositions - assert len(graph._graph.edges()) == 52 + # and 6 edge from the dummy starting node to the target gate set. + assert len(graph._graph.edges()) == 58 graph.solve() @@ -443,8 +456,9 @@ def test_adjoint_adjoint(self, _): graph = DecompositionGraph(operations=[op], gate_set={"RX"}) # 2 operator nodes (Adjoint(Adjoint(RX)) and RX), and 1 decomposition node. - assert len(graph._graph.nodes()) == 3 - assert len(graph._graph.edges()) == 2 + # and the dummy starting node + assert len(graph._graph.nodes()) == 4 + assert len(graph._graph.edges()) == 3 graph.solve() with qml.queuing.AnnotatedQueue() as q: @@ -459,19 +473,19 @@ def test_adjoint_pow(self, _): op = qml.adjoint(qml.pow(qml.H(0), z=3)) graph = DecompositionGraph(operations=[op], gate_set={"H"}) - # 3 operator nodes: Adjoint(Pow(H)), Pow(H), and H - # 2 decomposition nodes for Adjoint(Pow(H)) and Pow(H) - assert len(graph._graph.nodes()) == 5 - # 2 edges from decompositions to ops and 2 edges from ops to decompositions - assert len(graph._graph.edges()) == 4 + # 3 operator nodes: Adjoint(Pow(H)), Pow(H), and H, and 1 dummy starting node + # 1 decomposition nodes for Adjoint(Pow(H)) and two decomposition nodes for Pow(H) + assert len(graph._graph.nodes()) == 7 + # 3 edges from decompositions to ops and 3 edges from ops to decompositions + # and 1 edge from the dummy starting node to the target gate set. + assert len(graph._graph.edges()) == 7 graph.solve() with qml.queuing.AnnotatedQueue() as q: graph.decomposition(op)(*op.parameters, wires=op.wires, **op.hyperparameters) assert q.queue == [qml.pow(qml.H(0), z=3)] - # TODO: There should just be a single `H` after we have full support of Pow decompositions. - assert graph.resource_estimate(op) == to_resources({qml.H: 3}) + assert graph.resource_estimate(op) == to_resources({qml.H: 1}) def test_adjoint_custom(self, _): """Tests adjoint of an operator that defines its own adjoint.""" @@ -479,9 +493,9 @@ def test_adjoint_custom(self, _): op = qml.adjoint(qml.RX(0.5, wires=[0])) graph = DecompositionGraph(operations=[op], gate_set={"RX"}) - # 2 operator nodes (Adjoint(RX) and RX), and 1 decomposition node. - assert len(graph._graph.nodes()) == 3 - assert len(graph._graph.edges()) == 2 + # 2 operator nodes (Adjoint(RX) and RX), and 1 decomposition node, and 1 dummy starting node + assert len(graph._graph.nodes()) == 4 + assert len(graph._graph.edges()) == 3 graph.solve() with qml.queuing.AnnotatedQueue() as q: @@ -499,9 +513,11 @@ def test_adjoint_controlled(self, _): graph = DecompositionGraph(operations=[op, op2], gate_set={"ControlledPhaseShift", "CRX"}) # 5 operator nodes: Adjoint(C(RX)), Adjoint(C(U1)), CRX, C(U1), ControlledPhaseShift # 3 decomposition nodes leading into Adjoint(C(RX)), Adjoint(C(U1)), and C(U1) - assert len(graph._graph.nodes()) == 8 - # 3 edges from decompositions to ops and 3 edges from ops to decompositions - assert len(graph._graph.edges()) == 6 + # and the dummy starting node + assert len(graph._graph.nodes()) == 9 + # 3 edges from decompositions to ops and 3 edges from ops to decompositions, + # and 2 edges from the dummy starting node to the target gate set. + assert len(graph._graph.edges()) == 8 graph.solve() with qml.queuing.AnnotatedQueue() as q: @@ -544,10 +560,13 @@ def custom_decomp(phi, wires): alt_decomps={CustomOp: [custom_decomp]}, ) # 10 operator nodes: A(CustomOp), A(H), A(CNOT), A(RX), A(T), H, CNOT, RX, A(PhaseShift), PhaseShift - # 6 decomposition nodes for: A(CustomOp), A(H), A(CNOT), A(RX), A(T), A(PhaseShift) - assert len(graph._graph.nodes()) == 16 - # 9 edges from ops to decompositions and 6 edges from decompositions to ops. - assert len(graph._graph.edges()) == 15 + # 5 decomposition nodes for: A(CustomOp), A(CNOT), A(RX), A(T), A(PhaseShift) + # 2 decomposition nodes for A(H) + # 1 dummy starting node + assert len(graph._graph.nodes()) == 18 + # 10 edges from ops to decompositions and 7 edges from decompositions to ops. + # and 4 edges from the dummy starting node to the target gate set. + assert len(graph._graph.edges()) == 21 graph.solve() with qml.queuing.AnnotatedQueue() as q: @@ -564,27 +583,65 @@ def custom_decomp(phi, wires): {qml.H: 1, qml.CNOT: 2, qml.RX: 1, qml.PhaseShift: 1} ) - def test_pow_pow(self, _): + def test_nested_powers(self, _): """Tests nested power decompositions.""" op = qml.pow(qml.pow(qml.H(0), 3), 2) graph = DecompositionGraph(operations=[op], gate_set={"H"}) # 3 operator nodes: Pow(Pow(H)), Pow(H), and H - # 2 decomposition nodes for Pow(Pow(H)) and Pow(H) - assert len(graph._graph.nodes()) == 5 - # 2 edges from decompositions to ops and 2 edges from ops to decompositions - assert len(graph._graph.edges()) == 4 + # 1 decomposition nodes for Pow(Pow(H)) and 2 nodes for Pow(H) + # and the dummy starting node + assert len(graph._graph.nodes()) == 7 + # 3 edges from decompositions to ops and 3 edges from ops to decompositions + # and 1 edge from the dummy starting node to the target gate set. + assert len(graph._graph.edges()) == 7 graph.solve() with qml.queuing.AnnotatedQueue() as q: graph.decomposition(op)(*op.parameters, wires=op.wires, **op.hyperparameters) assert q.queue == [qml.pow(qml.H(0), 6)] - assert graph.resource_estimate(op) == to_resources({qml.H: 6}) + assert graph.resource_estimate(op) == to_resources({}) op2 = qml.pow(qml.H(0), 6) with qml.queuing.AnnotatedQueue() as q: graph.decomposition(op2)(*op2.parameters, wires=op2.wires, **op2.hyperparameters) - assert q.queue == [qml.H(0), qml.H(0), qml.H(0), qml.H(0), qml.H(0), qml.H(0)] - assert graph.resource_estimate(op2) == to_resources({qml.H: 6}) + assert q.queue == [] + assert graph.resource_estimate(op2) == to_resources({}) + + def test_custom_symbolic_decompositions(self, _): + """Tests that custom symbolic decompositions are used.""" + + @qml.register_resources({qml.RX: 1}) + def my_adjoint_rx(theta, wires, **__): + qml.RX(-theta, wires=wires) + + graph = DecompositionGraph( + operations=[ + qml.adjoint(qml.H(0)), + qml.pow(qml.H(1), 3), + qml.ops.Controlled(qml.H(0), control_wires=1), + qml.adjoint(qml.RX(0.5, wires=0)), + ], + fixed_decomps={"Adjoint(RX)": my_adjoint_rx}, + gate_set={"H", "CH", "RX"}, + ) + + op1 = qml.adjoint(qml.H(0)) + op2 = qml.pow(qml.H(1), 3) + op3 = qml.ops.Controlled(qml.H(0), control_wires=1) + op4 = qml.adjoint(qml.RX(0.5, wires=0)) + + graph.solve() + with qml.queuing.AnnotatedQueue() as q: + graph.decomposition(op1)(*op1.parameters, wires=op1.wires, **op1.hyperparameters) + graph.decomposition(op2)(*op2.parameters, wires=op2.wires, **op2.hyperparameters) + graph.decomposition(op3)(*op3.parameters, wires=op3.wires, **op3.hyperparameters) + graph.decomposition(op4)(*op4.parameters, wires=op4.wires, **op4.hyperparameters) + + assert q.queue == [qml.H(0), qml.H(1), qml.CH(wires=[1, 0]), qml.RX(-0.5, wires=0)] + assert graph.resource_estimate(op1) == to_resources({qml.H: 1}) + assert graph.resource_estimate(op2) == to_resources({qml.H: 1}) + assert graph.resource_estimate(op3) == to_resources({qml.CH: 1}) + assert graph.resource_estimate(op4) == to_resources({qml.RX: 1}) diff --git a/tests/decomposition/test_decomposition_rule.py b/tests/decomposition/test_decomposition_rule.py index 22b251fe097..94f67064294 100644 --- a/tests/decomposition/test_decomposition_rule.py +++ b/tests/decomposition/test_decomposition_rule.py @@ -177,7 +177,19 @@ def custom_decomp4(theta, wires, **__): with pytest.raises(TypeError, match="decomposition rule must be a qfunc with a resource"): qml.add_decomps(CustomOp, custom_decomp4) - _decompositions.pop(CustomOp) # cleanup + _decompositions.pop("CustomOp") # cleanup + + def test_custom_symbolic_decomposition(self): + """Tests that custom decomposition rules for symbolic operators can be registered.""" + + @qml.register_resources({qml.RX: 1, qml.RZ: 1}) + def my_adjoint_custom_op(theta, wires, **__): + qml.RX(theta, wires=wires[0]) + qml.RZ(theta, wires=wires[1]) + + qml.add_decomps("Adjoint(CustomOp)", my_adjoint_custom_op) + assert qml.decomposition.has_decomp("Adjoint(CustomOp)") + assert qml.list_decomps("Adjoint(CustomOp)") == [my_adjoint_custom_op] def test_auto_wrap_in_resource_op(self): """Tests that simply classes can be auto-wrapped in a ``CompressionResourceOp``.""" diff --git a/tests/decomposition/test_symbolic_decomposition.py b/tests/decomposition/test_symbolic_decomposition.py index 54ceea53cc6..b704fb0a6f1 100644 --- a/tests/decomposition/test_symbolic_decomposition.py +++ b/tests/decomposition/test_symbolic_decomposition.py @@ -20,11 +20,11 @@ from pennylane.decomposition.resources import Resources, pow_resource_rep from pennylane.decomposition.symbolic_decomposition import ( AdjointDecomp, - adjoint_adjoint_decomp, adjoint_controlled_decomp, adjoint_pow_decomp, - pow_decomp, - pow_pow_decomp, + cancel_adjoint, + merge_powers, + repeat_pow_base, same_type_adjoint_decomp, ) from tests.decomposition.conftest import to_resources @@ -39,12 +39,10 @@ def test_adjoint_adjoint(self): op = qml.adjoint(qml.adjoint(qml.RX(0.5, wires=0))) with qml.queuing.AnnotatedQueue() as q: - adjoint_adjoint_decomp(*op.parameters, wires=op.wires, **op.hyperparameters) + cancel_adjoint(*op.parameters, wires=op.wires, **op.hyperparameters) assert q.queue == [qml.RX(0.5, wires=0)] - assert adjoint_adjoint_decomp.compute_resources(**op.resource_params) == to_resources( - {qml.RX: 1} - ) + assert cancel_adjoint.compute_resources(**op.resource_params) == to_resources({qml.RX: 1}) @pytest.mark.jax def test_adjoint_adjoint_capture(self): @@ -58,7 +56,7 @@ def test_adjoint_adjoint_capture(self): qml.capture.enable() def circuit(): - adjoint_adjoint_decomp(*op.parameters, wires=op.wires, **op.hyperparameters) + cancel_adjoint(*op.parameters, wires=op.wires, **op.hyperparameters) plxpr = qml.capture.make_plxpr(circuit)() collector = CollectOpsandMeas() @@ -207,10 +205,10 @@ def test_pow_pow(self): op = qml.pow(qml.pow(qml.H(0), 3), 2) with qml.queuing.AnnotatedQueue() as q: - pow_pow_decomp(*op.parameters, wires=op.wires, **op.hyperparameters) + merge_powers(*op.parameters, wires=op.wires, **op.hyperparameters) assert q.queue == [qml.pow(qml.H(0), 6)] - assert pow_pow_decomp.compute_resources(**op.resource_params) == to_resources( + assert merge_powers.compute_resources(**op.resource_params) == to_resources( {pow_resource_rep(qml.H, {}, 6): 1} ) @@ -219,10 +217,10 @@ def test_pow_general(self): op = qml.pow(qml.H(0), 3) with qml.queuing.AnnotatedQueue() as q: - pow_decomp(*op.parameters, wires=op.wires, **op.hyperparameters) + repeat_pow_base(*op.parameters, wires=op.wires, **op.hyperparameters) assert q.queue == [qml.H(0), qml.H(0), qml.H(0)] - assert pow_decomp.compute_resources(**op.resource_params) == to_resources({qml.H: 3}) + assert repeat_pow_base.compute_resources(**op.resource_params) == to_resources({qml.H: 3}) @pytest.mark.jax def test_pow_general_capture(self): @@ -236,7 +234,7 @@ def test_pow_general_capture(self): qml.capture.enable() def circuit(): - pow_decomp(*op.parameters, wires=op.wires, **op.hyperparameters) + repeat_pow_base(*op.parameters, wires=op.wires, **op.hyperparameters) plxpr = qml.capture.make_plxpr(circuit)() collector = CollectOpsandMeas() @@ -251,7 +249,7 @@ def test_non_integer_pow_not_implemented(self): op = qml.pow(qml.H(0), 0.5) with pytest.raises(NotImplementedError, match="Non-integer or negative powers"): - pow_decomp.compute_resources(**op.resource_params) + repeat_pow_base.compute_resources(**op.resource_params) op = qml.pow(qml.H(0), -1) with pytest.raises(NotImplementedError, match="Non-integer or negative powers"): - pow_decomp.compute_resources(**op.resource_params) + repeat_pow_base.compute_resources(**op.resource_params)