Skip to content

[WIP] add PrepSelPrep to new decomposition framework #7385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pennylane/decomposition/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _validate_resource_rep(op_type, params):
if not issubclass(op_type, qml.operation.Operator):
raise TypeError(f"op_type must be a type of Operator, got {op_type}")

if not isinstance(op_type.resource_keys, set):
if not isinstance(op_type.resource_keys, (set, frozenset)):
raise TypeError(
f"{op_type.__name__}.resource_keys must be a set, not a {type(op_type.resource_keys)}"
)
Expand Down
1 change: 1 addition & 0 deletions pennylane/ops/op_math/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
self._pauli_rep = self._build_pauli_rep() if _pauli_rep is None else _pauli_rep
self.queue()
self._batch_size = _UNSET_BATCH_SIZE
self._hyperparameters = {"operands": operands}

@handle_recursion_error
def _check_batching(self):
Expand Down
22 changes: 22 additions & 0 deletions pennylane/ops/op_math/prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
computing the product between operations.
"""
import itertools
from collections import Counter
from copy import copy
from functools import reduce, wraps
from itertools import combinations
Expand Down Expand Up @@ -230,6 +231,13 @@ def circuit(weights):

"""

resource_keys = frozenset({"resources"})

@property
def resource_params(self):
resources = Counter(qml.resource_rep(type(op), **op.resource_params) for op in self)
return {"resources": resources}

_op_symbol = "@"
_math_op = math.prod
grad_method = None
Expand Down Expand Up @@ -467,6 +475,20 @@ def terms(self):
return coeffs, ops


def _prod_resources(resources):
return resources


# pylint: disable=unused-argument
@qml.register_resources(_prod_resources)
def _prod_decomp(*_, wires=None, operands):
for op in operands:
op._unflatten(*op._flatten()) # pylint: disable=protected-access


qml.add_decomps(Prod, _prod_decomp)


def _swappable_ops(op1, op2, wire_map: dict = None) -> bool:
"""Boolean expression that indicates if op1 and op2 don't have intersecting wires and if they
should be swapped when sorting them by wire values.
Expand Down
18 changes: 18 additions & 0 deletions pennylane/ops/qubit/state_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,12 @@ def circuit(state=None):

"""

resource_keys = frozenset({"num_wires"})

@property
def resource_params(self):
return {"num_wires": len(self.wires)}

num_params = 1
"""int: Number of trainable parameters that the operator depends on."""

Expand Down Expand Up @@ -574,6 +580,18 @@ def _preprocess_csr(state, wires, pad_with, normalize, validate_norm):
return state


def _stateprep_resources(num_wires):
return {qml.resource_rep(qml.MottonenStatePreparation, num_wires=num_wires): 1}


@register_resources(_stateprep_resources)
def _state_prep_decomp(state, wires, **_):
qml.MottonenStatePreparation(state, wires)


add_decomps(StatePrep, _state_prep_decomp)


class QubitDensityMatrix(Operation):
r"""QubitDensityMatrix(state, wires)
Prepare subsystems using the given density matrix.
Expand Down
7 changes: 2 additions & 5 deletions pennylane/templates/embeddings/amplitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,10 @@
r"""
Contains the AmplitudeEmbedding template.
"""
# pylint: disable-msg=too-many-branches,too-many-arguments,protected-access
from pennylane.ops import StatePrep

# tolerance for normalization
TOLERANCE = 1e-10


# pylint: disable=too-many-arguments
class AmplitudeEmbedding(StatePrep):
r"""Encodes :math:`2^n` features into the amplitude vector of :math:`n` qubits.

Expand Down Expand Up @@ -108,7 +105,7 @@ def circuit(f=None):
"""

def __init__(
self, features, wires, pad_with=None, normalize=False, id=None, validate_norm=True
self, features, wires, *, pad_with=None, normalize=False, id=None, validate_norm=True
):
super().__init__(
features,
Expand Down
19 changes: 19 additions & 0 deletions pennylane/templates/state_preparations/mottonen.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,12 @@ def circuit(state):

"""

