Skip to content

Handle real-valued phases with RY-type rotations in state preparations#9561

Open
dwierichs wants to merge 6 commits into
mainfrom
mottonen-real-phases
Open

Handle real-valued phases with RY-type rotations in state preparations#9561
dwierichs wants to merge 6 commits into
mainfrom
mottonen-real-phases

Conversation

@dwierichs
Copy link
Copy Markdown
Contributor

Context:
We can arrange real-valued phases, i.e. signs, with RY-type rotations and don't need phase gates for that.
Currently, MultiplexerStatePreparation and MottonenStatePreparation do not know that.

Description of the Change:
Tell them.

Benefits:

Possible Drawbacks:

Related GitHub Issues:

@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 3, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 99.45%. Comparing base (7ccf8c4) to head (a70d8ba).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #9561      +/-   ##
==========================================
- Coverage   99.45%   99.45%   -0.01%     
==========================================
  Files         613      613              
  Lines       67867    67872       +5     
==========================================
+ Hits        67496    67500       +4     
- Misses        371      372       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment on lines -179 to +185
if (
type(obj).__name__ != "AbstractArray"
and not is_abstract(obj)
and allclose(ar.imag(obj), 0.0)
):
obj = ar.real(obj)
return not get_dtype_name(obj).startswith("complex")
if not get_dtype_name(obj).startswith("complex"):
return True
if type(obj).__name__ != "AbstractArray":
imag = ar.imag(obj)
if not is_abstract(imag) and allclose(imag, 0.0):
return True
return False
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This previously was hitting an annoying edge case if state_vector was a closure variable of a JITted function that is a jax.numpy.ndarray. Calling ar.imag on it would convert it to a tracer, so that we get a BoolTracer back from this function, even though we checked that state_vector itself is not a tracer :D

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot 😓 the previous logic block also reads like something I would never parse correctly within 1 min before I get a supercomputer implanted in my brain

Comment on lines -144 to -158
for i in range(num_iterations):
shapes.append([int(2 ** (i + 1)), -1])
probs_aux = math.reshape(probs, [1, -1])

# From Eq. 5 of arXiv:quant-ph/0208112.
for itx in range(i + 1):
probs_denominator = math.sum(probs_aux, axis=1)
probs_aux = math.reshape(probs_aux, shapes[itx])
probs_numerator = math.sum(probs_aux, axis=1)[::2]

# arcos(x) = arctan2(sqrt(1-x^2), x)
thetas = 2 * math.arctan2(
math.sqrt(probs_denominator - probs_numerator),
math.sqrt(probs_numerator),
)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This processing was basically a duplication of _get_alpha_y

output = dev.execute(tape[0])[0]

assert np.allclose(state, output, atol=0.05)
assert np.allclose(state, output, atol=1e-5)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This very loose atol was not needed as far as I can see. I also wouldn't know why it would be needed...

Comment thread pennylane/templates/state_preparations/multiplexer_state_prep.py Outdated
Comment thread pennylane/math/__init__.py
Comment on lines 264 to 266
indices_numerator = (qp.math.arange(1, 2 ** (n - k + 1) + 1, 2) * 2 ** (k - 1))[
:, None
] + np.arange(2 ** (k - 1))[None]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance we could also improve this part? I'll be surprised if anyone could understand the formula we are following immediately

Comment on lines +256 to 272
if k == 1:
# At the leaf level, use arctan2 with signed amplitudes to correctly encode
# the sign of real-valued states into the Y rotation angle, avoiding the need
# for subsequent Z rotations or DiagonalQubitUnitary gates.
even = qp.math.take(a, indices=qp.math.arange(0, 2**n, 2), axis=-1)
odd = qp.math.take(a, indices=qp.math.arange(1, 2**n, 2), axis=-1)
return 2 * qp.math.arctan2(odd, even)

indices_numerator = (qp.math.arange(1, 2 ** (n - k + 1) + 1, 2) * 2 ** (k - 1))[
:, None
] + np.arange(2 ** (k - 1))[None]
numerator = qp.math.take(a, indices=indices_numerator, axis=-1)
numerator = qp.math.sum(qp.math.abs(numerator) ** 2, axis=-1)

indices_denominator = (qp.math.arange(2 ** (n - k)) * 2**k)[:, None] + np.arange(2**k)[None]
denominator = qp.math.take(a, indices=indices_denominator, axis=-1)
denominator = qp.math.sum(qp.math.abs(denominator) ** 2, axis=-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Reshape a so that each block of size 2**k forms the final axis
shape = qp.math.shape(a)[:-1] + (2 ** (n - k), 2**k)
a_reshaped = qp.math.reshape(a, shape)
if k == 1: # leaf level
even = a_reshaped[..., 0]
odd = a_reshaped[..., 1]
return 2 * qp.math.arctan2(odd, even)
# Precompute absolute squares to avoid doing it twice
abs_sq = qp.math.abs(a_reshaped) ** 2
# Denominator is the sum over the entire block
denominator = qp.math.sum(abs_sq, axis=-1)
# Numerator is the sum over the second half of that block
numerator = qp.math.sum(abs_sq[..., 2 ** (k - 1) :], axis=-1)

I wonder if this is doing the same thing. If so, this feels much better

dwierichs and others added 2 commits June 3, 2026 16:53
Co-authored-by: Yushao Chen (Jerry) <chenys13@outlook.com>
Co-authored-by: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com>
if not is_real:
omega = math.angle(state_vector)
if math.is_abstract(omega) or math.requires_grad(omega) or not math.allclose(omega, 0):
qp.DiagonalQubitUnitary(math.exp(1j * omega), wires=wires)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
qp.DiagonalQubitUnitary(math.exp(1j * omega), wires=wires)
qp.DiagonalQubitUnitary(math.exp(1j * omega), wires=wires)

ooooops my bad

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants