Skip to content

Commit 491cdb3

Browse files
[Feature] Add HamiltonianEvolution, Support py3.12 (#11)
* [Feature] Add HamiltonianEvolution, Support py3.12
1 parent a64c900 commit 491cdb3

File tree

9 files changed

+112
-28
lines changed

9 files changed

+112
-28
lines changed

.github/workflows/run-tests-and-mypy.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
runs-on: ubuntu-latest
2323
strategy:
2424
matrix:
25-
python-version: ["3.9", "3.10", "3.11"]
25+
python-version: ["3.9", "3.10", "3.11", "3.12"]
2626
steps:
2727
- name: Checkout main code and submodules
2828
uses: actions/checkout@v4

docs/index.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ from typing import Any, Callable
8787
from uuid import uuid4
8888

8989
from horqrux.adjoint import adjoint_expectation
90-
from horqrux.abstract import Operator
91-
from horqrux import Z, RX, RY, NOT, zero_state, apply_gate, overlap
90+
from horqrux.abstract import Primitive
91+
from horqrux import Z, RX, RY, NOT, zero_state, apply_gate
9292

9393

9494
n_qubits = 5
@@ -121,9 +121,9 @@ class Circuit:
121121

122122
def __post_init__(self) -> None:
123123
# We will use a featuremap of RX rotations to encode some classical data
124-
self.feature_map: list[Operator] = [RX('phi', i) for i in range(n_qubits)]
124+
self.feature_map: list[Primitive] = [RX('phi', i) for i in range(n_qubits)]
125125
self.ansatz, self.param_names = ansatz_w_params(self.n_qubits, self.n_layers)
126-
self.observable: list[Operator] = [Z(0)]
126+
self.observable: list[Primitive] = [Z(0)]
127127

128128
@partial(vmap, in_axes=(None, None, 0))
129129
def forward(self, param_values: Array, x: Array) -> Array:

horqrux/abstract.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222

2323
@register_pytree_node_class
2424
@dataclass
25-
class Operator:
26-
"""Abstract class which stores information about generators target and control qubits
25+
class Primitive:
26+
"""Primitive gate class which stores information about generators target and control qubits
2727
of a particular quantum operator."""
2828

2929
generator_name: str
@@ -42,11 +42,11 @@ def parse_idx(
4242
return (idx.astype(int),)
4343

4444
def __post_init__(self) -> None:
45-
self.target = Operator.parse_idx(self.target)
45+
self.target = Primitive.parse_idx(self.target)
4646
if self.control is None:
4747
self.control = none_like(self.target)
4848
else:
49-
self.control = Operator.parse_idx(self.control)
49+
self.control = Primitive.parse_idx(self.control)
5050

5151
def __iter__(self) -> Iterable:
5252
return iter((self.generator_name, self.target, self.control))
@@ -74,9 +74,6 @@ def __repr__(self) -> str:
7474
return self.name + f"(target={self.target[0]}, control={self.control[0]})"
7575

7676

77-
Primitive = Operator
78-
79-
8077
@register_pytree_node_class
8178
@dataclass
8279
class Parametric(Primitive):

horqrux/adjoint.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -3,45 +3,46 @@
33
from typing import Tuple
44

55
from jax import Array, custom_vjp
6+
from jax.numpy import real as jnpreal
67

7-
from horqrux.abstract import Operator, Parametric
8+
from horqrux.abstract import Parametric, Primitive
89
from horqrux.apply import apply_gate
9-
from horqrux.utils import OperationType, overlap
10+
from horqrux.utils import OperationType, inner
1011

1112

1213
def expectation(
13-
state: Array, gates: list[Operator], observable: list[Operator], values: dict[str, float]
14+
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
1415
) -> Array:
1516
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
1617
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
17-
return overlap(out_state, projected_state)
18+
return jnpreal(inner(out_state, projected_state))
1819

1920

2021
@custom_vjp
2122
def adjoint_expectation(
22-
state: Array, gates: list[Operator], observable: list[Operator], values: dict[str, float]
23+
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
2324
) -> Array:
2425
return expectation(state, gates, observable, values)
2526

2627

2728
def adjoint_expectation_fwd(
28-
state: Array, gates: list[Operator], observable: list[Operator], values: dict[str, float]
29-
) -> Tuple[Array, Tuple[Array, Array, list[Operator], dict[str, float]]]:
29+
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
30+
) -> Tuple[Array, Tuple[Array, Array, list[Primitive], dict[str, float]]]:
3031
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
3132
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
32-
return overlap(out_state, projected_state), (out_state, projected_state, gates, values)
33+
return jnpreal(inner(out_state, projected_state)), (out_state, projected_state, gates, values)
3334

3435

3536
def adjoint_expectation_bwd(
36-
res: Tuple[Array, Array, list[Operator], dict[str, float]], tangent: Array
37+
res: Tuple[Array, Array, list[Primitive], dict[str, float]], tangent: Array
3738
) -> tuple:
3839
out_state, projected_state, gates, values = res
3940
grads = {}
4041
for gate in gates[::-1]:
4142
out_state = apply_gate(out_state, gate, values, OperationType.DAGGER)
4243
if isinstance(gate, Parametric):
4344
mu = apply_gate(out_state, gate, values, OperationType.JACOBIAN)
44-
grads[gate.param] = tangent * 2 * overlap(mu, projected_state)
45+
grads[gate.param] = tangent * 2 * jnpreal(inner(mu, projected_state))
4546
projected_state = apply_gate(projected_state, gate, values, OperationType.DAGGER)
4647
return (None, None, None, grads)
4748

horqrux/analog.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
5+
from jax import Array
6+
from jax.scipy.linalg import expm
7+
from jax.tree_util import register_pytree_node_class
8+
9+
from .abstract import Primitive, QubitSupport
10+
11+
12+
@register_pytree_node_class
13+
@dataclass
14+
class _HamiltonianEvolution(Primitive):
15+
"""
16+
A slim wrapper class which evolves a 'hamiltonian'
17+
given a 'time_evolution' parameter and applies it to 'state' psi by doing: matrixexp(-iHt)|psi>
18+
"""
19+
20+
generator_name: str
21+
target: QubitSupport
22+
control: QubitSupport
23+
24+
def unitary(self, values: dict[str, Array] = dict()) -> Array:
25+
return expm(values["hamiltonian"] * (-1j * values["time_evolution"]))
26+
27+
28+
def HamiltonianEvolution(
29+
target: QubitSupport, control: QubitSupport = (None,)
30+
) -> _HamiltonianEvolution:
31+
return _HamiltonianEvolution("I", target, control)

horqrux/apply.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
from jax import Array
1010

11-
from horqrux.abstract import Operator
11+
from horqrux.abstract import Primitive
1212

1313
from .utils import OperationType, State, _controlled, is_controlled
1414

@@ -53,7 +53,7 @@ def apply_operator(
5353

5454
def apply_gate(
5555
state: State,
56-
gate: Operator | Iterable[Operator],
56+
gate: Primitive | Iterable[Primitive],
5757
values: dict[str, float] = dict(),
5858
op_type: OperationType = OperationType.UNITARY,
5959
) -> State:
@@ -68,7 +68,7 @@ def apply_gate(
6868
State after applying 'gate'.
6969
"""
7070
operator: Tuple[Array, ...]
71-
if isinstance(gate, Operator):
71+
if isinstance(gate, Primitive):
7272
operator_fn = getattr(gate, op_type)
7373
operator, target, control = (operator_fn(values),), gate.target, gate.control
7474
else:

horqrux/utils.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,12 @@ def equivalent_state(s0: Array, s1: Array) -> bool:
118118
return jnp.allclose(overlap(s0, s1), 1.0, atol=ATOL) # type: ignore[no-any-return]
119119

120120

121+
def inner(state: Array, projection: Array) -> Array:
122+
return jnp.dot(jnp.conj(state.flatten()), projection.flatten())
123+
124+
121125
def overlap(state: Array, projection: Array) -> Array:
122-
return jnp.real(jnp.dot(jnp.conj(state.flatten()), projection.flatten()))
126+
return jnp.real(jnp.power(inner(state, projection), 2))
123127

124128

125129
def uniform_state(
@@ -150,3 +154,7 @@ def _normalize(wf: Array) -> Array:
150154
return _normalize(
151155
(jnp.sqrt(x / sumx) * jnp.exp(1j * phases)).reshape(tuple(2 for _ in range(n_qubits)))
152156
)
157+
158+
159+
def is_normalized(state: Array) -> bool:
160+
return equivalent_state(state, state)

pyproject.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ authors = [
99
{ name = "Gert-Jan Both" , email = "[email protected]" },
1010
{ name = "Dominik Seitz", email = "[email protected]" },
1111
]
12-
requires-python = ">=3.9,<3.12"
12+
requires-python = ">=3.8,<3.13"
1313
license = {text = "Apache 2.0"}
1414

15-
version = "0.5.0"
15+
version = "0.6.0"
1616

1717
classifiers=[
1818
"License :: Other/Proprietary License",
@@ -21,6 +21,7 @@ classifiers=[
2121
"Programming Language :: Python :: 3.9",
2222
"Programming Language :: Python :: 3.10",
2323
"Programming Language :: Python :: 3.11",
24+
"Programming Language :: Python :: 3.12",
2425
"Programming Language :: Python :: Implementation :: CPython",
2526
"Programming Language :: Python :: Implementation :: PyPy",
2627
]

tests/test_analog.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from __future__ import annotations
2+
3+
import jax.numpy as jnp
4+
import numpy as np
5+
import pytest
6+
from jax import jit, vmap
7+
8+
from horqrux.analog import HamiltonianEvolution
9+
from horqrux.apply import apply_gate
10+
from horqrux.utils import is_normalized, overlap, random_state, uniform_state
11+
12+
sigmaz = jnp.diag(jnp.array([1.0, -1.0], dtype=jnp.cdouble))
13+
Hbase = jnp.kron(sigmaz, sigmaz)
14+
15+
Hamiltonian = jnp.kron(Hbase, Hbase)
16+
17+
18+
def test_hamevo_single() -> None:
19+
n_qubits = 4
20+
t_evo = jnp.pi / 4
21+
hamevo = HamiltonianEvolution(tuple([i for i in range(n_qubits)]))
22+
psi = uniform_state(n_qubits)
23+
psi_star = apply_gate(psi, hamevo, {"hamiltonian": Hamiltonian, "time_evolution": t_evo})
24+
result = overlap(psi_star, psi)
25+
assert jnp.isclose(result, 0.5)
26+
27+
28+
def Hamiltonian_general(n_qubits: int = 2, batch_size: int = 1) -> jnp.array:
29+
H_batch = jnp.zeros((batch_size, 2**n_qubits, 2**n_qubits), dtype=jnp.cdouble)
30+
for i in range(batch_size):
31+
H_0 = np.random.uniform(0.0, 1.0, (2**n_qubits, 2**n_qubits)).astype(np.cdouble)
32+
H = H_0 + jnp.conj(H_0.transpose(0, 1))
33+
H_batch.at[i, :, :].set(H)
34+
return H_batch
35+
36+
37+
@pytest.mark.parametrize("n_qubits, batch_size", [(2, 1), (4, 2)])
38+
def test_hamevo_general(n_qubits: int, batch_size: int) -> None:
39+
H = Hamiltonian_general(n_qubits, batch_size)
40+
t_evo = np.random.uniform(0, 1, (batch_size, 1))
41+
hamevo = HamiltonianEvolution(tuple([i for i in range(n_qubits)]))
42+
psi = random_state(n_qubits)
43+
psi_star = jit(vmap(apply_gate, in_axes=(None, None, {"hamiltonian": 0, "time_evolution": 0})))(
44+
psi, hamevo, {"hamiltonian": H, "time_evolution": t_evo}
45+
)
46+
assert jnp.all(vmap(is_normalized, in_axes=(0,))(psi_star))

0 commit comments

Comments
 (0)