resource_keys = frozenset({"num_wires"})

@property
def resource_params(self):
return {"num_wires": len(self.wires)}

grad_method = None
ndim_params = (1,)

Expand Down Expand Up @@ -390,3 +396,16 @@ def compute_decomposition(state_vector, wires): # pylint: disable=arguments-dif
op_list.extend([qml.GlobalPhase(global_phase, wires=wires)])

return op_list


def _mottonen_resources(num_wires):
n = sum(2**i for i in range(num_wires))

return {qml.GlobalPhase: 1, qml.RY: n, qml.RZ: n, qml.CNOT: 2 * (n - 1)}


mottonen_decomp = qml.register_resources(
_mottonen_resources, MottonenStatePreparation.compute_decomposition
)

qml.add_decomps(MottonenStatePreparation, mottonen_decomp)
31 changes: 31 additions & 0 deletions pennylane/templates/subroutines/prepselprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ class PrepSelPrep(Operation):
[ 0.75 0.25]]
"""

resource_keys = frozenset({"num_control", "op_reps"})

@property
def resource_params(self):
ops = self.lcu.terms()[1]
op_reps = tuple(qml.resource_rep(type(op), **op.resource_params) for op in ops)
return {"op_reps": op_reps, "num_control": len(self.control)}

grad_method = None

def __init__(self, lcu, control=None, id=None):
Expand Down Expand Up @@ -203,3 +211,26 @@ def target_wires(self):
def wires(self):
"""All wires involved in the operation."""
return self.hyperparameters["control"] + self.hyperparameters["target_wires"]


def _prepselprep_resources(op_reps, num_control):
return {
qml.resource_rep(qml.Select, op_reps=op_reps, num_control_wires=num_control): 1,
qml.resource_rep(qml.StatePrep, num_wires=num_control): 1,
qml.resource_rep(
qml.ops.Adjoint, base_class=qml.StatePrep, base_params={"num_wires": num_control}
): 1,
}


# pylint: disable=unused-argument
@qml.register_resources(_prepselprep_resources)
def _prepselprep_decomp(*_, wires, lcu, coeffs, ops, control, target_wires):
coeffs, ops = _get_new_terms(lcu)
sqrt_coeffs = qml.math.sqrt(coeffs)
qml.StatePrep(sqrt_coeffs, normalize=True, pad_with=0, wires=control)
qml.Select(ops, control)
qml.adjoint(qml.StatePrep(sqrt_coeffs, normalize=True, pad_with=0, wires=control))


qml.add_decomps(PrepSelPrep, _prepselprep_decomp)
44 changes: 40 additions & 4 deletions pennylane/templates/subroutines/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ class Select(Operation):

"""

resource_keys = {"op_reps", "num_control_wires"}

@property
def resource_params(self):
op_reps = tuple(qml.resource_rep(type(op), **op.resource_params) for op in self.ops)
return {"op_reps": op_reps, "num_control_wires": len(self.control)}

def _flatten(self):
return (self.ops), (self.control)

