Skip to content

Commit d54c02d

Browse files
Remake (#9)
Remake of horqrux in the style of pyqtorch, i.e., geared towards Qadence.
1 parent dbb49cb commit d54c02d

13 files changed

+800
-603
lines changed

docs/index.md

+127-47
Original file line numberDiff line numberDiff line change
@@ -18,92 +18,172 @@ pip install horqrux
1818
Let's have a look at primitive gates first.
1919

2020
```python exec="on" source="material-block"
21-
from horqrux.gates import X
22-
from horqrux.utils import prepare_state
23-
from horqrux.ops import apply_gate
21+
from horqrux import X, random_state, apply_gate
2422

25-
state = prepare_state(2)
23+
state = random_state(2)
2624
new_state = apply_gate(state, X(0))
2725
```
2826

2927
We can also make any gate controlled, in the case of X, we have to pass the target qubit first!
3028

3129
```python exec="on" source="material-block"
3230
import jax.numpy as jnp
33-
from horqrux.gates import X
34-
from horqrux.utils import prepare_state, equivalent_state
35-
from horqrux.ops import apply_gate
31+
from horqrux import X, product_state, equivalent_state, apply_gate
3632

3733
n_qubits = 2
38-
state = prepare_state(n_qubits, '11')
34+
state = product_state('11')
3935
control = 0
4036
target = 1
4137
# This is equivalent to performing CNOT(0,1)
4238
new_state= apply_gate(state, X(target,control))
43-
assert jnp.allclose(new_state, prepare_state(n_qubits, '10'))
39+
assert jnp.allclose(new_state, product_state('10'))
4440
```
4541

46-
When applying parametric gates, we pass the numeric value for the parameter first
42+
When applying parametric gates, we can either pass a numeric value or a parameter name for the parameter as the first argument.
4743

4844
```python exec="on" source="material-block"
4945
import jax.numpy as jnp
50-
from horqrux.gates import Rx
51-
from horqrux.utils import prepare_state
52-
from horqrux.ops import apply_gate
46+
from horqrux import RX, random_state, apply_gate
5347

5448
target_qubit = 1
55-
state = prepare_state(target_qubit+1)
49+
state = random_state(target_qubit+1)
5650
param_value = 1 / 4 * jnp.pi
57-
new_state = apply_gate(state, Rx(param_value, target_qubit))
51+
new_state = apply_gate(state, RX(param_value, target_qubit))
52+
# Parametric horqrux gates also accept parameter names in the form of strings.
53+
# Simply pass a dictionary of parameter names and values to the 'apply_gate' function
54+
new_state = apply_gate(state, RX('theta', target_qubit), {'theta': jnp.pi})
5855
```
5956

6057
We can also make any parametric gate controlled simply by passing a control qubit.
6158

6259
```python exec="on" source="material-block"
6360
import jax.numpy as jnp
64-
from horqrux.gates import Rx
65-
from horqrux.utils import prepare_state
66-
from horqrux.ops import apply_gate
61+
from horqrux import RX, product_state, apply_gate
6762

6863
n_qubits = 2
6964
target_qubit = 1
7065
control_qubit = 0
71-
state = prepare_state(2, '11')
66+
state = product_state('11')
7267
param_value = 1 / 4 * jnp.pi
73-
new_state = apply_gate(state, Rx(param_value, target_qubit, control_qubit))
68+
new_state = apply_gate(state, RX(param_value, target_qubit, control_qubit))
7469
```
7570

76-
A fully differentiable variational circuit is simply a sequence of gates which are applied to a state.
71+
We can now build a fully differentiable variational circuit by simply defining a sequence of gates
72+
and a set of initial parameter values we want to optimize.
73+
Lets fit a function using a simple circuit class wrapper.
74+
75+
```python exec="on" source="material-block" html="1"
76+
from __future__ import annotations
7777

78-
```python exec="on" source="material-block"
7978
import jax
79+
from jax import grad, jit, Array, value_and_grad, vmap
80+
from dataclasses import dataclass
8081
import jax.numpy as jnp
81-
from horqrux import gates
82-
from horqrux.utils import prepare_state, overlap
83-
from horqrux.ops import apply_gate
82+
import optax
83+
from functools import reduce, partial
84+
from operator import add
85+
from typing import Any, Callable
86+
from uuid import uuid4
87+
88+
from horqrux.abstract import Operator
89+
from horqrux import Z, RX, RY, NOT, zero_state, apply_gate, overlap
90+
91+
92+
n_qubits = 5
93+
n_params = 3
94+
n_layers = 3
8495

85-
n_qubits = 2
86-
state = prepare_state(2, '00')
8796
# Lets define a sequence of rotations
88-
ops = [gates.Rx, gates.Ry, gates.Rx]
97+
def ansatz_w_params(n_qubits: int, n_layers: int) -> tuple[list, list]:
98+
all_ops = []
99+
param_names = []
100+
rots_fns = [RX ,RY, RX]
101+
for _ in range(n_layers):
102+
for i in range(n_qubits):
103+
ops = [fn(str(uuid4()), qubit) for fn, qubit in zip(rots_fns, [i for _ in range(len(rots_fns))])]
104+
param_names += [op.param for op in ops]
105+
ops += [NOT((i+1) % n_qubits, i % n_qubits) for i in range(n_qubits)]
106+
all_ops += ops
107+
108+
return all_ops, param_names
109+
110+
# We need a function to fit and use it to produce training data
111+
fn = lambda x, degree: .05 * reduce(add, (jnp.cos(i*x) + jnp.sin(i*x) for i in range(degree)), 0)
112+
x = jnp.linspace(0, 10, 100)
113+
y = fn(x, 5)
114+
115+
@dataclass
116+
class Circuit:
117+
n_qubits: int
118+
n_layers: int
119+
120+
def __post_init__(self) -> None:
121+
# We will use a featuremap of RX rotations to encode some classical data
122+
self.feature_map: list[Operator] = [RX('phi', i) for i in range(n_qubits)]
123+
self.ansatz, self.param_names = ansatz_w_params(self.n_qubits, self.n_layers)
124+
self.observable: list[Operator] = [Z(0)]
125+
126+
@partial(vmap, in_axes=(None, None, 0))
127+
def forward(self, param_values: Array, x: Array) -> Array:
128+
state = zero_state(self.n_qubits)
129+
param_dict = {name: val for name, val in zip(self.param_names, param_values)}
130+
state = apply_gate(state, self.feature_map + self.ansatz, {**param_dict, **{'phi': x}})
131+
return overlap(state, apply_gate(state, self.observable))
132+
133+
def __call__(self, param_values: Array, x: Array) -> Array:
134+
return self.forward(param_values, x)
135+
136+
@property
137+
def n_vparams(self) -> int:
138+
return len(self.param_names)
139+
140+
circ = Circuit(n_qubits, n_layers)
89141
# Create random initial values for the parameters
90-
key = jax.random.PRNGKey(0)
91-
params = jax.random.uniform(key, shape=(n_qubits * len(ops),))
92-
93-
def circ(state) -> jax.Array:
94-
for qubit in range(n_qubits):
95-
for gate,param in zip(ops, params):
96-
state = apply_gate(state, gate(param, qubit))
97-
state = apply_gate(state,gates.NOT(1, 0))
98-
projection = apply_gate(state, gates.Z(0))
99-
return overlap(state, projection)
100-
101-
# Let's compute both values and gradients for a set of parameters and compile the circuit.
102-
circ = jax.jit(jax.value_and_grad(circ))
103-
# Run it on a state.
104-
expval_and_grads = circ(state)
105-
expval = expval_and_grads[0]
106-
grads = expval_and_grads[1:]
107-
print(f'Expval: {expval};'
108-
f'Grads: {grads}')
142+
key = jax.random.PRNGKey(42)
143+
param_vals = jax.random.uniform(key, shape=(circ.n_vparams,))
144+
# Check the initial predictions using randomly initialized parameters
145+
y_init = circ(param_vals, x)
146+
147+
optimizer = optax.adam(learning_rate=0.01)
148+
opt_state = optimizer.init(param_vals)
149+
150+
# Define a loss function
151+
def loss_fn(param_vals: Array, x: Array, y: Array) -> Array:
152+
y_pred = circ(param_vals, x)
153+
return jnp.mean(optax.l2_loss(y_pred, y))
154+
155+
156+
def optimize_step(params: dict[str, Array], opt_state: Array, grads: dict[str, Array]) -> tuple:
157+
updates, opt_state = optimizer.update(grads, opt_state)
158+
params = optax.apply_updates(params, updates)
159+
return params, opt_state
160+
161+
@jit
162+
def train_step(i: int, inputs: tuple
163+
) -> tuple:
164+
param_vals, opt_state = inputs
165+
loss, grads = value_and_grad(loss_fn)(param_vals, x, y)
166+
param_vals, opt_state = optimize_step(param_vals, opt_state, grads)
167+
return param_vals, opt_state
168+
169+
170+
n_epochs = 200
171+
param_vals, opt_state = jax.lax.fori_loop(0, n_epochs, train_step, (param_vals, opt_state))
172+
y_final = circ(param_vals, x)
173+
174+
# Lets plot the results
175+
import matplotlib.pyplot as plt
176+
plt.plot(x, y, label="truth")
177+
plt.plot(x, y_init, label="initial")
178+
plt.plot(x, y_final, "--", label="final", linewidth=3)
179+
plt.legend()
180+
181+
from io import StringIO # markdown-exec: hide
182+
from matplotlib.figure import Figure # markdown-exec: hide
183+
def fig_to_html(fig: Figure) -> str: # markdown-exec: hide
184+
buffer = StringIO() # markdown-exec: hide
185+
fig.savefig(buffer, format="svg") # markdown-exec: hide
186+
return buffer.getvalue() # markdown-exec: hide
187+
# from docs import docutils # markdown-exec: hide
188+
print(fig_to_html(plt.gcf())) # markdown-exec: hide
109189
```

horqrux/__init__.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
from __future__ import annotations
22

3-
from jax import config
4-
5-
config.update("jax_enable_x64", True) # you should really really do this
3+
from .apply import apply_gate, apply_operator
4+
from .parametric import PHASE, RX, RY, RZ
5+
from .primitive import NOT, SWAP, H, I, S, T, X, Y, Z
6+
from .utils import (
7+
equivalent_state,
8+
hilbert_reshape,
9+
overlap,
10+
product_state,
11+
random_state,
12+
uniform_state,
13+
zero_state,
14+
)

horqrux/abstract.py

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Any, Iterable, Tuple
5+
6+
import numpy as np
7+
from jax import Array
8+
from jax.tree_util import register_pytree_node_class
9+
10+
from .matrices import OPERATIONS_DICT
11+
from .utils import (
12+
ControlQubits,
13+
QubitSupport,
14+
TargetQubits,
15+
_dagger,
16+
_jacobian,
17+
_unitary,
18+
is_controlled,
19+
none_like,
20+
)
21+
22+
23+
@register_pytree_node_class
24+
@dataclass
25+
class Operator:
26+
"""Abstract class which stores information about generators target and control qubits
27+
of a particular quantum operator."""
28+
29+
generator_name: str
30+
target: QubitSupport
31+
control: QubitSupport
32+
33+
@staticmethod
34+
def parse_idx(
35+
idx: Tuple,
36+
) -> Tuple:
37+
if isinstance(idx, (int, np.int64)):
38+
return ((idx,),)
39+
elif isinstance(idx, tuple):
40+
return (idx,)
41+
else:
42+
return (idx.astype(int),)
43+
44+
def __post_init__(self) -> None:
45+
self.target = Operator.parse_idx(self.target)
46+
if self.control is None:
47+
self.control = none_like(self.target)
48+
else:
49+
self.control = Operator.parse_idx(self.control)
50+
51+
def __iter__(self) -> Iterable:
52+
return iter((self.generator_name, self.target, self.control))
53+
54+
def tree_flatten(self) -> Tuple[Tuple, Tuple[str, TargetQubits, ControlQubits]]:
55+
children = ()
56+
aux_data = (self.generator_name, self.target, self.control)
57+
return (children, aux_data)
58+
59+
@classmethod
60+
def tree_unflatten(cls, aux_data: Any, children: Any) -> Any:
61+
return cls(*children, *aux_data)
62+
63+
def unitary(self, values: dict[str, float] = dict()) -> Array:
64+
return OPERATIONS_DICT[self.generator_name]
65+
66+
def dagger(self, values: dict[str, float] = dict()) -> Array:
67+
return _dagger(self.unitary(values))
68+
69+
@property
70+
def name(self) -> str:
71+
return "C" + self.generator_name if is_controlled(self.control) else self.generator_name
72+
73+
def __repr__(self) -> str:
74+
return self.name + f"(target={self.target[0]}, control={self.control[0]})"
75+
76+
77+
Primitive = Operator
78+
79+
80+
@register_pytree_node_class
81+
@dataclass
82+
class Parametric(Primitive):
83+
"""Extension of the Primitive class adding the option to pass a parameter."""
84+
85+
generator_name: str
86+
target: QubitSupport
87+
control: QubitSupport
88+
param: str | float = ""
89+
90+
def __post_init__(self) -> None:
91+
super().__post_init__()
92+
93+
def parse_dict(values: dict[str, float] = dict()) -> float:
94+
return values[self.param] # type: ignore[index]
95+
96+
def parse_val(values: dict[str, float] = dict()) -> float:
97+
return self.param # type: ignore[return-value]
98+
99+
self.parse_values = parse_dict if isinstance(self.param, str) else parse_val
100+
101+
def tree_flatten(self) -> Tuple[Tuple, Tuple[str, Tuple, Tuple, str | float]]: # type: ignore[override]
102+
children = ()
103+
aux_data = (
104+
self.name,
105+
self.target,
106+
self.control,
107+
self.param,
108+
)
109+
return (children, aux_data)
110+
111+
def unitary(self, values: dict[str, float] = dict()) -> Array:
112+
return _unitary(OPERATIONS_DICT[self.generator_name], self.parse_values(values))
113+
114+
def jacobian(self, values: dict[str, float] = dict()) -> Array:
115+
return _jacobian(OPERATIONS_DICT[self.generator_name], self.parse_values(values))
116+
117+
@property
118+
def name(self) -> str:
119+
base_name = "R" + self.generator_name
120+
return "C" + base_name if is_controlled(self.control) else base_name
121+
122+
def __repr__(self) -> str:
123+
return (
124+
self.name + f"(target={self.target[0]}, control={self.control[0]}, param={self.param})"
125+
)

0 commit comments

Comments
 (0)