|
3 | 3 | from typing import Tuple
|
4 | 4 |
|
5 | 5 | from jax import Array, custom_vjp
|
| 6 | +from jax.numpy import real as jnpreal |
6 | 7 |
|
7 |
| -from horqrux.abstract import Operator, Parametric |
| 8 | +from horqrux.abstract import Parametric, Primitive |
8 | 9 | from horqrux.apply import apply_gate
|
9 |
| -from horqrux.utils import OperationType, overlap |
| 10 | +from horqrux.utils import OperationType, inner |
10 | 11 |
|
11 | 12 |
|
12 | 13 | def expectation(
|
13 |
| - state: Array, gates: list[Operator], observable: list[Operator], values: dict[str, float] |
| 14 | + state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float] |
14 | 15 | ) -> Array:
|
15 | 16 | out_state = apply_gate(state, gates, values, OperationType.UNITARY)
|
16 | 17 | projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
|
17 |
| - return overlap(out_state, projected_state) |
| 18 | + return jnpreal(inner(out_state, projected_state)) |
18 | 19 |
|
19 | 20 |
|
20 | 21 | @custom_vjp
|
21 | 22 | def adjoint_expectation(
|
22 |
| - state: Array, gates: list[Operator], observable: list[Operator], values: dict[str, float] |
| 23 | + state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float] |
23 | 24 | ) -> Array:
|
24 | 25 | return expectation(state, gates, observable, values)
|
25 | 26 |
|
26 | 27 |
|
27 | 28 | def adjoint_expectation_fwd(
|
28 |
| - state: Array, gates: list[Operator], observable: list[Operator], values: dict[str, float] |
29 |
| -) -> Tuple[Array, Tuple[Array, Array, list[Operator], dict[str, float]]]: |
| 29 | + state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float] |
| 30 | +) -> Tuple[Array, Tuple[Array, Array, list[Primitive], dict[str, float]]]: |
30 | 31 | out_state = apply_gate(state, gates, values, OperationType.UNITARY)
|
31 | 32 | projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
|
32 |
| - return overlap(out_state, projected_state), (out_state, projected_state, gates, values) |
| 33 | + return jnpreal(inner(out_state, projected_state)), (out_state, projected_state, gates, values) |
33 | 34 |
|
34 | 35 |
|
35 | 36 | def adjoint_expectation_bwd(
|
36 |
| - res: Tuple[Array, Array, list[Operator], dict[str, float]], tangent: Array |
| 37 | + res: Tuple[Array, Array, list[Primitive], dict[str, float]], tangent: Array |
37 | 38 | ) -> tuple:
|
38 | 39 | out_state, projected_state, gates, values = res
|
39 | 40 | grads = {}
|
40 | 41 | for gate in gates[::-1]:
|
41 | 42 | out_state = apply_gate(out_state, gate, values, OperationType.DAGGER)
|
42 | 43 | if isinstance(gate, Parametric):
|
43 | 44 | mu = apply_gate(out_state, gate, values, OperationType.JACOBIAN)
|
44 |
| - grads[gate.param] = tangent * 2 * overlap(mu, projected_state) |
| 45 | + grads[gate.param] = tangent * 2 * jnpreal(inner(mu, projected_state)) |
45 | 46 | projected_state = apply_gate(projected_state, gate, values, OperationType.DAGGER)
|
46 | 47 | return (None, None, None, grads)
|
47 | 48 |
|
|
0 commit comments