Skip to content

Improved decomposition of DiagonalQubitUnitary #7370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 62 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
d6545e8
New decompostion; tests; changelog; dev comments; docs.
dwierichs Mar 31, 2025
dcbcd3f
format, bit_count
dwierichs Mar 31, 2025
98490c6
update
dwierichs Mar 31, 2025
2e9a1e2
Merge branch 'master' into pcphase-decomp
dwierichs Mar 31, 2025
d8ce851
extract controlled phase shift subroutine
dwierichs Apr 2, 2025
3aa0519
docstring
dwierichs Apr 2, 2025
e2b1a89
merge
dwierichs Apr 2, 2025
a0e041a
formatting
dwierichs Apr 2, 2025
3964084
Merge branch 'master' into pcphase-decomp
dwierichs Apr 2, 2025
0ab7fd6
[skip ci] new numbering
dwierichs Apr 2, 2025
ba40c67
Merge branch 'master' into pcphase-decomp
dwierichs Apr 3, 2025
82a079c
tiny
dwierichs Apr 3, 2025
07c4093
numbering is not easy in sphinx RST
dwierichs Apr 3, 2025
8bcec72
first review
dwierichs Apr 3, 2025
8d45023
some doc progress?
dwierichs Apr 3, 2025
f9ab42b
?
dwierichs Apr 3, 2025
f3d7b16
table?
dwierichs Apr 3, 2025
7e1dd75
table!
dwierichs Apr 3, 2025
575291a
outsource algo description to markdown file. replace docstring with e…
dwierichs Apr 7, 2025
282839a
polish
dwierichs Apr 7, 2025
f09eda7
polish md
dwierichs Apr 7, 2025
b7d728b
align test
dwierichs Apr 7, 2025
f33251b
align test
dwierichs Apr 7, 2025
2afde61
polish
dwierichs Apr 7, 2025
4f99267
whitespace
dwierichs Apr 28, 2025
5247589
polish
dwierichs Apr 28, 2025
f8a842e
Apply suggestions from code review
dwierichs Apr 29, 2025
ebd2bd1
link
dwierichs Apr 29, 2025
86d3ad2
Merge branch 'master' into pcphase-decomp
dwierichs Apr 29, 2025
8a74f7f
changelog
dwierichs Apr 29, 2025
e82409d
start
dwierichs May 2, 2025
dfd963b
docstring
dwierichs May 2, 2025
8e67cb9
conclude
dwierichs May 2, 2025
50f1274
merge
dwierichs May 2, 2025
c1685e1
revert unrelated
dwierichs May 2, 2025
ee6e4e1
changelog
dwierichs May 2, 2025
b46eefd
Merge branch 'master' into selectpaulirot-fwh
dwierichs May 2, 2025
98e4cc7
start
dwierichs May 3, 2025
3596a3f
black
dwierichs May 3, 2025
52d9e0d
n=0 case
dwierichs May 4, 2025
d69dca5
changelog
dwierichs May 4, 2025
0f11320
SelectPauliRot batching
dwierichs May 4, 2025
1f4c588
merge
dwierichs May 6, 2025
2a09b58
SelectPauliRot tests
dwierichs May 6, 2025
ec49c5e
gray code speed up
dwierichs May 6, 2025
7a10bef
merge
dwierichs May 6, 2025
d46921c
merge Uniformly controlled rotation PR in
dwierichs May 6, 2025
3c96fc0
test dependency fix
dwierichs May 6, 2025
79369cd
actualy PR contrib
dwierichs May 6, 2025
4089d47
Merge branch 'master' into selectpaulirot-fwh
dwierichs May 6, 2025
f75b4d4
review
dwierichs May 6, 2025
1a96b81
Merge branch 'master' into selectpaulirot-fwh
dwierichs May 7, 2025
6550b0a
Merge branch 'master' into selectpaulirot-fwh
dwierichs May 7, 2025
6b05fde
seed
dwierichs May 7, 2025
79cc670
Apply suggestions from code review
dwierichs May 7, 2025
ce03157
Merge branch 'selectpaulirot-fwh' into better-diag-decomp
dwierichs May 7, 2025
ac96fab
merge
dwierichs May 8, 2025
9990bb1
docstring improvement
dwierichs May 8, 2025
5f43940
Merge branch 'master' into better-diag-decomp
dwierichs May 8, 2025
6d205c7
empty
dwierichs May 8, 2025
b77349a
review
dwierichs May 9, 2025
6a6aab5
merge
dwierichs May 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@

