Handle real-valued phases with RY-type rotations in state preparations#9561
Handle real-valued phases with RY-type rotations in state preparations#9561dwierichs wants to merge 6 commits into
RY-type rotations in state preparations#9561Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #9561 +/- ##
==========================================
- Coverage 99.45% 99.45% -0.01%
==========================================
Files 613 613
Lines 67867 67872 +5
==========================================
+ Hits 67496 67500 +4
- Misses 371 372 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| if ( | ||
| type(obj).__name__ != "AbstractArray" | ||
| and not is_abstract(obj) | ||
| and allclose(ar.imag(obj), 0.0) | ||
| ): | ||
| obj = ar.real(obj) | ||
| return not get_dtype_name(obj).startswith("complex") | ||
| if not get_dtype_name(obj).startswith("complex"): | ||
| return True | ||
| if type(obj).__name__ != "AbstractArray": | ||
| imag = ar.imag(obj) | ||
| if not is_abstract(imag) and allclose(imag, 0.0): | ||
| return True | ||
| return False |
There was a problem hiding this comment.
This previously was hitting an annoying edge case if state_vector was a closure variable of a JITted function that is a jax.numpy.ndarray. Calling ar.imag on it would convert it to a tracer, so that we get a BoolTracer back from this function, even though we checked that state_vector itself is not a tracer :D
There was a problem hiding this comment.
Thanks a lot 😓 the previous logic block also reads like something I would never parse correctly within 1 min before I get a supercomputer implanted in my brain
| for i in range(num_iterations): | ||
| shapes.append([int(2 ** (i + 1)), -1]) | ||
| probs_aux = math.reshape(probs, [1, -1]) | ||
|
|
||
| # From Eq. 5 of arXiv:quant-ph/0208112. | ||
| for itx in range(i + 1): | ||
| probs_denominator = math.sum(probs_aux, axis=1) | ||
| probs_aux = math.reshape(probs_aux, shapes[itx]) | ||
| probs_numerator = math.sum(probs_aux, axis=1)[::2] | ||
|
|
||
| # arcos(x) = arctan2(sqrt(1-x^2), x) | ||
| thetas = 2 * math.arctan2( | ||
| math.sqrt(probs_denominator - probs_numerator), | ||
| math.sqrt(probs_numerator), | ||
| ) |
There was a problem hiding this comment.
This processing was basically a duplication of _get_alpha_y
| output = dev.execute(tape[0])[0] | ||
|
|
||
| assert np.allclose(state, output, atol=0.05) | ||
| assert np.allclose(state, output, atol=1e-5) |
There was a problem hiding this comment.
This very loose atol was not needed as far as I can see. I also wouldn't know why it would be needed...
| indices_numerator = (qp.math.arange(1, 2 ** (n - k + 1) + 1, 2) * 2 ** (k - 1))[ | ||
| :, None | ||
| ] + np.arange(2 ** (k - 1))[None] |
There was a problem hiding this comment.
Any chance we could also improve this part? I'll be surprised if anyone could understand the formula we are following immediately
| if k == 1: | ||
| # At the leaf level, use arctan2 with signed amplitudes to correctly encode | ||
| # the sign of real-valued states into the Y rotation angle, avoiding the need | ||
| # for subsequent Z rotations or DiagonalQubitUnitary gates. | ||
| even = qp.math.take(a, indices=qp.math.arange(0, 2**n, 2), axis=-1) | ||
| odd = qp.math.take(a, indices=qp.math.arange(1, 2**n, 2), axis=-1) | ||
| return 2 * qp.math.arctan2(odd, even) | ||
|
|
||
| indices_numerator = (qp.math.arange(1, 2 ** (n - k + 1) + 1, 2) * 2 ** (k - 1))[ | ||
| :, None | ||
| ] + np.arange(2 ** (k - 1))[None] | ||
| numerator = qp.math.take(a, indices=indices_numerator, axis=-1) | ||
| numerator = qp.math.sum(qp.math.abs(numerator) ** 2, axis=-1) | ||
|
|
||
| indices_denominator = (qp.math.arange(2 ** (n - k)) * 2**k)[:, None] + np.arange(2**k)[None] | ||
| denominator = qp.math.take(a, indices=indices_denominator, axis=-1) | ||
| denominator = qp.math.sum(qp.math.abs(denominator) ** 2, axis=-1) |
There was a problem hiding this comment.
| # Reshape a so that each block of size 2**k forms the final axis | |
| shape = qp.math.shape(a)[:-1] + (2 ** (n - k), 2**k) | |
| a_reshaped = qp.math.reshape(a, shape) | |
| if k == 1: # leaf level | |
| even = a_reshaped[..., 0] | |
| odd = a_reshaped[..., 1] | |
| return 2 * qp.math.arctan2(odd, even) | |
| # Precompute absolute squares to avoid doing it twice | |
| abs_sq = qp.math.abs(a_reshaped) ** 2 | |
| # Denominator is the sum over the entire block | |
| denominator = qp.math.sum(abs_sq, axis=-1) | |
| # Numerator is the sum over the second half of that block | |
| numerator = qp.math.sum(abs_sq[..., 2 ** (k - 1) :], axis=-1) |
I wonder if this is doing the same thing. If so, this feels much better
Co-authored-by: Yushao Chen (Jerry) <chenys13@outlook.com>
Co-authored-by: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com>
| if not is_real: | ||
| omega = math.angle(state_vector) | ||
| if math.is_abstract(omega) or math.requires_grad(omega) or not math.allclose(omega, 0): | ||
| qp.DiagonalQubitUnitary(math.exp(1j * omega), wires=wires) |
There was a problem hiding this comment.
| qp.DiagonalQubitUnitary(math.exp(1j * omega), wires=wires) | |
| qp.DiagonalQubitUnitary(math.exp(1j * omega), wires=wires) |
ooooops my bad
Context:
We can arrange real-valued phases, i.e. signs, with
RY-type rotations and don't need phase gates for that.Currently,
MultiplexerStatePreparationandMottonenStatePreparationdo not know that.Description of the Change:
Tell them.
Benefits:
Possible Drawbacks:
Related GitHub Issues: