Skip to content

Commit 92490b8

Browse files
[Feature | Performance] Add circuit module, Merge gates acting on same qubits (#14)
1 parent 0db46bd commit 92490b8

File tree

7 files changed

+229
-121
lines changed

7 files changed

+229
-121
lines changed

docs/index.md

+65-111
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ from typing import Any, Callable
111111
from uuid import uuid4
112112

113113
from horqrux.adjoint import adjoint_expectation
114+
from horqrux.circuit import Circuit, hea
114115
from horqrux.primitive import Primitive
116+
from horqrux.parametric import Parametric
115117
from horqrux import Z, RX, RY, NOT, zero_state, apply_gate
116118

117119

@@ -120,47 +122,25 @@ n_params = 3
120122
n_layers = 3
121123

122124
# Lets define a sequence of rotations
123-
def ansatz_w_params(n_qubits: int, n_layers: int) -> tuple[list, list]:
124-
all_ops = []
125-
param_names = []
126-
rots_fns = [RX ,RY, RX]
127-
for _ in range(n_layers):
128-
for i in range(n_qubits):
129-
ops = [fn(str(uuid4()), qubit) for fn, qubit in zip(rots_fns, [i for _ in range(len(rots_fns))])]
130-
param_names += [op.param for op in ops]
131-
ops += [NOT((i+1) % n_qubits, i % n_qubits) for i in range(n_qubits)]
132-
all_ops += ops
133-
134-
return all_ops, param_names
135125

136126
# We need a function to fit and use it to produce training data
137127
fn = lambda x, degree: .05 * reduce(add, (jnp.cos(i*x) + jnp.sin(i*x) for i in range(degree)), 0)
138128
x = jnp.linspace(0, 10, 100)
139129
y = fn(x, 5)
140130

141-
@dataclass
142-
class Circuit:
143-
n_qubits: int
144-
n_layers: int
145131

132+
class DQC(Circuit):
146133
def __post_init__(self) -> None:
147-
# We will use a featuremap of RX rotations to encode some classical data
148-
self.feature_map: list[Primitive] = [RX('phi', i) for i in range(self.n_qubits)]
149-
self.ansatz, self.param_names = ansatz_w_params(self.n_qubits, self.n_layers)
150134
self.observable: list[Primitive] = [Z(0)]
135+
self.state = zero_state(self.n_qubits)
151136

152137
@partial(vmap, in_axes=(None, None, 0))
153138
def __call__(self, param_values: Array, x: Array) -> Array:
154-
state = zero_state(self.n_qubits)
155139
param_dict = {name: val for name, val in zip(self.param_names, param_values)}
156-
return adjoint_expectation(state, self.feature_map + self.ansatz, self.observable, {**param_dict, **{'phi': x}})
140+
return adjoint_expectation(self.state, self.feature_map + self.ansatz, self.observable, {**param_dict, **{'phi': x}})
157141

158142

159-
@property
160-
def n_vparams(self) -> int:
161-
return len(self.param_names)
162-
163-
circ = Circuit(n_qubits, n_layers)
143+
circ = DQC(n_qubits=n_qubits, feature_map=[RX('phi', i) for i in range(n_qubits)], ansatz=hea(n_qubits, n_layers))
164144
# Create random initial values for the parameters
165145
key = jax.random.PRNGKey(42)
166146
param_vals = jax.random.uniform(key, shape=(circ.n_vparams,))
@@ -171,7 +151,7 @@ optimizer = optax.adam(learning_rate=0.01)
171151
opt_state = optimizer.init(param_vals)
172152

173153
# Define a loss function
174-
def loss_fn(param_vals: Array, x: Array, y: Array) -> Array:
154+
def loss_fn(param_vals: Array) -> Array:
175155
y_pred = circ(param_vals, x)
176156
return jnp.mean(optax.l2_loss(y_pred, y))
177157

@@ -185,7 +165,7 @@ def optimize_step(param_vals: Array, opt_state: Array, grads: Array) -> tuple:
185165
def train_step(i: int, paramvals_w_optstate: tuple
186166
) -> tuple:
187167
param_vals, opt_state = paramvals_w_optstate
188-
loss, grads = value_and_grad(loss_fn)(param_vals, x, y)
168+
loss, grads = value_and_grad(loss_fn)(param_vals)
189169
param_vals, opt_state = optimize_step(param_vals, opt_state, grads)
190170
return param_vals, opt_state
191171

@@ -221,7 +201,7 @@ from dataclasses import dataclass
221201
from functools import reduce
222202
from itertools import product
223203
from operator import add
224-
from uuid import uuid4
204+
from typing import Callable
225205

226206
import jax
227207
import jax.numpy as jnp
@@ -231,75 +211,52 @@ import optax
231211
from jax import Array, jit, value_and_grad, vmap
232212
from numpy.random import uniform
233213

214+
from horqrux.apply import group_by_index
215+
from horqrux.circuit import Circuit, hea
234216
from horqrux import NOT, RX, RY, Z, apply_gate, zero_state
235217
from horqrux.primitive import Primitive
218+
from horqrux.parametric import Parametric
236219
from horqrux.utils import inner
237220

238221
LEARNING_RATE = 0.01
239222
N_QUBITS = 4
240223
DEPTH = 3
241224
VARIABLES = ("x", "y")
242-
X_POS = 0
243-
Y_POS = 1
244-
N_POINTS = 150
225+
NUM_VARIABLES = len(VARIABLES)
226+
X_POS, Y_POS = [i for i in range(NUM_VARIABLES)]
227+
BATCH_SIZE = 150
245228
N_EPOCHS = 1000
246229

230+
def total_magnetization(n_qubits:int) -> Callable:
231+
paulis = [Z(i) for i in range(n_qubits)]
247232

248-
def ansatz_w_params(n_qubits: int, n_layers: int) -> tuple[list, list]:
249-
all_ops = []
250-
param_names = []
251-
rots_fns = [RX, RY, RX]
252-
for _ in range(n_layers):
253-
for i in range(n_qubits):
254-
ops = [
255-
fn(str(uuid4()), qubit)
256-
for fn, qubit in zip(rots_fns, [i for _ in range(len(rots_fns))])
257-
]
258-
param_names += [op.param for op in ops]
259-
ops += [NOT((i + 1) % n_qubits, i % n_qubits) for i in range(n_qubits)]
260-
all_ops += ops
261-
262-
return all_ops, param_names
263-
264-
265-
@dataclass
266-
class TotalMagnetization:
267-
n_qubits: int
268-
269-
def __post_init__(self) -> None:
270-
self.paulis = [Z(i) for i in range(self.n_qubits)]
271-
272-
def __call__(self, state: Array, values: dict) -> Array:
273-
return reduce(add, [apply_gate(state, pauli, values) for pauli in self.paulis])
274-
233+
def _total_magnetization(out_state: Array, values: dict[str, Array]) -> Array:
234+
projected_state = reduce(
235+
add, [apply_gate(out_state, pauli, values) for pauli in paulis]
236+
)
237+
return inner(out_state, projected_state).real
238+
return _total_magnetization
275239

276-
@dataclass
277-
class Circuit:
278-
n_qubits: int
279-
n_layers: int
280240

241+
class DQC(Circuit):
281242
def __post_init__(self) -> None:
282-
self.feature_map: list[Primitive] = [RX("x", i) for i in range(self.n_qubits // 2)] + [
283-
RX("y", i) for i in range(self.n_qubits // 2, self.n_qubits)
284-
]
285-
self.ansatz, self.param_names = ansatz_w_params(self.n_qubits, self.n_layers)
286-
self.observable = TotalMagnetization(self.n_qubits)
243+
self.ansatz = group_by_index(self.ansatz)
244+
self.observable = total_magnetization(self.n_qubits)
245+
self.state = zero_state(self.n_qubits)
287246

288247
def __call__(self, param_vals: Array, x: Array, y: Array) -> Array:
289-
state = zero_state(self.n_qubits)
290248
param_dict = {name: val for name, val in zip(self.param_names, param_vals)}
291249
out_state = apply_gate(
292-
state, self.feature_map + self.ansatz, {**param_dict, **{"x": x, "y": y}}
250+
self.state, self.feature_map + self.ansatz, {**param_dict, **{"x": x, "y": y}}
293251
)
294-
projected_state = self.observable(state, param_dict)
295-
return jnp.real(inner(out_state, projected_state))
296-
297-
@property
298-
def n_vparams(self) -> int:
299-
return len(self.param_names)
252+
return self.observable(out_state, {})
300253

301254

302-
circ = Circuit(N_QUBITS, DEPTH)
255+
fm = [RX("x", i) for i in range(N_QUBITS // 2)] + [
256+
RX("y", i) for i in range(N_QUBITS // 2, N_QUBITS)
257+
]
258+
ansatz = hea(N_QUBITS, DEPTH)
259+
circ = DQC(N_QUBITS, fm, ansatz)
303260
# Create random initial values for the parameters
304261
key = jax.random.PRNGKey(42)
305262
param_vals = jax.random.uniform(key, shape=(circ.n_vparams,))
@@ -308,25 +265,20 @@ optimizer = optax.adam(learning_rate=0.01)
308265
opt_state = optimizer.init(param_vals)
309266

310267

311-
def exp_fn(param_vals: Array, x: Array, y: Array) -> Array:
312-
return circ(param_vals, x, y)
313-
314-
315-
def loss_fn(param_vals: Array, x: Array, y: Array) -> Array:
316-
def pde_loss(x: float, y: float) -> Array:
317-
l_b, r_b, t_b, b_b = list(
318-
map(
319-
lambda xy: exp_fn(param_vals, *xy),
320-
[
321-
[jnp.zeros((1, 1)), y], # u(0,y)=0
322-
[jnp.ones((1, 1)), y], # u(L,y)=0
323-
[x, jnp.ones((1, 1))], # u(x,H)=0
324-
[x, jnp.zeros((1, 1))], # u(x,0)=f(x)
325-
],
326-
)
268+
def loss_fn(param_vals: Array) -> Array:
269+
def pde_loss(x: Array, y: Array) -> Array:
270+
x = x.reshape(-1, 1)
271+
y = y.reshape(-1, 1)
272+
left = (jnp.zeros_like(y), y) # u(0,y)=0
273+
right = (jnp.ones_like(y), y) # u(L,y)=0
274+
top = (x, jnp.ones_like(x)) # u(x,H)=0
275+
bottom = (x, jnp.zeros_like(x)) # u(x,0)=f(x)
276+
terms = jnp.dstack(list(map(jnp.hstack, [left, right, top, bottom])))
277+
loss_left, loss_right, loss_top, loss_bottom = vmap(lambda xy: circ(param_vals, xy[:, 0], xy[:, 1]), in_axes=(2,))(
278+
terms
327279
)
328-
b_b -= jnp.sin(jnp.pi * x)
329-
hessian = jax.hessian(lambda xy: exp_fn(param_vals, xy[0], xy[1]))(
280+
loss_bottom -= jnp.sin(jnp.pi * x)
281+
hessian = jax.hessian(lambda xy: circ(param_vals, xy[0], xy[1]))(
330282
jnp.concatenate(
331283
[
332284
x.reshape(
@@ -338,10 +290,19 @@ def loss_fn(param_vals: Array, x: Array, y: Array) -> Array:
338290
]
339291
)
340292
)
341-
interior = hessian[X_POS][X_POS] + hessian[Y_POS][Y_POS] # uxx+uyy=0
342-
return reduce(add, list(map(lambda term: jnp.power(term, 2), [l_b, r_b, t_b, b_b, interior])))
293+
loss_interior = hessian[X_POS][X_POS] + hessian[Y_POS][Y_POS] # uxx+uyy=0
294+
return jnp.sum(
295+
jnp.concatenate(
296+
list(
297+
map(
298+
lambda term: jnp.power(term, 2).reshape(-1, 1),
299+
[loss_left, loss_right, loss_top, loss_bottom, loss_interior],
300+
)
301+
)
302+
)
303+
)
343304

344-
return jnp.mean(vmap(pde_loss, in_axes=(0, 0))(x, y))
305+
return jnp.mean(vmap(pde_loss, in_axes=(0, 0))(*uniform(0, 1.0, (NUM_VARIABLES, BATCH_SIZE))))
345306

346307

347308
def optimize_step(param_vals: Array, opt_state: Array, grads: dict[str, Array]) -> tuple:
@@ -350,32 +311,25 @@ def optimize_step(param_vals: Array, opt_state: Array, grads: dict[str, Array])
350311
return param_vals, opt_state
351312

352313

353-
# collocation points sampling and training
354-
def sample_points(n_in: int, n_p: int) -> Array:
355-
return uniform(0, 1.0, (n_in, n_p))
356-
357-
358314
@jit
359315
def train_step(i: int, paramvals_w_optstate: tuple) -> tuple:
360316
param_vals, opt_state = paramvals_w_optstate
361-
x, y = sample_points(2, N_POINTS)
362-
loss, grads = value_and_grad(loss_fn)(param_vals, x, y)
317+
loss, grads = value_and_grad(loss_fn)(param_vals)
363318
return optimize_step(param_vals, opt_state, grads)
364319

365320

366321
param_vals, opt_state = jax.lax.fori_loop(0, N_EPOCHS, train_step, (param_vals, opt_state))
367322
# compare the solution to known ground truth
368-
single_domain = jnp.linspace(0, 1, num=N_POINTS)
323+
single_domain = jnp.linspace(0, 1, num=BATCH_SIZE)
369324
domain = jnp.array(list(product(single_domain, single_domain)))
370325
# analytical solution
371326
analytic_sol = (
372-
(np.exp(-np.pi * domain[:, 0]) * np.sin(np.pi * domain[:, 1])).reshape(N_POINTS, N_POINTS).T
327+
(np.exp(-np.pi * domain[:, 0]) * np.sin(np.pi * domain[:, 1])).reshape(BATCH_SIZE, BATCH_SIZE).T
373328
)
374329
# DQC solution
375-
376-
dqc_sol = vmap(lambda domain: exp_fn(param_vals, domain[0], domain[1]), in_axes=(0,))(domain).reshape(
377-
N_POINTS, N_POINTS
378-
)
330+
dqc_sol = vmap(lambda domain: circ(param_vals, domain[0], domain[1]), in_axes=(0,))(
331+
domain
332+
).reshape(BATCH_SIZE, BATCH_SIZE)
379333
# # plot results
380334
fig, ax = plt.subplots(1, 2, figsize=(7, 7))
381335
ax[0].imshow(analytic_sol, cmap="turbo")

horqrux/adjoint.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Tuple
44

55
from jax import Array, custom_vjp
6-
from jax.numpy import real as jnpreal
76

87
from horqrux.apply import apply_gate
98
from horqrux.parametric import Parametric
@@ -14,9 +13,13 @@
1413
def expectation(
1514
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
1615
) -> Array:
16+
"""
17+
Run 'state' through a sequence of 'gates' given parameters 'values'
18+
and compute the expectation given an observable.
19+
"""
1720
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
1821
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
19-
return jnpreal(inner(out_state, projected_state))
22+
return inner(out_state, projected_state).real
2023

2124

2225
@custom_vjp
@@ -31,19 +34,24 @@ def adjoint_expectation_fwd(
3134
) -> Tuple[Array, Tuple[Array, Array, list[Primitive], dict[str, float]]]:
3235
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
3336
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
34-
return jnpreal(inner(out_state, projected_state)), (out_state, projected_state, gates, values)
37+
return inner(out_state, projected_state).real, (out_state, projected_state, gates, values)
3538

3639

3740
def adjoint_expectation_bwd(
3841
res: Tuple[Array, Array, list[Primitive], dict[str, float]], tangent: Array
3942
) -> tuple:
43+
"""Implementation of Algorithm 1 of https://arxiv.org/abs/2009.02823
44+
which computes the vector-jacobian product in O(P) time using O(1) state vectors
45+
where P=number of parameters in the circuit.
46+
"""
47+
4048
out_state, projected_state, gates, values = res
4149
grads = {}
4250
for gate in gates[::-1]:
4351
out_state = apply_gate(out_state, gate, values, OperationType.DAGGER)
4452
if isinstance(gate, Parametric):
4553
mu = apply_gate(out_state, gate, values, OperationType.JACOBIAN)
46-
grads[gate.param] = tangent * 2 * jnpreal(inner(mu, projected_state))
54+
grads[gate.param] = tangent * 2 * inner(mu, projected_state).real
4755
projected_state = apply_gate(projected_state, gate, values, OperationType.DAGGER)
4856
return (None, None, None, grads)
4957

0 commit comments

Comments
 (0)