@@ -72,3 +72,38 @@ state = prepare_state(2, '11')
72
72
param_value = 1 / 4 * jnp.pi
73
73
new_state = apply_gate(state, Rx(param_value, target_qubit, control_qubit))
74
74
```
75
+
76
+ A fully differentiable variational circuit is simply a sequence of gates which are applied to a state.
77
+
78
+ ``` python exec="on" source="material-block"
79
+ import jax
80
+ import jax.numpy as jnp
81
+ from horqrux import gates
82
+ from horqrux.utils import prepare_state, overlap
83
+ from horqrux.ops import apply_gate
84
+
85
+ n_qubits = 2
86
+ state = prepare_state(2 , ' 00' )
87
+ # Lets define a sequence of rotations
88
+ ops = [gates.Rx, gates.Ry, gates.Rx]
89
+ # Create random initial values for the parameters
90
+ key = jax.random.PRNGKey(0 )
91
+ params = jax.random.uniform(key, shape = (n_qubits * len (ops),))
92
+
93
+ def circ (state ) -> jax.Array:
94
+ for qubit in range (n_qubits):
95
+ for gate,param in zip (ops, params):
96
+ state = apply_gate(state, gate(param, qubit))
97
+ state = apply_gate(state,gates.NOT(1 , 0 ))
98
+ projection = apply_gate(state, gates.Z(0 ))
99
+ return overlap(state, projection)
100
+
101
+ # Let's compute both values and gradients for a set of parameters and compile the circuit.
102
+ circ = jax.jit(jax.value_and_grad(circ))
103
+ # Run it on a state.
104
+ expval_and_grads = circ(state)
105
+ expval = expval_and_grads[0 ]
106
+ grads = expval_and_grads[1 :]
107
+ print (f ' Expval: { expval} ; '
108
+ f ' Grads: { grads} ' )
109
+ ```
0 commit comments