@@ -18,92 +18,172 @@ pip install horqrux
18
18
Let's have a look at primitive gates first.
19
19
20
20
``` python exec="on" source="material-block"
21
- from horqrux.gates import X
22
- from horqrux.utils import prepare_state
23
- from horqrux.ops import apply_gate
21
+ from horqrux import X, random_state, apply_gate
24
22
25
- state = prepare_state (2 )
23
+ state = random_state (2 )
26
24
new_state = apply_gate(state, X(0 ))
27
25
```
28
26
29
27
We can also make any gate controlled, in the case of X, we have to pass the target qubit first!
30
28
31
29
``` python exec="on" source="material-block"
32
30
import jax.numpy as jnp
33
- from horqrux.gates import X
34
- from horqrux.utils import prepare_state, equivalent_state
35
- from horqrux.ops import apply_gate
31
+ from horqrux import X, product_state, equivalent_state, apply_gate
36
32
37
33
n_qubits = 2
38
- state = prepare_state(n_qubits, ' 11' )
34
+ state = product_state( ' 11' )
39
35
control = 0
40
36
target = 1
41
37
# This is equivalent to performing CNOT(0,1)
42
38
new_state= apply_gate(state, X(target,control))
43
- assert jnp.allclose(new_state, prepare_state(n_qubits, ' 10' ))
39
+ assert jnp.allclose(new_state, product_state( ' 10' ))
44
40
```
45
41
46
- When applying parametric gates, we pass the numeric value for the parameter first
42
+ When applying parametric gates, we can either pass a numeric value or a parameter name for the parameter as the first argument.
47
43
48
44
``` python exec="on" source="material-block"
49
45
import jax.numpy as jnp
50
- from horqrux.gates import Rx
51
- from horqrux.utils import prepare_state
52
- from horqrux.ops import apply_gate
46
+ from horqrux import RX , random_state, apply_gate
53
47
54
48
target_qubit = 1
55
- state = prepare_state (target_qubit+ 1 )
49
+ state = random_state (target_qubit+ 1 )
56
50
param_value = 1 / 4 * jnp.pi
57
- new_state = apply_gate(state, Rx(param_value, target_qubit))
51
+ new_state = apply_gate(state, RX(param_value, target_qubit))
52
+ # Parametric horqrux gates also accept parameter names in the form of strings.
53
+ # Simply pass a dictionary of parameter names and values to the 'apply_gate' function
54
+ new_state = apply_gate(state, RX(' theta' , target_qubit), {' theta' : jnp.pi})
58
55
```
59
56
60
57
We can also make any parametric gate controlled simply by passing a control qubit.
61
58
62
59
``` python exec="on" source="material-block"
63
60
import jax.numpy as jnp
64
- from horqrux.gates import Rx
65
- from horqrux.utils import prepare_state
66
- from horqrux.ops import apply_gate
61
+ from horqrux import RX , product_state, apply_gate
67
62
68
63
n_qubits = 2
69
64
target_qubit = 1
70
65
control_qubit = 0
71
- state = prepare_state( 2 , ' 11' )
66
+ state = product_state( ' 11' )
72
67
param_value = 1 / 4 * jnp.pi
73
- new_state = apply_gate(state, Rx (param_value, target_qubit, control_qubit))
68
+ new_state = apply_gate(state, RX (param_value, target_qubit, control_qubit))
74
69
```
75
70
76
- A fully differentiable variational circuit is simply a sequence of gates which are applied to a state.
71
+ We can now build a fully differentiable variational circuit by simply defining a sequence of gates
72
+ and a set of initial parameter values we want to optimize.
73
+ Lets fit a function using a simple circuit class wrapper.
74
+
75
+ ``` python exec="on" source="material-block" html="1"
76
+ from __future__ import annotations
77
77
78
- ``` python exec="on" source="material-block"
79
78
import jax
79
+ from jax import grad, jit, Array, value_and_grad, vmap
80
+ from dataclasses import dataclass
80
81
import jax.numpy as jnp
81
- from horqrux import gates
82
- from horqrux.utils import prepare_state, overlap
83
- from horqrux.ops import apply_gate
82
+ import optax
83
+ from functools import reduce , partial
84
+ from operator import add
85
+ from typing import Any, Callable
86
+ from uuid import uuid4
87
+
88
+ from horqrux.abstract import Operator
89
+ from horqrux import Z, RX , RY , NOT , zero_state, apply_gate, overlap
90
+
91
+
92
+ n_qubits = 5
93
+ n_params = 3
94
+ n_layers = 3
84
95
85
- n_qubits = 2
86
- state = prepare_state(2 , ' 00' )
87
96
# Lets define a sequence of rotations
88
- ops = [gates.Rx, gates.Ry, gates.Rx]
97
+ def ansatz_w_params (n_qubits : int , n_layers : int ) -> tuple[list , list ]:
98
+ all_ops = []
99
+ param_names = []
100
+ rots_fns = [RX ,RY , RX ]
101
+ for _ in range (n_layers):
102
+ for i in range (n_qubits):
103
+ ops = [fn(str (uuid4()), qubit) for fn, qubit in zip (rots_fns, [i for _ in range (len (rots_fns))])]
104
+ param_names += [op.param for op in ops]
105
+ ops += [NOT((i+ 1 ) % n_qubits, i % n_qubits) for i in range (n_qubits)]
106
+ all_ops += ops
107
+
108
+ return all_ops, param_names
109
+
110
+ # We need a function to fit and use it to produce training data
111
+ fn = lambda x , degree : .05 * reduce (add, (jnp.cos(i* x) + jnp.sin(i* x) for i in range (degree)), 0 )
112
+ x = jnp.linspace(0 , 10 , 100 )
113
+ y = fn(x, 5 )
114
+
115
+ @dataclass
116
+ class Circuit :
117
+ n_qubits: int
118
+ n_layers: int
119
+
120
+ def __post_init__ (self ) -> None :
121
+ # We will use a featuremap of RX rotations to encode some classical data
122
+ self .feature_map: list[Operator] = [RX(' phi' , i) for i in range (n_qubits)]
123
+ self .ansatz, self .param_names = ansatz_w_params(self .n_qubits, self .n_layers)
124
+ self .observable: list[Operator] = [Z(0 )]
125
+
126
+ @partial (vmap, in_axes = (None , None , 0 ))
127
+ def forward (self , param_values : Array, x : Array) -> Array:
128
+ state = zero_state(self .n_qubits)
129
+ param_dict = {name: val for name, val in zip (self .param_names, param_values)}
130
+ state = apply_gate(state, self .feature_map + self .ansatz, {** param_dict, ** {' phi' : x}})
131
+ return overlap(state, apply_gate(state, self .observable))
132
+
133
+ def __call__ (self , param_values : Array, x : Array) -> Array:
134
+ return self .forward(param_values, x)
135
+
136
+ @ property
137
+ def n_vparams (self ) -> int :
138
+ return len (self .param_names)
139
+
140
+ circ = Circuit(n_qubits, n_layers)
89
141
# Create random initial values for the parameters
90
- key = jax.random.PRNGKey(0 )
91
- params = jax.random.uniform(key, shape = (n_qubits * len (ops),))
92
-
93
- def circ (state ) -> jax.Array:
94
- for qubit in range (n_qubits):
95
- for gate,param in zip (ops, params):
96
- state = apply_gate(state, gate(param, qubit))
97
- state = apply_gate(state,gates.NOT(1 , 0 ))
98
- projection = apply_gate(state, gates.Z(0 ))
99
- return overlap(state, projection)
100
-
101
- # Let's compute both values and gradients for a set of parameters and compile the circuit.
102
- circ = jax.jit(jax.value_and_grad(circ))
103
- # Run it on a state.
104
- expval_and_grads = circ(state)
105
- expval = expval_and_grads[0 ]
106
- grads = expval_and_grads[1 :]
107
- print (f ' Expval: { expval} ; '
108
- f ' Grads: { grads} ' )
142
+ key = jax.random.PRNGKey(42 )
143
+ param_vals = jax.random.uniform(key, shape = (circ.n_vparams,))
144
+ # Check the initial predictions using randomly initialized parameters
145
+ y_init = circ(param_vals, x)
146
+
147
+ optimizer = optax.adam(learning_rate = 0.01 )
148
+ opt_state = optimizer.init(param_vals)
149
+
150
+ # Define a loss function
151
+ def loss_fn (param_vals : Array, x : Array, y : Array) -> Array:
152
+ y_pred = circ(param_vals, x)
153
+ return jnp.mean(optax.l2_loss(y_pred, y))
154
+
155
+
156
+ def optimize_step (params : dict[str , Array], opt_state : Array, grads : dict[str , Array]) -> tuple :
157
+ updates, opt_state = optimizer.update(grads, opt_state)
158
+ params = optax.apply_updates(params, updates)
159
+ return params, opt_state
160
+
161
+ @jit
162
+ def train_step (i : int , inputs : tuple
163
+ ) -> tuple :
164
+ param_vals, opt_state = inputs
165
+ loss, grads = value_and_grad(loss_fn)(param_vals, x, y)
166
+ param_vals, opt_state = optimize_step(param_vals, opt_state, grads)
167
+ return param_vals, opt_state
168
+
169
+
170
+ n_epochs = 200
171
+ param_vals, opt_state = jax.lax.fori_loop(0 , n_epochs, train_step, (param_vals, opt_state))
172
+ y_final = circ(param_vals, x)
173
+
174
+ # Lets plot the results
175
+ import matplotlib.pyplot as plt
176
+ plt.plot(x, y, label = " truth" )
177
+ plt.plot(x, y_init, label = " initial" )
178
+ plt.plot(x, y_final, " --" , label = " final" , linewidth = 3 )
179
+ plt.legend()
180
+
181
+ from io import StringIO # markdown-exec: hide
182
+ from matplotlib.figure import Figure # markdown-exec: hide
183
+ def fig_to_html (fig : Figure) -> str : # markdown-exec: hide
184
+ buffer = StringIO() # markdown-exec: hide
185
+ fig.savefig(buffer, format = " svg" ) # markdown-exec: hide
186
+ return buffer.getvalue() # markdown-exec: hide
187
+ # from docs import docutils # markdown-exec: hide
188
+ print (fig_to_html(plt.gcf())) # markdown-exec: hide
109
189
```
0 commit comments