<h3>Improvements 🛠</h3>

* The decomposition of `DiagonalQubitUnitary` has been improved to contain fewer gates.
[(#7370)](https://github.com/PennyLaneAI/pennylane/pull/7370)

* PennyLane supports `JAX` version 0.5.3.
[(#6919)](https://github.com/PennyLaneAI/pennylane/pull/6919)

Expand Down
109 changes: 71 additions & 38 deletions pennylane/ops/qubit/matrix_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""
# pylint:disable=arguments-differ
import warnings
from itertools import product
from typing import Optional, Union

import numpy as np
Expand Down Expand Up @@ -480,49 +479,83 @@ def compute_decomposition(D: TensorLike, wires: WiresLike) -> list["qml.operatio
Returns:
list[Operator]: decomposition into lower level operations

Implements Theorem 7 of `Shende et al. <https://arxiv.org/abs/quant-ph/0406176>`__.
Decomposing a ``DiagonalQubitUnitary`` on :math:`n` wires (:math:`n>1`) yields a
uniformly-controlled :math:`R_Z` gate, or :class:`~.SelectPauliRot` gate, as well as a
``DiagonalQubitUnitary`` on :math:`n-1` wires. For :math:`n=1` wires, the decomposition
yields a :class:`~.RZ` gate and a :class:`~.GlobalPhase`.
Resolving this recursion relationship, one would obtain :math:`n-1` ``SelectPauliRot``
gates with :math:`n, n-1, \dots, 1` controls each, a single ``RZ`` gate, and
a ``GlobalPhase``.

**Example:**

>>> diag = np.exp(1j * np.array([0.4, 2.1, 0.5, 1.8]))
>>> qml.DiagonalQubitUnitary.compute_decomposition(diag, wires=[0, 1])
[QubitUnitary(array([[0.36235775+0.93203909j, 0. +0.j ],
[0. +0.j , 0.36235775+0.93203909j]]), wires=[0]),
RZ(1.5000000000000002, wires=[1]),
RZ(-0.10000000000000003, wires=[0]),
IsingZZ(0.2, wires=[0, 1])]
[SelectPauliRot(array([1.7, 1.3]), wires=[0, 1]),
DiagonalQubitUnitary(array([0.31532236+0.94898462j, 0.40848744+0.91276394j]), wires=[0])]

"""
n = len(wires)

# Cast the diagonal into a complex dtype so that the logarithm works as expected
D_casted = qml.math.cast(D, "complex128")

phases = qml.math.real(qml.math.log(D_casted) * (-1j))
coeffs = _walsh_hadamard_transform(phases, n).T
global_phase = qml.math.exp(1j * coeffs[0])
# For all other gates, there is a prefactor -1/2 to be compensated.
coeffs = coeffs * (-2.0)

# TODO: Replace the following by a GlobalPhase gate.
ops = [QubitUnitary(qml.math.tensordot(global_phase, qml.math.eye(2), axes=0), wires[0])]
for wire0 in range(n):
# Single PauliZ generators correspond to the coeffs at powers of two
ops.append(qml.RZ(coeffs[1 << wire0], wires[n - 1 - wire0]))
# Double PauliZ generators correspond to the coeffs at the sum of two powers of two
ops.extend(
qml.IsingZZ(
coeffs[(1 << wire0) + (1 << wire1)],
[wires[n - 1 - wire0], wires[n - 1 - wire1]],
)
for wire1 in range(wire0)
)
**Finding the parameters:**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the following should be placed inside a .. details:: section?


# Add all multi RZ gates that are not generated by single or double PauliZ generators
ops.extend(
qml.MultiRZ(c, [wires[k] for k in np.where(term)[0]])
for c, term in zip(coeffs, product((0, 1), repeat=n))
if sum(term) > 2
)
return ops
Theorem 7 referenced above only tells us the structure of the circuit, but not the
parameters for the ``SelectPauliRot`` and ``DiagonalQubitUnitary`` in the decomposition.
In the following, we will only write out the diagonals of all gates.
Consider a ``DiagonalQubitUnitary`` on :math:`n` qubits that we want to decompose:

.. math::

D(\theta) = (\exp(i\theta_0), \exp(i\theta_1), \dots,
\exp(i\theta_{N-2}), \exp(i\theta_{N-1})).

Here, :math:`N=2^n` is the Hilbert space dimension for :math:`n` qubits, which is
the same as the number of parameters in :math:`D`.

A ``SelectPauliRot`` gate using ``RZ`` rotations, or multiplexed ``RZ`` rotation, using the
first :math:`n-1` qubits as controls and the last qubit as target, takes the form

.. math::

UCR_Z(\phi) = (\exp(-\frac{i}{2}\phi_0), \exp(\frac{i}{2}\phi_0), \dots,
\exp(-\frac{i}{2}\phi_{N/2-1}), \exp(\frac{i}{2}\phi_{N/2-1})),

i.e., it moves the phase of neighbouring pairs of computational basis states by
the same amount, but in opposite direction. There are :math:`N/2` parameters
in this gate.
Similarly, a ``DiagonalQubitUnitary`` acting on the first :math:`n-1` qubits only (the
ones that were controls for ``SelectPauliRot``) takes the form

.. math::

D'(\theta') = (\exp(i\theta'_0), \exp(i\theta'_0), \dots,
\exp(i\theta'_{N/2-1}), \exp(i\theta'_{N/2-1})).

That is, :math:`D'` moves the phase of neighbouring pairs of basis states by the same
amount and in the same direction. It, too, has :math:`N/2` parameters.
Now, we see that we can compute the rotation angles,
or phases, :math:`\phi` and :math:`\theta'` quite easily from the original :math:`\theta`:

.. math::

(\exp(i\theta_{2i}), \exp(i\theta_{2i+1})) &=
(\exp(-\frac{i}{2}\phi_i)\exp(i\theta'_i), \exp(\frac{i}{2}\phi_i)\exp(i\theta'_i))\\
\Rightarrow \qquad \theta'_i &=\frac{1}{2}(\theta_{2i}+\theta_{2i+1})\\
\phi_i &=\theta_{2i+1}-\theta_{2i}.

So the phases for the new gates arise simply as difference and average of the odd-indexed
and even-indexed phases.
"""
angles = qml.math.angle(D)
diff = angles[..., 1::2] - angles[..., ::2]
mean = (angles[..., ::2] + angles[..., 1::2]) / 2
if len(wires) == 1:
return [ # Squeeze away non-broadcasting axis (there is just one angle for RZ/GPhase
qml.GlobalPhase(-qml.math.squeeze(mean, axis=-1), wires=wires),
qml.RZ(qml.math.squeeze(diff, axis=-1), wires=wires),
]
return [ # Note that we use the first qubits as control, the reference uses the last qubits
qml.DiagonalQubitUnitary(np.exp(1j * mean), wires=wires[:-1]),
qml.SelectPauliRot(diff, control_wires=wires[:-1], target_wire=wires[-1]),
]

def adjoint(self) -> "DiagonalQubitUnitary":
return DiagonalQubitUnitary(qml.math.conj(self.parameters[0]), wires=self.wires)
Expand Down
97 changes: 42 additions & 55 deletions tests/ops/qubit/test_matrix_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,27 +641,24 @@ def test_decomposition_single_qubit(self):
decomp = qml.DiagonalQubitUnitary.compute_decomposition(D, [0])
decomp2 = qml.DiagonalQubitUnitary(D, wires=[0]).decomposition()

ph = np.exp(3j * np.pi / 4)
for dec in (decomp, decomp2):
assert len(dec) == 2
qml.assert_equal(decomp[0], qml.QubitUnitary(np.eye(2) * ph, 0))
qml.assert_equal(decomp[1], qml.RZ(np.pi / 2, 0))
qml.assert_equal(decomp[0], qml.RZ(np.pi / 2, 0))
qml.assert_equal(decomp[1], qml.GlobalPhase(-3 * np.pi / 4, 0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that the tests never actually test that the matrix of the original op and the matrix of the decomposed sequence are the same, can you add them?


def test_decomposition_single_qubit_broadcasted(self):
"""Test that a broadcasted single-qubit DiagonalQubitUnitary is decomposed correctly."""
D = np.stack(
[[1j, -1], np.exp(1j * np.array([np.pi / 8, -np.pi / 8])), [1j, -1j], [-1, -1]]
)
angles = np.array([np.pi / 2, -np.pi / 4, -np.pi, 0])
D = np.exp(1j * np.pi * np.array([[1 / 2, 1], [1 / 8, -1 / 8], [1 / 2, -1 / 2], [1, 1]]))

decomp = qml.DiagonalQubitUnitary.compute_decomposition(D, [0])
decomp2 = qml.DiagonalQubitUnitary(D, wires=[0]).decomposition()

ph = [np.exp(3j * np.pi / 4), 1, 1, -1]
angles = np.array([1 / 2, -1 / 4, -1, 0]) * np.pi
global_angles = np.array([3 / 4, 0, 0, 1]) * np.pi
for dec in (decomp, decomp2):
assert len(dec) == 2
qml.assert_equal(decomp[0], qml.QubitUnitary(np.array([np.eye(2) * p for p in ph]), 0))
qml.assert_equal(decomp[1], qml.RZ(angles, 0))
qml.assert_equal(decomp[0], qml.RZ(angles, 0))
qml.assert_equal(decomp[1], qml.GlobalPhase(-global_angles, 0))

def test_decomposition_two_qubits(self):
"""Test that a two-qubit DiagonalQubitUnitary is decomposed correctly."""
Expand All @@ -670,12 +667,13 @@ def test_decomposition_two_qubits(self):
decomp = qml.DiagonalQubitUnitary.compute_decomposition(D, [0, 1])
decomp2 = qml.DiagonalQubitUnitary(D, wires=[0, 1]).decomposition()

angles = np.array([-2, 0.5])
new_D = np.exp(1j * np.array([0, 3 / 4]))

for dec in (decomp, decomp2):
assert len(dec) == 4
qml.assert_equal(decomp[0], qml.QubitUnitary(np.eye(2) * np.exp(0.375j), 0))
qml.assert_equal(decomp[1], qml.RZ(-0.75, 1))
qml.assert_equal(decomp[2], qml.RZ(0.75, 0))
qml.assert_equal(decomp[3], qml.IsingZZ(-1.25, [0, 1]))
assert len(dec) == 2
qml.assert_equal(decomp[0], qml.SelectPauliRot(angles, [0], target_wire=1))
qml.assert_equal(decomp[1], qml.DiagonalQubitUnitary(new_D, wires=[0]))

def test_decomposition_two_qubits_broadcasted(self):
"""Test that a broadcasted two-qubit DiagonalQubitUnitary is decomposed correctly."""
Expand All @@ -684,14 +682,13 @@ def test_decomposition_two_qubits_broadcasted(self):
decomp = qml.DiagonalQubitUnitary.compute_decomposition(D, [0, 1])
decomp2 = qml.DiagonalQubitUnitary(D, wires=[0, 1]).decomposition()

angles = [[-0.75, -0.8, 0.65], [0.75, -2.4, -0.55], [-1.25, 0.4, -1.35]]
ph = [np.exp(1j * 0.375), np.exp(1j * 0.9), np.exp(1j * 0.475)]
angles = np.array([[-2, 0.5], [-0.4, -1.2], [-0.7, 2.0]])
new_D = np.exp(1j * np.array([[0, 3 / 4], [2.1, -0.3], [0.75, 0.2]]))

for dec in (decomp, decomp2):
assert len(dec) == 4
qml.assert_equal(decomp[0], qml.QubitUnitary(np.array([np.eye(2) * p for p in ph]), 0))
qml.assert_equal(decomp[1], qml.RZ(angles[0], 1))
qml.assert_equal(decomp[2], qml.RZ(angles[1], 0))
qml.assert_equal(decomp[3], qml.IsingZZ(angles[2], [0, 1]))
assert len(dec) == 2
qml.assert_equal(decomp[0], qml.SelectPauliRot(angles, [0], target_wire=1))
qml.assert_equal(decomp[1], qml.DiagonalQubitUnitary(new_D, wires=[0]))

def test_decomposition_three_qubits(self):
"""Test that a three-qubit DiagonalQubitUnitary is decomposed correctly."""
Expand All @@ -700,16 +697,12 @@ def test_decomposition_three_qubits(self):
decomp = qml.DiagonalQubitUnitary.compute_decomposition(D, [0, 1, 2])
decomp2 = qml.DiagonalQubitUnitary(D, wires=[0, 1, 2]).decomposition()

angles = np.array([-2, 0.5, -0.1, 1.7])
new_D = np.exp(1j * np.array([0, 3 / 4, 0.15, 1.45]))
for dec in (decomp, decomp2):
assert len(dec) == 8
qml.assert_equal(decomp[0], qml.QubitUnitary(np.eye(2) * np.exp(0.5875j), 0))
qml.assert_equal(decomp[1], qml.RZ(0.025, 2))
qml.assert_equal(decomp[2], qml.RZ(1.025, 1))
qml.assert_equal(decomp[3], qml.IsingZZ(-1.075, [1, 2]))
qml.assert_equal(decomp[4], qml.RZ(0.425, 0))
qml.assert_equal(decomp[5], qml.IsingZZ(-0.775, [0, 2]))
qml.assert_equal(decomp[6], qml.IsingZZ(-0.275, [0, 1]))
qml.assert_equal(decomp[7], qml.MultiRZ(-0.175, [0, 1, 2]))
assert len(dec) == 2
qml.assert_equal(decomp[0], qml.SelectPauliRot(angles, [0, 1], target_wire=2))
qml.assert_equal(decomp[1], qml.DiagonalQubitUnitary(new_D, wires=[0, 1]))

def test_decomposition_three_qubits_broadcasted(self):
"""Test that a broadcasted three-qubit DiagonalQubitUnitary is decomposed correctly."""
Expand All @@ -723,26 +716,12 @@ def test_decomposition_three_qubits_broadcasted(self):
decomp = qml.DiagonalQubitUnitary.compute_decomposition(D, [0, 1, 2])
decomp2 = qml.DiagonalQubitUnitary(D, wires=[0, 1, 2]).decomposition()

angles = [
[0.025, -0.55],
[1.025, -0.75],
[-1.075, -0.75],
[0.425, 0.3],
[-0.775, 0.4],
[-0.275, 0.5],
[-0.175, 0.1],
]
ph = [np.exp(0.5875j), np.exp(0.625j)]
angles = np.array([[-2, 0.5, -0.1, 1.7], [-0.8, 0.5, -1.8, -0.1]])
new_D = np.exp(1j * np.array([[0, 3 / 4, 0.15, 1.45], [0.6, 0.35, 1.4, 0.15]]))
for dec in (decomp, decomp2):
assert len(dec) == 8
qml.assert_equal(decomp[0], qml.QubitUnitary(np.array([np.eye(2) * p for p in ph]), 0))
qml.assert_equal(decomp[1], qml.RZ(angles[0], 2))
qml.assert_equal(decomp[2], qml.RZ(angles[1], 1))
qml.assert_equal(decomp[3], qml.IsingZZ(angles[2], [1, 2]))
qml.assert_equal(decomp[4], qml.RZ(angles[3], 0))
qml.assert_equal(decomp[5], qml.IsingZZ(angles[4], [0, 2]))
qml.assert_equal(decomp[6], qml.IsingZZ(angles[5], [0, 1]))
qml.assert_equal(decomp[7], qml.MultiRZ(angles[6], [0, 1, 2]))
assert len(dec) == 2
qml.assert_equal(decomp[0], qml.SelectPauliRot(angles, [0, 1], target_wire=2))
qml.assert_equal(decomp[1], qml.DiagonalQubitUnitary(new_D, wires=[0, 1]))

@pytest.mark.parametrize("n", [1, 2, 3])
def test_decomposition_matrix_match(self, n, seed):
Expand Down Expand Up @@ -775,7 +754,7 @@ def test_decomposition_matrix_match_broadcasted(self, n, seed):
assert qml.math.allclose(orig_mat, decomp_mat2)

@pytest.mark.parametrize(
"dtype", [np.float64, np.float32, np.int64, np.int32, np.complex128, np.complex64]
"dtype", [np.float64, np.float32, np.int64, np.int32, np.int16, np.complex128, np.complex64]
)
def test_decomposition_cast_to_complex128(self, dtype):
"""Test that the parameters of decomposed operations are of the correct dtype."""
Expand All @@ -784,10 +763,18 @@ def test_decomposition_cast_to_complex128(self, dtype):
decomp1 = qml.DiagonalQubitUnitary(D, wires).decomposition()
decomp2 = qml.DiagonalQubitUnitary.compute_decomposition(D, wires)

assert decomp1[0].data[0].dtype == np.complex128
assert decomp2[0].data[0].dtype == np.complex128
assert all(op.data[0].dtype == np.float64 for op in decomp1[1:])
assert all(op.data[0].dtype == np.float64 for op in decomp2[1:])
r_dtype = (
np.float64 if dtype in [np.float64, np.int64, np.int32, np.complex128] else np.float32
)
c_dtype = (
np.complex128
if dtype in [np.float64, np.int64, np.int32, np.complex128]
else np.complex64
)
assert decomp1[0].data[0].dtype == r_dtype
assert decomp2[0].data[0].dtype == r_dtype
assert decomp1[1].data[0].dtype == c_dtype
assert decomp2[1].data[0].dtype == c_dtype

def test_controlled(self):
"""Test that the correct controlled operation is created when controlling a qml.DiagonalQubitUnitary."""
Expand Down
Loading