@@ -111,7 +111,9 @@ from typing import Any, Callable
111
111
from uuid import uuid4
112
112
113
113
from horqrux.adjoint import adjoint_expectation
114
+ from horqrux.circuit import Circuit, hea
114
115
from horqrux.primitive import Primitive
116
+ from horqrux.parametric import Parametric
115
117
from horqrux import Z, RX , RY , NOT , zero_state, apply_gate
116
118
117
119
@@ -120,47 +122,25 @@ n_params = 3
120
122
n_layers = 3
121
123
122
124
# Lets define a sequence of rotations
123
- def ansatz_w_params (n_qubits : int , n_layers : int ) -> tuple[list , list ]:
124
- all_ops = []
125
- param_names = []
126
- rots_fns = [RX ,RY , RX ]
127
- for _ in range (n_layers):
128
- for i in range (n_qubits):
129
- ops = [fn(str (uuid4()), qubit) for fn, qubit in zip (rots_fns, [i for _ in range (len (rots_fns))])]
130
- param_names += [op.param for op in ops]
131
- ops += [NOT((i+ 1 ) % n_qubits, i % n_qubits) for i in range (n_qubits)]
132
- all_ops += ops
133
-
134
- return all_ops, param_names
135
125
136
126
# We need a function to fit and use it to produce training data
137
127
fn = lambda x , degree : .05 * reduce (add, (jnp.cos(i* x) + jnp.sin(i* x) for i in range (degree)), 0 )
138
128
x = jnp.linspace(0 , 10 , 100 )
139
129
y = fn(x, 5 )
140
130
141
- @dataclass
142
- class Circuit :
143
- n_qubits: int
144
- n_layers: int
145
131
132
+ class DQC (Circuit ):
146
133
def __post_init__ (self ) -> None :
147
- # We will use a featuremap of RX rotations to encode some classical data
148
- self .feature_map: list[Primitive] = [RX(' phi' , i) for i in range (self .n_qubits)]
149
- self .ansatz, self .param_names = ansatz_w_params(self .n_qubits, self .n_layers)
150
134
self .observable: list[Primitive] = [Z(0 )]
135
+ self .state = zero_state(self .n_qubits)
151
136
152
137
@partial (vmap, in_axes = (None , None , 0 ))
153
138
def __call__ (self , param_values : Array, x : Array) -> Array:
154
- state = zero_state(self .n_qubits)
155
139
param_dict = {name: val for name, val in zip (self .param_names, param_values)}
156
- return adjoint_expectation(state, self .feature_map + self .ansatz, self .observable, {** param_dict, ** {' phi' : x}})
140
+ return adjoint_expectation(self . state, self .feature_map + self .ansatz, self .observable, {** param_dict, ** {' phi' : x}})
157
141
158
142
159
- @ property
160
- def n_vparams (self ) -> int :
161
- return len (self .param_names)
162
-
163
- circ = Circuit(n_qubits, n_layers)
143
+ circ = DQC(n_qubits = n_qubits, feature_map = [RX(' phi' , i) for i in range (n_qubits)], ansatz = hea(n_qubits, n_layers))
164
144
# Create random initial values for the parameters
165
145
key = jax.random.PRNGKey(42 )
166
146
param_vals = jax.random.uniform(key, shape = (circ.n_vparams,))
@@ -171,7 +151,7 @@ optimizer = optax.adam(learning_rate=0.01)
171
151
opt_state = optimizer.init(param_vals)
172
152
173
153
# Define a loss function
174
- def loss_fn (param_vals : Array, x : Array, y : Array ) -> Array:
154
+ def loss_fn (param_vals : Array) -> Array:
175
155
y_pred = circ(param_vals, x)
176
156
return jnp.mean(optax.l2_loss(y_pred, y))
177
157
@@ -185,7 +165,7 @@ def optimize_step(param_vals: Array, opt_state: Array, grads: Array) -> tuple:
185
165
def train_step (i : int , paramvals_w_optstate : tuple
186
166
) -> tuple :
187
167
param_vals, opt_state = paramvals_w_optstate
188
- loss, grads = value_and_grad(loss_fn)(param_vals, x, y )
168
+ loss, grads = value_and_grad(loss_fn)(param_vals)
189
169
param_vals, opt_state = optimize_step(param_vals, opt_state, grads)
190
170
return param_vals, opt_state
191
171
@@ -221,7 +201,7 @@ from dataclasses import dataclass
221
201
from functools import reduce
222
202
from itertools import product
223
203
from operator import add
224
- from uuid import uuid4
204
+ from typing import Callable
225
205
226
206
import jax
227
207
import jax.numpy as jnp
@@ -231,75 +211,52 @@ import optax
231
211
from jax import Array, jit, value_and_grad, vmap
232
212
from numpy.random import uniform
233
213
214
+ from horqrux.apply import group_by_index
215
+ from horqrux.circuit import Circuit, hea
234
216
from horqrux import NOT , RX , RY , Z, apply_gate, zero_state
235
217
from horqrux.primitive import Primitive
218
+ from horqrux.parametric import Parametric
236
219
from horqrux.utils import inner
237
220
238
221
LEARNING_RATE = 0.01
239
222
N_QUBITS = 4
240
223
DEPTH = 3
241
224
VARIABLES = (" x" , " y" )
242
- X_POS = 0
243
- Y_POS = 1
244
- N_POINTS = 150
225
+ NUM_VARIABLES = len ( VARIABLES )
226
+ X_POS , Y_POS = [i for i in range ( NUM_VARIABLES )]
227
+ BATCH_SIZE = 150
245
228
N_EPOCHS = 1000
246
229
230
+ def total_magnetization (n_qubits :int ) -> Callable:
231
+ paulis = [Z(i) for i in range (n_qubits)]
247
232
248
- def ansatz_w_params (n_qubits : int , n_layers : int ) -> tuple[list , list ]:
249
- all_ops = []
250
- param_names = []
251
- rots_fns = [RX , RY , RX ]
252
- for _ in range (n_layers):
253
- for i in range (n_qubits):
254
- ops = [
255
- fn(str (uuid4()), qubit)
256
- for fn, qubit in zip (rots_fns, [i for _ in range (len (rots_fns))])
257
- ]
258
- param_names += [op.param for op in ops]
259
- ops += [NOT((i + 1 ) % n_qubits, i % n_qubits) for i in range (n_qubits)]
260
- all_ops += ops
261
-
262
- return all_ops, param_names
263
-
264
-
265
- @dataclass
266
- class TotalMagnetization :
267
- n_qubits: int
268
-
269
- def __post_init__ (self ) -> None :
270
- self .paulis = [Z(i) for i in range (self .n_qubits)]
271
-
272
- def __call__ (self , state : Array, values : dict ) -> Array:
273
- return reduce (add, [apply_gate(state, pauli, values) for pauli in self .paulis])
274
-
233
+ def _total_magnetization (out_state : Array, values : dict[str , Array]) -> Array:
234
+ projected_state = reduce (
235
+ add, [apply_gate(out_state, pauli, values) for pauli in paulis]
236
+ )
237
+ return inner(out_state, projected_state).real
238
+ return _total_magnetization
275
239
276
- @dataclass
277
- class Circuit :
278
- n_qubits: int
279
- n_layers: int
280
240
241
+ class DQC (Circuit ):
281
242
def __post_init__ (self ) -> None :
282
- self .feature_map: list[Primitive] = [RX(" x" , i) for i in range (self .n_qubits // 2 )] + [
283
- RX(" y" , i) for i in range (self .n_qubits // 2 , self .n_qubits)
284
- ]
285
- self .ansatz, self .param_names = ansatz_w_params(self .n_qubits, self .n_layers)
286
- self .observable = TotalMagnetization(self .n_qubits)
243
+ self .ansatz = group_by_index(self .ansatz)
244
+ self .observable = total_magnetization(self .n_qubits)
245
+ self .state = zero_state(self .n_qubits)
287
246
288
247
def __call__ (self , param_vals : Array, x : Array, y : Array) -> Array:
289
- state = zero_state(self .n_qubits)
290
248
param_dict = {name: val for name, val in zip (self .param_names, param_vals)}
291
249
out_state = apply_gate(
292
- state, self .feature_map + self .ansatz, {** param_dict, ** {" x" : x, " y" : y}}
250
+ self . state, self .feature_map + self .ansatz, {** param_dict, ** {" x" : x, " y" : y}}
293
251
)
294
- projected_state = self .observable(state, param_dict)
295
- return jnp.real(inner(out_state, projected_state))
296
-
297
- @ property
298
- def n_vparams (self ) -> int :
299
- return len (self .param_names)
252
+ return self .observable(out_state, {})
300
253
301
254
302
- circ = Circuit(N_QUBITS , DEPTH )
255
+ fm = [RX(" x" , i) for i in range (N_QUBITS // 2 )] + [
256
+ RX(" y" , i) for i in range (N_QUBITS // 2 , N_QUBITS )
257
+ ]
258
+ ansatz = hea(N_QUBITS , DEPTH )
259
+ circ = DQC(N_QUBITS , fm, ansatz)
303
260
# Create random initial values for the parameters
304
261
key = jax.random.PRNGKey(42 )
305
262
param_vals = jax.random.uniform(key, shape = (circ.n_vparams,))
@@ -308,25 +265,20 @@ optimizer = optax.adam(learning_rate=0.01)
308
265
opt_state = optimizer.init(param_vals)
309
266
310
267
311
- def exp_fn (param_vals : Array, x : Array, y : Array) -> Array:
312
- return circ(param_vals, x, y)
313
-
314
-
315
- def loss_fn (param_vals : Array, x : Array, y : Array) -> Array:
316
- def pde_loss (x : float , y : float ) -> Array:
317
- l_b, r_b, t_b, b_b = list (
318
- map (
319
- lambda xy : exp_fn(param_vals, * xy),
320
- [
321
- [jnp.zeros((1 , 1 )), y], # u(0,y)=0
322
- [jnp.ones((1 , 1 )), y], # u(L,y)=0
323
- [x, jnp.ones((1 , 1 ))], # u(x,H)=0
324
- [x, jnp.zeros((1 , 1 ))], # u(x,0)=f(x)
325
- ],
326
- )
268
+ def loss_fn (param_vals : Array) -> Array:
269
+ def pde_loss (x : Array, y : Array) -> Array:
270
+ x = x.reshape(- 1 , 1 )
271
+ y = y.reshape(- 1 , 1 )
272
+ left = (jnp.zeros_like(y), y) # u(0,y)=0
273
+ right = (jnp.ones_like(y), y) # u(L,y)=0
274
+ top = (x, jnp.ones_like(x)) # u(x,H)=0
275
+ bottom = (x, jnp.zeros_like(x)) # u(x,0)=f(x)
276
+ terms = jnp.dstack(list (map (jnp.hstack, [left, right, top, bottom])))
277
+ loss_left, loss_right, loss_top, loss_bottom = vmap(lambda xy : circ(param_vals, xy[:, 0 ], xy[:, 1 ]), in_axes = (2 ,))(
278
+ terms
327
279
)
328
- b_b -= jnp.sin(jnp.pi * x)
329
- hessian = jax.hessian(lambda xy : exp_fn (param_vals, xy[0 ], xy[1 ]))(
280
+ loss_bottom -= jnp.sin(jnp.pi * x)
281
+ hessian = jax.hessian(lambda xy : circ (param_vals, xy[0 ], xy[1 ]))(
330
282
jnp.concatenate(
331
283
[
332
284
x.reshape(
@@ -338,10 +290,19 @@ def loss_fn(param_vals: Array, x: Array, y: Array) -> Array:
338
290
]
339
291
)
340
292
)
341
- interior = hessian[X_POS ][X_POS ] + hessian[Y_POS ][Y_POS ] # uxx+uyy=0
342
- return reduce (add, list (map (lambda term : jnp.power(term, 2 ), [l_b, r_b, t_b, b_b, interior])))
293
+ loss_interior = hessian[X_POS ][X_POS ] + hessian[Y_POS ][Y_POS ] # uxx+uyy=0
294
+ return jnp.sum(
295
+ jnp.concatenate(
296
+ list (
297
+ map (
298
+ lambda term : jnp.power(term, 2 ).reshape(- 1 , 1 ),
299
+ [loss_left, loss_right, loss_top, loss_bottom, loss_interior],
300
+ )
301
+ )
302
+ )
303
+ )
343
304
344
- return jnp.mean(vmap(pde_loss, in_axes = (0 , 0 ))(x, y ))
305
+ return jnp.mean(vmap(pde_loss, in_axes = (0 , 0 ))(* uniform( 0 , 1.0 , ( NUM_VARIABLES , BATCH_SIZE )) ))
345
306
346
307
347
308
def optimize_step (param_vals : Array, opt_state : Array, grads : dict[str , Array]) -> tuple :
@@ -350,32 +311,25 @@ def optimize_step(param_vals: Array, opt_state: Array, grads: dict[str, Array])
350
311
return param_vals, opt_state
351
312
352
313
353
- # collocation points sampling and training
354
- def sample_points (n_in : int , n_p : int ) -> Array:
355
- return uniform(0 , 1.0 , (n_in, n_p))
356
-
357
-
358
314
@jit
359
315
def train_step (i : int , paramvals_w_optstate : tuple ) -> tuple :
360
316
param_vals, opt_state = paramvals_w_optstate
361
- x, y = sample_points(2 , N_POINTS )
362
- loss, grads = value_and_grad(loss_fn)(param_vals, x, y)
317
+ loss, grads = value_and_grad(loss_fn)(param_vals)
363
318
return optimize_step(param_vals, opt_state, grads)
364
319
365
320
366
321
param_vals, opt_state = jax.lax.fori_loop(0 , N_EPOCHS , train_step, (param_vals, opt_state))
367
322
# compare the solution to known ground truth
368
- single_domain = jnp.linspace(0 , 1 , num = N_POINTS )
323
+ single_domain = jnp.linspace(0 , 1 , num = BATCH_SIZE )
369
324
domain = jnp.array(list (product(single_domain, single_domain)))
370
325
# analytical solution
371
326
analytic_sol = (
372
- (np.exp(- np.pi * domain[:, 0 ]) * np.sin(np.pi * domain[:, 1 ])).reshape(N_POINTS , N_POINTS ).T
327
+ (np.exp(- np.pi * domain[:, 0 ]) * np.sin(np.pi * domain[:, 1 ])).reshape(BATCH_SIZE , BATCH_SIZE ).T
373
328
)
374
329
# DQC solution
375
-
376
- dqc_sol = vmap(lambda domain : exp_fn(param_vals, domain[0 ], domain[1 ]), in_axes = (0 ,))(domain).reshape(
377
- N_POINTS , N_POINTS
378
- )
330
+ dqc_sol = vmap(lambda domain : circ(param_vals, domain[0 ], domain[1 ]), in_axes = (0 ,))(
331
+ domain
332
+ ).reshape(BATCH_SIZE , BATCH_SIZE )
379
333
# # plot results
380
334
fig, ax = plt.subplots(1 , 2 , figsize = (7 , 7 ))
381
335
ax[0 ].imshow(analytic_sol, cmap = " turbo" )
0 commit comments