From 5cb88063b9cbb4eeeccd7bf875006d12030dbb1a Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 6 May 2025 14:56:33 +0100 Subject: [PATCH] add PrepSelPrep to new decomposition framework --- pennylane/decomposition/resources.py | 2 +- pennylane/ops/op_math/composite.py | 1 + pennylane/ops/op_math/prod.py | 22 ++++++++++ pennylane/ops/qubit/state_preparation.py | 18 ++++++++ pennylane/templates/embeddings/amplitude.py | 7 +-- .../templates/state_preparations/mottonen.py | 19 ++++++++ .../templates/subroutines/prepselprep.py | 31 +++++++++++++ pennylane/templates/subroutines/select.py | 44 +++++++++++++++++-- tests/ops/op_math/test_prod.py | 41 +++++++++++++++++ tests/ops/qubit/test_state_prep.py | 27 ++++++++++++ 10 files changed, 202 insertions(+), 10 deletions(-) diff --git a/pennylane/decomposition/resources.py b/pennylane/decomposition/resources.py index d4f71c06aaa..652e4c68370 100644 --- a/pennylane/decomposition/resources.py +++ b/pennylane/decomposition/resources.py @@ -160,7 +160,7 @@ def _validate_resource_rep(op_type, params): if not issubclass(op_type, qml.operation.Operator): raise TypeError(f"op_type must be a type of Operator, got {op_type}") - if not isinstance(op_type.resource_keys, set): + if not isinstance(op_type.resource_keys, (set, frozenset)): raise TypeError( f"{op_type.__name__}.resource_keys must be a set, not a {type(op_type.resource_keys)}" ) diff --git a/pennylane/ops/op_math/composite.py b/pennylane/ops/op_math/composite.py index 1cdfdb379e9..c26cda61614 100644 --- a/pennylane/ops/op_math/composite.py +++ b/pennylane/ops/op_math/composite.py @@ -88,6 +88,7 @@ def __init__( self._pauli_rep = self._build_pauli_rep() if _pauli_rep is None else _pauli_rep self.queue() self._batch_size = _UNSET_BATCH_SIZE + self._hyperparameters = {"operands": operands} @handle_recursion_error def _check_batching(self): diff --git a/pennylane/ops/op_math/prod.py b/pennylane/ops/op_math/prod.py index 0108408aec7..532d53b60a3 100644 --- a/pennylane/ops/op_math/prod.py +++ b/pennylane/ops/op_math/prod.py @@ -16,6 +16,7 @@ computing the product between operations. """ import itertools +from collections import Counter from copy import copy from functools import reduce, wraps from itertools import combinations @@ -230,6 +231,13 @@ def circuit(weights): """ + resource_keys = frozenset({"resources"}) + + @property + def resource_params(self): + resources = Counter(qml.resource_rep(type(op), **op.resource_params) for op in self) + return {"resources": resources} + _op_symbol = "@" _math_op = math.prod grad_method = None @@ -467,6 +475,20 @@ def terms(self): return coeffs, ops +def _prod_resources(resources): + return resources + + +# pylint: disable=unused-argument +@qml.register_resources(_prod_resources) +def _prod_decomp(*_, wires=None, operands): + for op in operands: + op._unflatten(*op._flatten()) # pylint: disable=protected-access + + +qml.add_decomps(Prod, _prod_decomp) + + def _swappable_ops(op1, op2, wire_map: dict = None) -> bool: """Boolean expression that indicates if op1 and op2 don't have intersecting wires and if they should be swapped when sorting them by wire values. diff --git a/pennylane/ops/qubit/state_preparation.py b/pennylane/ops/qubit/state_preparation.py index d0268afb2f2..6d301a500dd 100644 --- a/pennylane/ops/qubit/state_preparation.py +++ b/pennylane/ops/qubit/state_preparation.py @@ -337,6 +337,12 @@ def circuit(state=None): """ + resource_keys = frozenset({"num_wires"}) + + @property + def resource_params(self): + return {"num_wires": len(self.wires)} + num_params = 1 """int: Number of trainable parameters that the operator depends on.""" @@ -574,6 +580,18 @@ def _preprocess_csr(state, wires, pad_with, normalize, validate_norm): return state +def _stateprep_resources(num_wires): + return {qml.resource_rep(qml.MottonenStatePreparation, num_wires=num_wires): 1} + + +@register_resources(_stateprep_resources) +def _state_prep_decomp(state, wires, **_): + qml.MottonenStatePreparation(state, wires) + + +add_decomps(StatePrep, _state_prep_decomp) + + class QubitDensityMatrix(Operation): r"""QubitDensityMatrix(state, wires) Prepare subsystems using the given density matrix. diff --git a/pennylane/templates/embeddings/amplitude.py b/pennylane/templates/embeddings/amplitude.py index ea940fdfdd0..d2931afb067 100644 --- a/pennylane/templates/embeddings/amplitude.py +++ b/pennylane/templates/embeddings/amplitude.py @@ -14,13 +14,10 @@ r""" Contains the AmplitudeEmbedding template. """ -# pylint: disable-msg=too-many-branches,too-many-arguments,protected-access from pennylane.ops import StatePrep -# tolerance for normalization -TOLERANCE = 1e-10 - +# pylint: disable=too-many-arguments class AmplitudeEmbedding(StatePrep): r"""Encodes :math:`2^n` features into the amplitude vector of :math:`n` qubits. @@ -108,7 +105,7 @@ def circuit(f=None): """ def __init__( - self, features, wires, pad_with=None, normalize=False, id=None, validate_norm=True + self, features, wires, *, pad_with=None, normalize=False, id=None, validate_norm=True ): super().__init__( features, diff --git a/pennylane/templates/state_preparations/mottonen.py b/pennylane/templates/state_preparations/mottonen.py index 82c5e114768..468a7ab8736 100644 --- a/pennylane/templates/state_preparations/mottonen.py +++ b/pennylane/templates/state_preparations/mottonen.py @@ -286,6 +286,12 @@ def circuit(state): """ + resource_keys = frozenset({"num_wires"}) + + @property + def resource_params(self): + return {"num_wires": len(self.wires)} + grad_method = None ndim_params = (1,) @@ -390,3 +396,16 @@ def compute_decomposition(state_vector, wires): # pylint: disable=arguments-dif op_list.extend([qml.GlobalPhase(global_phase, wires=wires)]) return op_list + + +def _mottonen_resources(num_wires): + n = sum(2**i for i in range(num_wires)) + + return {qml.GlobalPhase: 1, qml.RY: n, qml.RZ: n, qml.CNOT: 2 * (n - 1)} + + +mottonen_decomp = qml.register_resources( + _mottonen_resources, MottonenStatePreparation.compute_decomposition +) + +qml.add_decomps(MottonenStatePreparation, mottonen_decomp) diff --git a/pennylane/templates/subroutines/prepselprep.py b/pennylane/templates/subroutines/prepselprep.py index bb100bfeab2..0419da1a87b 100644 --- a/pennylane/templates/subroutines/prepselprep.py +++ b/pennylane/templates/subroutines/prepselprep.py @@ -71,6 +71,14 @@ class PrepSelPrep(Operation): [ 0.75 0.25]] """ + resource_keys = frozenset({"num_control", "op_reps"}) + + @property + def resource_params(self): + ops = self.lcu.terms()[1] + op_reps = tuple(qml.resource_rep(type(op), **op.resource_params) for op in ops) + return {"op_reps": op_reps, "num_control": len(self.control)} + grad_method = None def __init__(self, lcu, control=None, id=None): @@ -203,3 +211,26 @@ def target_wires(self): def wires(self): """All wires involved in the operation.""" return self.hyperparameters["control"] + self.hyperparameters["target_wires"] + + +def _prepselprep_resources(op_reps, num_control): + return { + qml.resource_rep(qml.Select, op_reps=op_reps, num_control_wires=num_control): 1, + qml.resource_rep(qml.StatePrep, num_wires=num_control): 1, + qml.resource_rep( + qml.ops.Adjoint, base_class=qml.StatePrep, base_params={"num_wires": num_control} + ): 1, + } + + +# pylint: disable=unused-argument +@qml.register_resources(_prepselprep_resources) +def _prepselprep_decomp(*_, wires, lcu, coeffs, ops, control, target_wires): + coeffs, ops = _get_new_terms(lcu) + sqrt_coeffs = qml.math.sqrt(coeffs) + qml.StatePrep(sqrt_coeffs, normalize=True, pad_with=0, wires=control) + qml.Select(ops, control) + qml.adjoint(qml.StatePrep(sqrt_coeffs, normalize=True, pad_with=0, wires=control)) + + +qml.add_decomps(PrepSelPrep, _prepselprep_decomp) diff --git a/pennylane/templates/subroutines/select.py b/pennylane/templates/subroutines/select.py index 66a54f09033..5d6d2522a37 100644 --- a/pennylane/templates/subroutines/select.py +++ b/pennylane/templates/subroutines/select.py @@ -64,6 +64,13 @@ class Select(Operation): """ + resource_keys = {"op_reps", "num_control_wires"} + + @property + def resource_params(self): + op_reps = tuple(qml.resource_rep(type(op), **op.resource_params) for op in self.ops) + return {"op_reps": op_reps, "num_control_wires": len(self.control)} + def _flatten(self): return (self.ops), (self.control) @@ -195,11 +202,10 @@ def compute_decomposition( Controlled(Y(2), control_wires=[0, 1], control_values=[True, False]), Controlled(SWAP(wires=[2, 3]), control_wires=[0, 1])] """ - states = list(itertools.product([0, 1], repeat=len(control))) - decomp_ops = [ - qml.ctrl(op, control, control_values=states[index]) for index, op in enumerate(ops) + return [ + qml.ctrl(op, control, control_values=state) + for state, op in zip(itertools.product([0, 1], repeat=len(control)), ops) ] - return decomp_ops @property def ops(self): @@ -220,3 +226,33 @@ def target_wires(self): def wires(self): """All wires involved in the operation.""" return self.hyperparameters["control"] + self.hyperparameters["target_wires"] + + +def _target(target_rep, num_control_wires, state): + return qml.resource_rep( + qml.ops.Controlled, + base_class=target_rep.op_type, + base_params=target_rep.params, + num_control_wires=num_control_wires, + num_work_wires=0, + num_zero_control_values=sum((1 - s for s in state)), + ) + + +def _select_resources(op_reps, num_control_wires): + state_iterator = itertools.product([0, 1], repeat=num_control_wires) + + resources = { + _target(rep, num_control_wires, state): 1 for rep, state in zip(op_reps, state_iterator) + } + return resources + + +# pylint: disable=unused-argument +@qml.register_resources(_select_resources) +def _select_decomp(*_, wires, ops, control, target_wires): + state_iterator = itertools.product([0, 1], repeat=len(control)) + return [qml.ctrl(op, control, control_values=state) for state, op in zip(state_iterator, ops)] + + +qml.add_decomps(Select, _select_decomp) diff --git a/tests/ops/op_math/test_prod.py b/tests/ops/op_math/test_prod.py index 81f7ac17fd3..8f391cf576a 100644 --- a/tests/ops/op_math/test_prod.py +++ b/tests/ops/op_math/test_prod.py @@ -1658,3 +1658,44 @@ def test_swappable_ops(self, op1, op2): def test_non_swappable_ops(self, op1, op2): """Test the check for non-swappable operators.""" assert not _swappable_ops(op1, op2) + + +class TestDecomposition: + + def test_resource_keys(self): + """Test that the resource keys of `Prod` are op_reps.""" + assert Prod.resource_keys == frozenset({"resources"}) + product = qml.X(0) @ qml.Y(1) @ qml.X(2) + resources = {qml.resource_rep(qml.X): 2, qml.resource_rep(qml.Y): 1} + assert product.resource_params == {"resources": resources} + + def test_registered_decomp(self): + """Test that the decomposition of prod is registered.""" + + decomps = qml.decomposition.list_decomps(Prod) + + default_decomp = decomps[0] + _ops = [qml.X(0), qml.X(1), qml.X(2), qml.MultiRZ(0.5, wires=(0, 1))] + resources = {qml.resource_rep(qml.X): 3, qml.resource_rep(qml.MultiRZ, num_wires=2): 1} + + resource_obj = default_decomp.compute_resources(resources=resources) + + assert resource_obj.num_gates == 4 + assert resource_obj.gate_counts == resources + + with qml.queuing.AnnotatedQueue() as q: + default_decomp(operands=_ops) + + assert q.queue == _ops[::-1] + + def test_integration(self, enable_graph_decomposition): + """Test that prod's can be integrated into the decomposition.""" + + op = qml.S(0) @ qml.S(1) @ qml.T(0) @ qml.Y(1) + + graph = qml.decomposition.DecompositionGraph([op], gate_set=qml.ops.__all__) + graph.solve() + with qml.queuing.AnnotatedQueue() as q: + graph.decomposition(op)(**op.hyperparameters) + + assert q.queue == list(op[::-1]) diff --git a/tests/ops/qubit/test_state_prep.py b/tests/ops/qubit/test_state_prep.py index 39bff4abc02..1582e1eeca4 100644 --- a/tests/ops/qubit/test_state_prep.py +++ b/tests/ops/qubit/test_state_prep.py @@ -134,6 +134,33 @@ def test_StatePrep_decomposition(self): assert isinstance(ops1[0], qml.MottonenStatePreparation) assert isinstance(ops2[0], qml.MottonenStatePreparation) + def test_stateprep_resources(self): + """Test the resources for StatePrep""" + + assert qml.StatePrep.resource_keys == frozenset({"num_wires"}) + + op = qml.StatePrep([0, 0, 0, 1], wires=(0, 1)) + assert op.resource_params == {"num_wires": 2} + + def test_decomposition_rule_stateprep(self): + """Test that stateprep has a correct decomposition rule registered.""" + + decomp = qml.list_decomps(qml.StatePrep)[0] + + resource_obj = decomp.compute_resources(num_wires=2) + assert resource_obj.num_gates == 1 + assert resource_obj.gate_counts == { + qml.resource_rep(qml.MottonenStatePreparation, num_wires=2): 1 + } + + with qml.queuing.AnnotatedQueue() as q: + decomp(np.array([0, 0, 0, 1]), wires=(0, 1)) + + qml.assert_equal(q.queue[0], qml.MottonenStatePreparation(np.array([0, 0, 0, 1]), (0, 1))) + + +class TestStatePrepIntegration: + @pytest.mark.parametrize( "state, pad_with, expected", [