Expand Down Expand Up @@ -195,11 +202,10 @@ def compute_decomposition(
Controlled(Y(2), control_wires=[0, 1], control_values=[True, False]),
Controlled(SWAP(wires=[2, 3]), control_wires=[0, 1])]
"""
states = list(itertools.product([0, 1], repeat=len(control)))
decomp_ops = [
qml.ctrl(op, control, control_values=states[index]) for index, op in enumerate(ops)
return [
qml.ctrl(op, control, control_values=state)
for state, op in zip(itertools.product([0, 1], repeat=len(control)), ops)
]
return decomp_ops

@property
def ops(self):
Expand All @@ -220,3 +226,33 @@ def target_wires(self):
def wires(self):
"""All wires involved in the operation."""
return self.hyperparameters["control"] + self.hyperparameters["target_wires"]


def _target(target_rep, num_control_wires, state):
return qml.resource_rep(
qml.ops.Controlled,
base_class=target_rep.op_type,
base_params=target_rep.params,
num_control_wires=num_control_wires,
num_work_wires=0,
num_zero_control_values=sum((1 - s for s in state)),
)


def _select_resources(op_reps, num_control_wires):
state_iterator = itertools.product([0, 1], repeat=num_control_wires)

resources = {
_target(rep, num_control_wires, state): 1 for rep, state in zip(op_reps, state_iterator)
}
return resources


# pylint: disable=unused-argument
@qml.register_resources(_select_resources)
def _select_decomp(*_, wires, ops, control, target_wires):
state_iterator = itertools.product([0, 1], repeat=len(control))
return [qml.ctrl(op, control, control_values=state) for state, op in zip(state_iterator, ops)]


qml.add_decomps(Select, _select_decomp)
41 changes: 41 additions & 0 deletions tests/ops/op_math/test_prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,3 +1658,44 @@ def test_swappable_ops(self, op1, op2):
def test_non_swappable_ops(self, op1, op2):
"""Test the check for non-swappable operators."""
assert not _swappable_ops(op1, op2)


class TestDecomposition:

def test_resource_keys(self):
"""Test that the resource keys of `Prod` are op_reps."""
assert Prod.resource_keys == frozenset({"resources"})
product = qml.X(0) @ qml.Y(1) @ qml.X(2)
resources = {qml.resource_rep(qml.X): 2, qml.resource_rep(qml.Y): 1}
assert product.resource_params == {"resources": resources}

def test_registered_decomp(self):
"""Test that the decomposition of prod is registered."""

decomps = qml.decomposition.list_decomps(Prod)

default_decomp = decomps[0]
_ops = [qml.X(0), qml.X(1), qml.X(2), qml.MultiRZ(0.5, wires=(0, 1))]
resources = {qml.resource_rep(qml.X): 3, qml.resource_rep(qml.MultiRZ, num_wires=2): 1}

resource_obj = default_decomp.compute_resources(resources=resources)

assert resource_obj.num_gates == 4
assert resource_obj.gate_counts == resources

with qml.queuing.AnnotatedQueue() as q:
default_decomp(operands=_ops)

assert q.queue == _ops[::-1]

def test_integration(self, enable_graph_decomposition):
"""Test that prod's can be integrated into the decomposition."""

op = qml.S(0) @ qml.S(1) @ qml.T(0) @ qml.Y(1)

graph = qml.decomposition.DecompositionGraph([op], gate_set=qml.ops.__all__)
graph.solve()
with qml.queuing.AnnotatedQueue() as q:
graph.decomposition(op)(**op.hyperparameters)

assert q.queue == list(op[::-1])
27 changes: 27 additions & 0 deletions tests/ops/qubit/test_state_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,33 @@ def test_StatePrep_decomposition(self):
assert isinstance(ops1[0], qml.MottonenStatePreparation)
assert isinstance(ops2[0], qml.MottonenStatePreparation)

def test_stateprep_resources(self):
"""Test the resources for StatePrep"""

assert qml.StatePrep.resource_keys == frozenset({"num_wires"})

op = qml.StatePrep([0, 0, 0, 1], wires=(0, 1))
assert op.resource_params == {"num_wires": 2}

def test_decomposition_rule_stateprep(self):
"""Test that stateprep has a correct decomposition rule registered."""

decomp = qml.list_decomps(qml.StatePrep)[0]

resource_obj = decomp.compute_resources(num_wires=2)
assert resource_obj.num_gates == 1
assert resource_obj.gate_counts == {
qml.resource_rep(qml.MottonenStatePreparation, num_wires=2): 1
}

with qml.queuing.AnnotatedQueue() as q:
decomp(np.array([0, 0, 0, 1]), wires=(0, 1))

qml.assert_equal(q.queue[0], qml.MottonenStatePreparation(np.array([0, 0, 0, 1]), (0, 1)))


class TestStatePrepIntegration:

@pytest.mark.parametrize(
"state, pad_with, expected",
[
Expand Down