Skip to content

Commit dbb49cb

Browse files
[Docs] Add circuit example to tutorial (#4)
1 parent ab99161 commit dbb49cb

File tree

3 files changed

+40
-1
lines changed

3 files changed

+40
-1
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# horqrux
22

33
**horqrux** is a [JAX](https://jax.readthedocs.io/en/latest/)-based state vector simulator designed for quantum machine learning.
4-
It acts as the a backend for [`Qadence`](https://github.com/pasqal-io/qadence), a digital-analog quantum programming interface.
4+
It acts as a backend for [`Qadence`](https://github.com/pasqal-io/qadence), a digital-analog quantum programming interface.
55

66
## Installation
77

docs/index.md

+35
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,38 @@ state = prepare_state(2, '11')
7272
param_value = 1 / 4 * jnp.pi
7373
new_state = apply_gate(state, Rx(param_value, target_qubit, control_qubit))
7474
```
75+
76+
A fully differentiable variational circuit is simply a sequence of gates which are applied to a state.
77+
78+
```python exec="on" source="material-block"
79+
import jax
80+
import jax.numpy as jnp
81+
from horqrux import gates
82+
from horqrux.utils import prepare_state, overlap
83+
from horqrux.ops import apply_gate
84+
85+
n_qubits = 2
86+
state = prepare_state(2, '00')
87+
# Lets define a sequence of rotations
88+
ops = [gates.Rx, gates.Ry, gates.Rx]
89+
# Create random initial values for the parameters
90+
key = jax.random.PRNGKey(0)
91+
params = jax.random.uniform(key, shape=(n_qubits * len(ops),))
92+
93+
def circ(state) -> jax.Array:
94+
for qubit in range(n_qubits):
95+
for gate,param in zip(ops, params):
96+
state = apply_gate(state, gate(param, qubit))
97+
state = apply_gate(state,gates.NOT(1, 0))
98+
projection = apply_gate(state, gates.Z(0))
99+
return overlap(state, projection)
100+
101+
# Let's compute both values and gradients for a set of parameters and compile the circuit.
102+
circ = jax.jit(jax.value_and_grad(circ))
103+
# Run it on a state.
104+
expval_and_grads = circ(state)
105+
expval = expval_and_grads[0]
106+
grads = expval_and_grads[1:]
107+
print(f'Expval: {expval};'
108+
f'Grads: {grads}')
109+
```

horqrux/utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,7 @@ def equivalent_state(state: Array, reference_state: str) -> bool:
6767
n_qubits = state.ndim
6868
ref_state = prepare_state(n_qubits, reference_state)
6969
return jnp.allclose(state, ref_state) # type: ignore[no-any-return]
70+
71+
72+
def overlap(state: Array, projection: Array) -> Array:
73+
return jnp.real(jnp.dot(jnp.conj(state.flatten()), projection.flatten()))

0 commit comments

Comments
 (0)