Skip to content

Commit 0db46bd

Browse files
[Documentation] Add DQC Tutorial (#12)
1 parent 491cdb3 commit 0db46bd

9 files changed

+382
-173
lines changed

README.md

+18-21
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,43 @@
1-
# horqrux
1+
[![Linting / Tests/ Documentation](https://github.com/pasqal-io/horqrux/actions/workflows/run-tests-and-mypy.yml/badge.svg)](https://github.com/pasqal-io/horqrux/actions/workflows/run-tests-and-mypy.yml)
2+
[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
3+
[![Pypi](https://badge.fury.io/py/horqrux.svg)](https://pypi.org/project/horqrux/)
24

3-
**horqrux** is a [JAX](https://jax.readthedocs.io/en/latest/)-based state vector simulator designed for quantum machine learning.
4-
It acts as a backend for [`Qadence`](https://github.com/pasqal-io/qadence), a digital-analog quantum programming interface.
5+
`horqrux` is a [JAX](https://jax.readthedocs.io/en/latest/)-based state vector simulator designed for quantum machine learning and acts as a backend for [`Qadence`](https://github.com/pasqal-io/qadence), a digital-analog quantum programming interface.
56

67
## Installation
78

8-
`horqrux` (CPU-only) can be installed from PyPI with `pip` as follows:
9+
To install the CPU-only version, simply use `pip`:
910
```bash
1011
pip install horqrux
1112
```
12-
If you want to install the GPU version, simply do:
13+
If you intend to use GPU:
1314

1415
```bash
1516
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
1617
```
1718

18-
[![Linting / Tests/ Documentation](https://github.com/pasqal-io/horqrux/actions/workflows/run-tests-and-mypy.yml/badge.svg)](https://github.com/pasqal-io/horqrux/actions/workflows/run-tests-and-mypy.yml)
19-
[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
20-
[![Pypi](https://badge.fury.io/py/horqrux.svg)](https://pypi.org/project/horqrux/)
19+
## Getting started
20+
`horqrux` adopts a minimalistic and functional interface however the [docs](https://pasqal-io.github.io/horqrux/latest/) provide a comprehensive A-Z guide ranging from how to apply simple primitive and parametric gates, to using [adjoint differentiation](https://arxiv.org/abs/2009.02823) to fit a nonlinear function and implementing [DQC](https://arxiv.org/abs/2011.10395) to solve a partial differential equation.
2121

22+
## Contributing
2223

23-
## Install from source
24+
To learn how to contribute, please visit the [CONTRIBUTING](docs/CONTRIBUTING.md) page.
2425

25-
We recommend to use the [`hatch`](https://hatch.pypa.io/latest/) environment manager to install `horqrux` from source:
26+
When developing within `horqrux`, you can either use the python environment manager [`hatch`](https://hatch.pypa.io/latest/):
2627

2728
```bash
28-
python -m pip install hatch
29+
pip install hatch
2930

30-
# get into a shell with all the dependencies
31-
python -m hatch shell
31+
# enter a shell with containing all the dependencies
32+
hatch shell
3233

3334
# run a command within the virtual environment with all the dependencies
34-
python -m hatch run python my_script.py
35+
hatch run python my_script.py
3536
```
3637

37-
Please note that `hatch` will not combine nicely with other environment managers such Conda. If you want to use Conda, install `horqrux` from source using `pip`:
38+
When using any other environment manager like `venv` or `conda`, simply do:
3839

3940
```bash
40-
# within the Conda environment
41-
python -m pip install -e .
41+
# within the virtual environment
42+
pip install -e .
4243
```
43-
44-
## Contributing
45-
46-
Please refer to [CONTRIBUTING](docs/CONTRIBUTING.md) to learn how to contribute to `horqrux`.

docs/index.md

+220-13
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ choice and install it normally with `pip`:
1111
pip install horqrux
1212
```
1313

14-
## Gates
14+
## Digital operations
1515

1616
`horqrux` implements a large selection of both primitive and parametric single to n-qubit, digital quantum gates.
1717

@@ -68,10 +68,34 @@ param_value = 1 / 4 * jnp.pi
6868
new_state = apply_gate(state, RX(param_value, target_qubit, control_qubit))
6969
```
7070

71+
## Analog Operations
72+
73+
`horqrux` also allows for global state evolution via the `HamiltonianEvolution` operation.
74+
Note that it expects a hamiltonian and a time evolution parameter passed as `numpy` or `jax.numpy` arrays. To build arbitrary Pauli hamiltonians, we recommend using [Qadence](https://github.com/pasqal-io/qadence/blob/main/examples/backends/low_level/horqrux_analog.py).
75+
76+
```python exec="on" source="material-block"
77+
from jax.numpy import pi, array, diag, kron, cdouble
78+
from horqrux.analog import HamiltonianEvolution
79+
from horqrux.apply import apply_gate
80+
from horqrux.utils import uniform_state
81+
82+
sigmaz = diag(array([1.0, -1.0], dtype=cdouble))
83+
Hbase = kron(sigmaz, sigmaz)
84+
85+
Hamiltonian = kron(Hbase, Hbase)
86+
n_qubits = 4
87+
t_evo = pi / 4
88+
hamevo = HamiltonianEvolution(tuple([i for i in range(n_qubits)]))
89+
psi = uniform_state(n_qubits)
90+
psi_star = apply_gate(psi, hamevo, {"hamiltonian": Hamiltonian, "time_evolution": t_evo})
91+
```
92+
93+
## Fitting a nonlinear function using adjoint differentiation
94+
7195
We can now build a fully differentiable variational circuit by simply defining a sequence of gates
7296
and a set of initial parameter values we want to optimize.
73-
Horqrux provides an implementation of [adjoint differentiation](https://arxiv.org/abs/2009.02823),
74-
which we can use to fit a function using a simple circuit class wrapper.
97+
`horqrux` provides an implementation of [adjoint differentiation](https://arxiv.org/abs/2009.02823),
98+
which we can use to fit a function using a simple `Circuit` class.
7599

76100
```python exec="on" source="material-block" html="1"
77101
from __future__ import annotations
@@ -87,7 +111,7 @@ from typing import Any, Callable
87111
from uuid import uuid4
88112

89113
from horqrux.adjoint import adjoint_expectation
90-
from horqrux.abstract import Primitive
114+
from horqrux.primitive import Primitive
91115
from horqrux import Z, RX, RY, NOT, zero_state, apply_gate
92116

93117

@@ -121,18 +145,16 @@ class Circuit:
121145

122146
def __post_init__(self) -> None:
123147
# We will use a featuremap of RX rotations to encode some classical data
124-
self.feature_map: list[Primitive] = [RX('phi', i) for i in range(n_qubits)]
148+
self.feature_map: list[Primitive] = [RX('phi', i) for i in range(self.n_qubits)]
125149
self.ansatz, self.param_names = ansatz_w_params(self.n_qubits, self.n_layers)
126150
self.observable: list[Primitive] = [Z(0)]
127151

128152
@partial(vmap, in_axes=(None, None, 0))
129-
def forward(self, param_values: Array, x: Array) -> Array:
153+
def __call__(self, param_values: Array, x: Array) -> Array:
130154
state = zero_state(self.n_qubits)
131155
param_dict = {name: val for name, val in zip(self.param_names, param_values)}
132156
return adjoint_expectation(state, self.feature_map + self.ansatz, self.observable, {**param_dict, **{'phi': x}})
133157

134-
def __call__(self, param_values: Array, x: Array) -> Array:
135-
return self.forward(param_values, x)
136158

137159
@property
138160
def n_vparams(self) -> int:
@@ -154,15 +176,15 @@ def loss_fn(param_vals: Array, x: Array, y: Array) -> Array:
154176
return jnp.mean(optax.l2_loss(y_pred, y))
155177

156178

157-
def optimize_step(params: dict[str, Array], opt_state: Array, grads: dict[str, Array]) -> tuple:
179+
def optimize_step(param_vals: Array, opt_state: Array, grads: Array) -> tuple:
158180
updates, opt_state = optimizer.update(grads, opt_state)
159-
params = optax.apply_updates(params, updates)
160-
return params, opt_state
181+
param_vals = optax.apply_updates(param_vals, updates)
182+
return param_vals, opt_state
161183

162184
@jit
163-
def train_step(i: int, inputs: tuple
185+
def train_step(i: int, paramvals_w_optstate: tuple
164186
) -> tuple:
165-
param_vals, opt_state = inputs
187+
param_vals, opt_state = paramvals_w_optstate
166188
loss, grads = value_and_grad(loss_fn)(param_vals, x, y)
167189
param_vals, opt_state = optimize_step(param_vals, opt_state, grads)
168190
return param_vals, opt_state
@@ -188,3 +210,188 @@ def fig_to_html(fig: Figure) -> str: # markdown-exec: hide
188210
# from docs import docutils # markdown-exec: hide
189211
print(fig_to_html(plt.gcf())) # markdown-exec: hide
190212
```
213+
## Fitting a partial differential equation using DQC
214+
215+
Finally, we show how [DQC](https://arxiv.org/abs/2011.10395) can be implemented in `horqrux` and solve a partial differential equation.
216+
217+
```python exec="on" source="material-block" html="1"
218+
from __future__ import annotations
219+
220+
from dataclasses import dataclass
221+
from functools import reduce
222+
from itertools import product
223+
from operator import add
224+
from uuid import uuid4
225+
226+
import jax
227+
import jax.numpy as jnp
228+
import matplotlib.pyplot as plt
229+
import numpy as np
230+
import optax
231+
from jax import Array, jit, value_and_grad, vmap
232+
from numpy.random import uniform
233+
234+
from horqrux import NOT, RX, RY, Z, apply_gate, zero_state
235+
from horqrux.primitive import Primitive
236+
from horqrux.utils import inner
237+
238+
LEARNING_RATE = 0.01
239+
N_QUBITS = 4
240+
DEPTH = 3
241+
VARIABLES = ("x", "y")
242+
X_POS = 0
243+
Y_POS = 1
244+
N_POINTS = 150
245+
N_EPOCHS = 1000
246+
247+
248+
def ansatz_w_params(n_qubits: int, n_layers: int) -> tuple[list, list]:
249+
all_ops = []
250+
param_names = []
251+
rots_fns = [RX, RY, RX]
252+
for _ in range(n_layers):
253+
for i in range(n_qubits):
254+
ops = [
255+
fn(str(uuid4()), qubit)
256+
for fn, qubit in zip(rots_fns, [i for _ in range(len(rots_fns))])
257+
]
258+
param_names += [op.param for op in ops]
259+
ops += [NOT((i + 1) % n_qubits, i % n_qubits) for i in range(n_qubits)]
260+
all_ops += ops
261+
262+
return all_ops, param_names
263+
264+
265+
@dataclass
266+
class TotalMagnetization:
267+
n_qubits: int
268+
269+
def __post_init__(self) -> None:
270+
self.paulis = [Z(i) for i in range(self.n_qubits)]
271+
272+
def __call__(self, state: Array, values: dict) -> Array:
273+
return reduce(add, [apply_gate(state, pauli, values) for pauli in self.paulis])
274+
275+
276+
@dataclass
277+
class Circuit:
278+
n_qubits: int
279+
n_layers: int
280+
281+
def __post_init__(self) -> None:
282+
self.feature_map: list[Primitive] = [RX("x", i) for i in range(self.n_qubits // 2)] + [
283+
RX("y", i) for i in range(self.n_qubits // 2, self.n_qubits)
284+
]
285+
self.ansatz, self.param_names = ansatz_w_params(self.n_qubits, self.n_layers)
286+
self.observable = TotalMagnetization(self.n_qubits)
287+
288+
def __call__(self, param_vals: Array, x: Array, y: Array) -> Array:
289+
state = zero_state(self.n_qubits)
290+
param_dict = {name: val for name, val in zip(self.param_names, param_vals)}
291+
out_state = apply_gate(
292+
state, self.feature_map + self.ansatz, {**param_dict, **{"x": x, "y": y}}
293+
)
294+
projected_state = self.observable(state, param_dict)
295+
return jnp.real(inner(out_state, projected_state))
296+
297+
@property
298+
def n_vparams(self) -> int:
299+
return len(self.param_names)
300+
301+
302+
circ = Circuit(N_QUBITS, DEPTH)
303+
# Create random initial values for the parameters
304+
key = jax.random.PRNGKey(42)
305+
param_vals = jax.random.uniform(key, shape=(circ.n_vparams,))
306+
307+
optimizer = optax.adam(learning_rate=0.01)
308+
opt_state = optimizer.init(param_vals)
309+
310+
311+
def exp_fn(param_vals: Array, x: Array, y: Array) -> Array:
312+
return circ(param_vals, x, y)
313+
314+
315+
def loss_fn(param_vals: Array, x: Array, y: Array) -> Array:
316+
def pde_loss(x: float, y: float) -> Array:
317+
l_b, r_b, t_b, b_b = list(
318+
map(
319+
lambda xy: exp_fn(param_vals, *xy),
320+
[
321+
[jnp.zeros((1, 1)), y], # u(0,y)=0
322+
[jnp.ones((1, 1)), y], # u(L,y)=0
323+
[x, jnp.ones((1, 1))], # u(x,H)=0
324+
[x, jnp.zeros((1, 1))], # u(x,0)=f(x)
325+
],
326+
)
327+
)
328+
b_b -= jnp.sin(jnp.pi * x)
329+
hessian = jax.hessian(lambda xy: exp_fn(param_vals, xy[0], xy[1]))(
330+
jnp.concatenate(
331+
[
332+
x.reshape(
333+
1,
334+
),
335+
y.reshape(
336+
1,
337+
),
338+
]
339+
)
340+
)
341+
interior = hessian[X_POS][X_POS] + hessian[Y_POS][Y_POS] # uxx+uyy=0
342+
return reduce(add, list(map(lambda term: jnp.power(term, 2), [l_b, r_b, t_b, b_b, interior])))
343+
344+
return jnp.mean(vmap(pde_loss, in_axes=(0, 0))(x, y))
345+
346+
347+
def optimize_step(param_vals: Array, opt_state: Array, grads: dict[str, Array]) -> tuple:
348+
updates, opt_state = optimizer.update(grads, opt_state, param_vals)
349+
param_vals = optax.apply_updates(param_vals, updates)
350+
return param_vals, opt_state
351+
352+
353+
# collocation points sampling and training
354+
def sample_points(n_in: int, n_p: int) -> Array:
355+
return uniform(0, 1.0, (n_in, n_p))
356+
357+
358+
@jit
359+
def train_step(i: int, paramvals_w_optstate: tuple) -> tuple:
360+
param_vals, opt_state = paramvals_w_optstate
361+
x, y = sample_points(2, N_POINTS)
362+
loss, grads = value_and_grad(loss_fn)(param_vals, x, y)
363+
return optimize_step(param_vals, opt_state, grads)
364+
365+
366+
param_vals, opt_state = jax.lax.fori_loop(0, N_EPOCHS, train_step, (param_vals, opt_state))
367+
# compare the solution to known ground truth
368+
single_domain = jnp.linspace(0, 1, num=N_POINTS)
369+
domain = jnp.array(list(product(single_domain, single_domain)))
370+
# analytical solution
371+
analytic_sol = (
372+
(np.exp(-np.pi * domain[:, 0]) * np.sin(np.pi * domain[:, 1])).reshape(N_POINTS, N_POINTS).T
373+
)
374+
# DQC solution
375+
376+
dqc_sol = vmap(lambda domain: exp_fn(param_vals, domain[0], domain[1]), in_axes=(0,))(domain).reshape(
377+
N_POINTS, N_POINTS
378+
)
379+
# # plot results
380+
fig, ax = plt.subplots(1, 2, figsize=(7, 7))
381+
ax[0].imshow(analytic_sol, cmap="turbo")
382+
ax[0].set_xlabel("x")
383+
ax[0].set_ylabel("y")
384+
ax[0].set_title("Analytical solution u(x,y)")
385+
ax[1].imshow(dqc_sol, cmap="turbo")
386+
ax[1].set_xlabel("x")
387+
ax[1].set_ylabel("y")
388+
ax[1].set_title("DQC solution u(x,y)")
389+
from io import StringIO # markdown-exec: hide
390+
from matplotlib.figure import Figure # markdown-exec: hide
391+
def fig_to_html(fig: Figure) -> str: # markdown-exec: hide
392+
buffer = StringIO() # markdown-exec: hide
393+
fig.savefig(buffer, format="svg") # markdown-exec: hide
394+
return buffer.getvalue() # markdown-exec: hide
395+
# from docs import docutils # markdown-exec: hide
396+
print(fig_to_html(plt.gcf())) # markdown-exec: hide
397+
```

0 commit comments

Comments
 (0)