Skip to content

Commit a64c900

Browse files
[Feature] Add custom vjp with adjoint differentiation (#10)
Co-authored-by: Gert-Jan Both <[email protected]>
1 parent d54c02d commit a64c900

File tree

8 files changed

+155
-37
lines changed

8 files changed

+155
-37
lines changed

docs/index.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ new_state = apply_gate(state, RX(param_value, target_qubit, control_qubit))
7070

7171
We can now build a fully differentiable variational circuit by simply defining a sequence of gates
7272
and a set of initial parameter values we want to optimize.
73-
Lets fit a function using a simple circuit class wrapper.
73+
Horqrux provides an implementation of [adjoint differentiation](https://arxiv.org/abs/2009.02823),
74+
which we can use to fit a function using a simple circuit class wrapper.
7475

7576
```python exec="on" source="material-block" html="1"
7677
from __future__ import annotations
@@ -85,6 +86,7 @@ from operator import add
8586
from typing import Any, Callable
8687
from uuid import uuid4
8788

89+
from horqrux.adjoint import adjoint_expectation
8890
from horqrux.abstract import Operator
8991
from horqrux import Z, RX, RY, NOT, zero_state, apply_gate, overlap
9092

@@ -127,8 +129,7 @@ class Circuit:
127129
def forward(self, param_values: Array, x: Array) -> Array:
128130
state = zero_state(self.n_qubits)
129131
param_dict = {name: val for name, val in zip(self.param_names, param_values)}
130-
state = apply_gate(state, self.feature_map + self.ansatz, {**param_dict, **{'phi': x}})
131-
return overlap(state, apply_gate(state, self.observable))
132+
return adjoint_expectation(state, self.feature_map + self.ansatz, self.observable, {**param_dict, **{'phi': x}})
132133

133134
def __call__(self, param_values: Array, x: Array) -> Array:
134135
return self.forward(param_values, x)

horqrux/abstract.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __iter__(self) -> Iterable:
5353

5454
def tree_flatten(self) -> Tuple[Tuple, Tuple[str, TargetQubits, ControlQubits]]:
5555
children = ()
56-
aux_data = (self.generator_name, self.target, self.control)
56+
aux_data = (self.generator_name, self.target[0], self.control[0])
5757
return (children, aux_data)
5858

5959
@classmethod
@@ -101,13 +101,20 @@ def parse_val(values: dict[str, float] = dict()) -> float:
101101
def tree_flatten(self) -> Tuple[Tuple, Tuple[str, Tuple, Tuple, str | float]]: # type: ignore[override]
102102
children = ()
103103
aux_data = (
104-
self.name,
105-
self.target,
106-
self.control,
104+
self.generator_name,
105+
self.target[0],
106+
self.control[0],
107107
self.param,
108108
)
109109
return (children, aux_data)
110110

111+
def __iter__(self) -> Iterable:
112+
return iter((self.generator_name, self.target, self.control, self.param))
113+
114+
@classmethod
115+
def tree_unflatten(cls, aux_data: Any, children: Any) -> Any:
116+
return cls(*children, *aux_data)
117+
111118
def unitary(self, values: dict[str, float] = dict()) -> Array:
112119
return _unitary(OPERATIONS_DICT[self.generator_name], self.parse_values(values))
113120

horqrux/adjoint.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from __future__ import annotations
2+
3+
from typing import Tuple
4+
5+
from jax import Array, custom_vjp
6+
7+
from horqrux.abstract import Operator, Parametric
8+
from horqrux.apply import apply_gate
9+
from horqrux.utils import OperationType, overlap
10+
11+
12+
def expectation(
13+
state: Array, gates: list[Operator], observable: list[Operator], values: dict[str, float]
14+
) -> Array:
15+
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
16+
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
17+
return overlap(out_state, projected_state)
18+
19+
20+
@custom_vjp
21+
def adjoint_expectation(
22+
state: Array, gates: list[Operator], observable: list[Operator], values: dict[str, float]
23+
) -> Array:
24+
return expectation(state, gates, observable, values)
25+
26+
27+
def adjoint_expectation_fwd(
28+
state: Array, gates: list[Operator], observable: list[Operator], values: dict[str, float]
29+
) -> Tuple[Array, Tuple[Array, Array, list[Operator], dict[str, float]]]:
30+
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
31+
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
32+
return overlap(out_state, projected_state), (out_state, projected_state, gates, values)
33+
34+
35+
def adjoint_expectation_bwd(
36+
res: Tuple[Array, Array, list[Operator], dict[str, float]], tangent: Array
37+
) -> tuple:
38+
out_state, projected_state, gates, values = res
39+
grads = {}
40+
for gate in gates[::-1]:
41+
out_state = apply_gate(out_state, gate, values, OperationType.DAGGER)
42+
if isinstance(gate, Parametric):
43+
mu = apply_gate(out_state, gate, values, OperationType.JACOBIAN)
44+
grads[gate.param] = tangent * 2 * overlap(mu, projected_state)
45+
projected_state = apply_gate(projected_state, gate, values, OperationType.DAGGER)
46+
return (None, None, None, grads)
47+
48+
49+
adjoint_expectation.defvjp(adjoint_expectation_fwd, adjoint_expectation_bwd)

horqrux/apply.py

+25-19
Original file line numberDiff line numberDiff line change
@@ -10,67 +10,73 @@
1010

1111
from horqrux.abstract import Operator
1212

13-
from .utils import State, _controlled, is_controlled
13+
from .utils import OperationType, State, _controlled, is_controlled
1414

1515

1616
def apply_operator(
1717
state: State,
18-
unitary: Array,
18+
operator: Array,
1919
target: Tuple[int, ...],
2020
control: Tuple[int | None, ...],
2121
) -> State:
22-
"""Applies a unitary, i.e. a single array of shape [2, 2, ...], on a given state
22+
"""Applies an operator, i.e. a single array of shape [2, 2, ...], on a given state
2323
of shape [2 for _ in range(n_qubits)] for a given set of target and control qubits.
24-
In case of control qubits, the 'unitary' array will be embedded into a controlled array.
24+
In case of a controlled operation, the 'operator' array will be embedded into a controlled array.
2525
26-
Since dimension 'i' in 'state' corresponds to all amplitudes which are affected by qubit 'i',
27-
target and control qubits correspond to dimensions to contract 'unitary' over.
28-
Contraction over qubit 'i' means applying the 'dot' operation between 'unitary' and dimension 'i'
26+
Since dimension 'i' in 'state' corresponds to all amplitudes where qubit 'i' is 1,
27+
target and control qubits represent the dimensions over which to contract the 'operator'.
28+
Contraction means applying the 'dot' operation between the operator array and dimension 'i'
2929
of 'state, resulting in a new state where the result of the 'dot' operation has been moved to
3030
dimension 'i' of 'state'. To restore the former order of dimensions, the affected dimensions
3131
are moved to their original positions and the state is returned.
3232
3333
Arguments:
3434
state: State to operate on.
35-
unitary: Array to contract over 'state'.
35+
operator: Array to contract over 'state'.
3636
target: Tuple of target qubits on which to apply the 'operator' to.
3737
control: Tuple of control qubits.
3838
3939
Returns:
40-
State after applying 'unitary'.
40+
State after applying 'operator'.
4141
"""
4242
state_dims: Tuple[int, ...] = target
4343
if is_controlled(control):
44-
unitary = _controlled(unitary, len(control))
44+
operator = _controlled(operator, len(control))
4545
state_dims = (*control, *target) # type: ignore[arg-type]
46-
n_qubits = int(np.log2(unitary.size))
47-
unitary = unitary.reshape(tuple(2 for _ in np.arange(n_qubits)))
48-
op_dims = tuple(np.arange(unitary.ndim // 2, unitary.ndim, dtype=int))
49-
state = jnp.tensordot(a=unitary, b=state, axes=(op_dims, state_dims))
46+
n_qubits = int(np.log2(operator.size))
47+
operator = operator.reshape(tuple(2 for _ in np.arange(n_qubits)))
48+
op_dims = tuple(np.arange(operator.ndim // 2, operator.ndim, dtype=int))
49+
state = jnp.tensordot(a=operator, b=state, axes=(op_dims, state_dims))
5050
new_state_dims = tuple(i for i in range(len(state_dims)))
5151
return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims)
5252

5353

5454
def apply_gate(
55-
state: State, gate: Operator | Iterable[Operator], values: dict[str, float] = dict()
55+
state: State,
56+
gate: Operator | Iterable[Operator],
57+
values: dict[str, float] = dict(),
58+
op_type: OperationType = OperationType.UNITARY,
5659
) -> State:
5760
"""Wrapper function for 'apply_operator' which applies a gate or a series of gates to a given state.
5861
Arguments:
5962
state: State to operate on.
6063
gate: Gate(s) to apply.
64+
values: A dictionary with parameter values.
65+
op_type: The type of operation to perform: Unitary, Dagger or Jacobian.
6166
6267
Returns:
6368
State after applying 'gate'.
6469
"""
65-
unitary: Tuple[Array, ...]
70+
operator: Tuple[Array, ...]
6671
if isinstance(gate, Operator):
67-
unitary, target, control = (gate.unitary(values),), gate.target, gate.control
72+
operator_fn = getattr(gate, op_type)
73+
operator, target, control = (operator_fn(values),), gate.target, gate.control
6874
else:
69-
unitary = tuple(g.unitary(values) for g in gate)
75+
operator = tuple(getattr(g, op_type)(values) for g in gate)
7076
target = reduce(add, [g.target for g in gate])
7177
control = reduce(add, [g.control for g in gate])
7278
return reduce(
7379
lambda state, gate: apply_operator(state, *gate),
74-
zip(unitary, target, control),
80+
zip(operator, target, control),
7581
state,
7682
)

horqrux/primitive.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def I(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
1010
"""Identity / I gate. This function returns an instance of 'Primitive' and does *not* apply the gate.
1111
By providing tuple of ints to 'control', it turns into a controlled gate.
1212
13-
Example usage: I(1) applies I to qubit 1.
13+
Example usage: I(1) represents the instruction to apply I to qubit 1.
1414
1515
Args:
1616
target: Tuple of ints describing the qubits to apply to.
@@ -26,8 +26,8 @@ def X(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
2626
"""X gate. This function returns an instance of 'Primitive' and does *not* apply the gate.
2727
By providing tuple of ints to 'control', it turns into a controlled gate.
2828
29-
Example usage: X(1) applies X to qubit 1.
30-
Example usage controlled: X(1, 0) applies CX / CNOT to qubit 1 with controlled qubit 0.
29+
Example usage: X(1) represents the instruction to apply X to qubit 1.
30+
Example usage controlled: X(1, 0) represents the instruction to apply CX / CNOT to qubit 1 with controlled qubit 0.
3131
3232
Args:
3333
target: Tuple of ints describing the qubits to apply to.
@@ -46,8 +46,8 @@ def Y(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
4646
"""Y gate. This function returns an instance of 'Primitive' and does *not* apply the gate.
4747
By providing tuple of ints to 'control', it turns into a controlled gate.
4848
49-
Example usage: Y(1) applies X to qubit 1.
50-
Example usage controlled: Y(1, 0) applies CY to qubit 1 with controlled qubit 0.
49+
Example usage: Y(1) represents the instruction to apply X to qubit 1.
50+
Example usage controlled: Y(1, 0) represents the instruction to apply CY to qubit 1 with controlled qubit 0.
5151
5252
Args:
5353
target: Tuple of ints describing the qubits to apply to.
@@ -63,8 +63,8 @@ def Z(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
6363
"""Z gate. This function returns an instance of 'Primitive' and does *not* apply the gate.
6464
By providing tuple of ints to 'control', it turns into a controlled gate.
6565
66-
Example usage: Z(1) applies Z to qubit 1.
67-
Example usage controlled: Z(1, 0) applies CZ to qubit 1 with controlled qubit 0.
66+
Example usage: Z(1) represents the instruction to apply Z to qubit 1.
67+
Example usage controlled: Z(1, 0) represents the instruction to apply CZ to qubit 1 with controlled qubit 0.
6868
6969
Args:
7070
target: Tuple of ints describing the qubits to apply to.
@@ -80,7 +80,7 @@ def H(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
8080
"""H/ Hadamard gate. This function returns an instance of 'Primitive' and does *not* apply the gate.
8181
By providing tuple of ints to 'control', it turns into a controlled gate.
8282
83-
Example usage: H(1) applies Hadamard to qubit 1.
83+
Example usage: H(1) represents the instruction to apply Hadamard to qubit 1.
8484
8585
Args:
8686
target: Tuple of ints describing the qubits to apply to.
@@ -96,7 +96,7 @@ def S(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
9696
"""S gate or constant phase gate. This function returns an instance of 'Primitive' and does *not* apply the gate.
9797
By providing tuple of ints to 'control', it turns into a controlled gate.
9898
99-
Example usage: S(1) applies S to qubit 1.
99+
Example usage: S(1) represents the instruction to apply S to qubit 1.
100100
101101
Args:
102102
target: Tuple of ints describing the qubits to apply to.
@@ -112,7 +112,7 @@ def T(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
112112
"""T gate. This function returns an instance of 'Primitive' and does *not* apply the gate.
113113
By providing tuple of ints to 'control', it turns into a controlled gate.
114114
115-
Example usage: T(1) applies Hadamard to qubit 1.
115+
Example usage: T(1) represents the instruction to apply Hadamard to qubit 1.
116116
117117
Args:
118118
target: Tuple of ints describing the qubits to apply to.

horqrux/utils.py

+18
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from enum import Enum
34
from typing import Any, Iterable, Tuple, Union
45

56
import jax
@@ -16,6 +17,23 @@
1617
ATOL = 1e-014
1718

1819

20+
class StrEnum(str, Enum):
21+
def __str__(self) -> str:
22+
"""Used when dumping enum fields in a schema."""
23+
ret: str = self.value
24+
return ret
25+
26+
@classmethod
27+
def list(cls) -> list[str]:
28+
return list(map(lambda c: c.value, cls)) # type: ignore
29+
30+
31+
class OperationType(StrEnum):
32+
UNITARY = "unitary"
33+
DAGGER = "dagger"
34+
JACOBIAN = "jacobian"
35+
36+
1937
def _dagger(operator: Array) -> Array:
2038
return jnp.conjugate(operator.T)
2139

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ authors = [
1212
requires-python = ">=3.9,<3.12"
1313
license = {text = "Apache 2.0"}
1414

15-
version = "0.4.0"
15+
version = "0.5.0"
1616

1717
classifiers=[
1818
"License :: Other/Proprietary License",

tests/test_adjoint.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from __future__ import annotations
2+
3+
import jax.numpy as jnp
4+
import numpy as np
5+
from jax import Array, grad
6+
7+
from horqrux import random_state
8+
from horqrux.adjoint import adjoint_expectation, expectation
9+
from horqrux.parametric import PHASE, RX, RY, RZ
10+
from horqrux.primitive import NOT, H, I, S, T, X, Y, Z
11+
12+
MAX_QUBITS = 7
13+
PARAMETRIC_GATES = (RX, RY, RZ, PHASE)
14+
PRIMITIVE_GATES = (NOT, H, X, Y, Z, I, S, T)
15+
16+
17+
def test_gradcheck() -> None:
18+
ops = [RX("theta", 0), RY("epsilon", 0), RX("phi", 0), NOT(1, 0), RX("omega", 0, 1)]
19+
observable = [Z(0)]
20+
values = {
21+
"theta": np.random.uniform(0, 1),
22+
"epsilon": np.random.uniform(0, 1),
23+
"phi": np.random.uniform(0, 1),
24+
"omega": np.random.uniform(0, 1),
25+
}
26+
state = random_state(MAX_QUBITS)
27+
28+
def adjoint_expfn(values) -> Array:
29+
return adjoint_expectation(state, ops, observable, values)
30+
31+
def ad_expfn(values) -> Array:
32+
return expectation(state, ops, observable, values)
33+
34+
grads_adjoint = grad(adjoint_expfn)(values)
35+
grad_ad = grad(ad_expfn)(values)
36+
for param, ad_grad in grad_ad.items():
37+
assert jnp.isclose(grads_adjoint[param], ad_grad, atol=0.09)

0 commit comments

Comments
 (0)