Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 54 additions & 54 deletions demonstrations_v2/tutorial_shors_algorithm_catalyst/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ def shors_algorithm(N):
# compilation of the *entire* algorithm from beginning to end. On the surface, it
# looks to be as simple as the following:

import pennylane as qml
import pennylane as qp


@qml.qjit
@qp.qjit
def shors_algorithm(N):
# Implementation goes here
return p, q
Expand All @@ -188,7 +188,7 @@ def shors_algorithm(N):
# important parts such that the signature can be as minimal as this:


@qml.qjit(autograph=True, static_argnums=(1))
@qp.qjit(autograph=True, static_argnums=(1))
def shors_algorithm(N, n_bits):
# Implementation goes here
return p, q
Expand Down Expand Up @@ -667,7 +667,7 @@ def phase_to_order(phase, max_denominator):
# Below, we have the implementations of the arithmetic circuits derived in the
# previous section.

import pennylane as qml
import pennylane as qp
import catalyst
from catalyst import measure

Expand All @@ -679,10 +679,10 @@ def QFT(wires):
shifts = jnp.array([2 * jnp.pi * 2**-i for i in range(2, len(wires) + 1)])

for i in range(len(wires)):
qml.Hadamard(wires[i])
qp.Hadamard(wires[i])

for j in range(len(shifts) - i):
qml.ControlledPhaseShift(shifts[j], wires=[wires[(i + 1) + j], wires[i]])
qp.ControlledPhaseShift(shifts[j], wires=[wires[(i + 1) + j], wires[i]])


def fourier_adder_phase_shift(a, wires):
Expand All @@ -694,30 +694,30 @@ def fourier_adder_phase_shift(a, wires):

for i in range(len(wires)):
if phases[i] != 0:
qml.PhaseShift(2 * jnp.pi * phases[i], wires=wires[i])
qp.PhaseShift(2 * jnp.pi * phases[i], wires=wires[i])


def doubly_controlled_adder(N, a, control_wires, wires, aux_wire):
"""Sends |c>|x>QFT(|b>)|0> -> |c>|x>QFT(|b + c x a) mod N>)|0>."""
qml.ctrl(fourier_adder_phase_shift, control=control_wires)(a, wires)
qp.ctrl(fourier_adder_phase_shift, control=control_wires)(a, wires)

qml.adjoint(fourier_adder_phase_shift)(N, wires)
qp.adjoint(fourier_adder_phase_shift)(N, wires)

qml.adjoint(QFT)(wires)
qml.CNOT(wires=[wires[0], aux_wire])
qp.adjoint(QFT)(wires)
qp.CNOT(wires=[wires[0], aux_wire])
QFT(wires)

qml.ctrl(fourier_adder_phase_shift, control=aux_wire)(N, wires)
qp.ctrl(fourier_adder_phase_shift, control=aux_wire)(N, wires)

qml.adjoint(qml.ctrl(fourier_adder_phase_shift, control=control_wires))(a, wires)
qp.adjoint(qp.ctrl(fourier_adder_phase_shift, control=control_wires))(a, wires)

qml.adjoint(QFT)(wires)
qml.PauliX(wires=wires[0])
qml.CNOT(wires=[wires[0], aux_wire])
qml.PauliX(wires=wires[0])
qp.adjoint(QFT)(wires)
qp.PauliX(wires=wires[0])
qp.CNOT(wires=[wires[0], aux_wire])
qp.PauliX(wires=wires[0])
QFT(wires)

qml.ctrl(fourier_adder_phase_shift, control=control_wires)(a, wires)
qp.ctrl(fourier_adder_phase_shift, control=control_wires)(a, wires)


def controlled_ua(N, a, control_wire, target_wires, aux_wires, mult_a_mask, mult_a_inv_mask):
Expand All @@ -735,14 +735,14 @@ def controlled_ua(N, a, control_wire, target_wires, aux_wires, mult_a_mask, mult
N, pow_a, [control_wire, target_wires[n - i - 1]], aux_wires[:-1], aux_wires[-1]
)

qml.adjoint(QFT)(wires=aux_wires[:-1])
qp.adjoint(QFT)(wires=aux_wires[:-1])

# Controlled SWAP the target and aux wires; note that the top-most aux wire
# is only to catch overflow, so we ignore it here.
for i in range(n):
qml.CNOT(wires=[aux_wires[i + 1], target_wires[i]])
qml.Toffoli(wires=[control_wire, target_wires[i], aux_wires[i + 1]])
qml.CNOT(wires=[aux_wires[i + 1], target_wires[i]])
qp.CNOT(wires=[aux_wires[i + 1], target_wires[i]])
qp.Toffoli(wires=[control_wire, target_wires[i], aux_wires[i + 1]])
qp.CNOT(wires=[aux_wires[i + 1], target_wires[i]])

# Adjoint of controlled multiplication with the modular inverse of a
a_mod_inv = modular_inverse(a, N)
Expand All @@ -752,7 +752,7 @@ def controlled_ua(N, a, control_wire, target_wires, aux_wires, mult_a_mask, mult
for i in range(n):
if mult_a_inv_mask[i] > 0:
pow_a_inv = (a_mod_inv * (2 ** (n - i - 1))) % N
qml.adjoint(doubly_controlled_adder)(
qp.adjoint(doubly_controlled_adder)(
N,
pow_a_inv,
[control_wire, target_wires[i]],
Expand All @@ -763,14 +763,14 @@ def controlled_ua(N, a, control_wire, target_wires, aux_wires, mult_a_mask, mult

######################################################################
# Finally, let's put it all together in the core portion of Shor's algorithm,
# under the ``@qml.qjit`` decorator.
# under the ``@qp.qjit`` decorator.

from jax import random


# ``n_bits`` is a static argument because ``jnp.arange`` does not currently
# support dynamically-shaped arrays when jitted.
@qml.qjit(autograph=True, static_argnums=(3))
@qp.qjit(autograph=True, static_argnums=(3))
def shors_algorithm(N, key, a, n_bits, n_trials):
# If no explicit a is passed (denoted by a = 0), randomly choose a
# non-trivial value of a that does not have a common factor with N.
Expand All @@ -783,10 +783,10 @@ def shors_algorithm(N, key, a, n_bits, n_trials):
target_wires = jnp.arange(n_bits) + 1
aux_wires = jnp.arange(n_bits + 2) + n_bits + 1

dev = qml.device("lightning.qubit", wires=2 * n_bits + 3)
dev = qp.device("lightning.qubit", wires=2 * n_bits + 3)

@qml.set_shots(1)
@qml.qnode(dev)
@qp.set_shots(1)
@qp.qnode(dev)
def run_qpe():
meas_results = jnp.zeros((n_bits,), dtype=jnp.int32)
cumulative_phase = jnp.array(0.0)
Expand All @@ -799,22 +799,22 @@ def run_qpe():
a_inv_mask = a_mask

# Initialize the target register of QPE in |1>
qml.PauliX(wires=target_wires[-1])
qp.PauliX(wires=target_wires[-1])

# The "first" QFT on the auxiliary register; required here because
# QFT are largely omitted in the modular arithmetic routines due to
# cancellation between adjacent blocks of the algorithm.
QFT(wires=aux_wires[:-1])

# First iteration: add a - 1 using the Fourier adder.
qml.Hadamard(wires=est_wire)
qp.Hadamard(wires=est_wire)

QFT(wires=target_wires)
qml.ctrl(fourier_adder_phase_shift, control=est_wire)(a - 1, target_wires)
qml.adjoint(QFT)(wires=target_wires)
qp.ctrl(fourier_adder_phase_shift, control=est_wire)(a - 1, target_wires)
qp.adjoint(QFT)(wires=target_wires)

# Measure the estimation wire and store the phase
qml.Hadamard(wires=est_wire)
qp.Hadamard(wires=est_wire)
meas_results[0] = measure(est_wire, reset=True)
cumulative_phase = -2 * jnp.pi * jnp.sum(meas_results / jnp.roll(phase_divisors, 1))

Expand All @@ -838,23 +838,23 @@ def run_qpe():
jnp.unpackbits(next_pow_a.view("uint8"), bitorder="little")[:n_bits]
)

qml.Hadamard(wires=est_wire)
qp.Hadamard(wires=est_wire)

controlled_ua(N, pow_cua, est_wire, target_wires, aux_wires, a_mask, a_inv_mask)

a_mask = a_mask + a_inv_mask
a_inv_mask = jnp.zeros_like(a_inv_mask)

# Rotate the estimation wire by the accumulated phase, then measure and reset it
qml.PhaseShift(cumulative_phase, wires=est_wire)
qml.Hadamard(wires=est_wire)
qp.PhaseShift(cumulative_phase, wires=est_wire)
qp.Hadamard(wires=est_wire)
meas_results[pow_a_idx] = measure(est_wire, reset=True)
cumulative_phase = (
-2 * jnp.pi * jnp.sum(meas_results / jnp.roll(phase_divisors, pow_a_idx + 1))
)

# The adjoint partner of the QFT from the beginning, to reset the auxiliary register
qml.adjoint(QFT)(wires=aux_wires[:-1])
qp.adjoint(QFT)(wires=aux_wires[:-1])

return meas_results

Expand Down Expand Up @@ -965,19 +965,19 @@ def run_qpe():
# circuits for many different :math:`N` using both the QJIT version, and the
# plain PennyLane version below. Note the standard PennyLane version makes use
# of many of the same subroutines and optimizations, but due to limitations on
# how PennyLane handles mid-circuit measurements, we must use ``qml.cond`` and
# explicit ``qml.PhaseShift`` gates.
# how PennyLane handles mid-circuit measurements, we must use ``qp.cond`` and
# explicit ``qp.PhaseShift`` gates.


def shors_algorithm_no_qjit(N, key, a, n_bits, n_trials):
est_wire = 0
target_wires = list(range(1, n_bits + 1))
aux_wires = list(range(n_bits + 1, 2 * n_bits + 3))

dev = qml.device("lightning.qubit", wires=2 * n_bits + 3)
dev = qp.device("lightning.qubit", wires=2 * n_bits + 3)

@qml.set_shots(1)
@qml.qnode(dev)
@qp.set_shots(1)
@qp.qnode(dev)
def run_qpe():
a_mask = jnp.zeros(n_bits, dtype=jnp.int64)
a_mask = a_mask.at[0].set(1) + jnp.array(
Expand All @@ -987,18 +987,18 @@ def run_qpe():

measurements = []

qml.PauliX(wires=target_wires[-1])
qp.PauliX(wires=target_wires[-1])

QFT(wires=aux_wires[:-1])

qml.Hadamard(wires=est_wire)
qp.Hadamard(wires=est_wire)

QFT(wires=target_wires)
qml.ctrl(fourier_adder_phase_shift, control=est_wire)(a - 1, target_wires)
qml.adjoint(QFT)(wires=target_wires)
qp.ctrl(fourier_adder_phase_shift, control=est_wire)(a - 1, target_wires)
qp.adjoint(QFT)(wires=target_wires)

qml.Hadamard(wires=est_wire)
measurements.append(qml.measure(est_wire, reset=True))
qp.Hadamard(wires=est_wire)
measurements.append(qp.measure(est_wire, reset=True))

powers_cua = jnp.array([repeated_squaring(a, 2**p, N) for p in range(n_bits)])

Expand All @@ -1016,7 +1016,7 @@ def run_qpe():
jnp.unpackbits(next_pow_a.view("uint8"), bitorder="little")[:n_bits]
)

qml.Hadamard(wires=est_wire)
qp.Hadamard(wires=est_wire)

controlled_ua(N, pow_cua, est_wire, target_wires, aux_wires, a_mask, a_inv_mask)

Expand All @@ -1025,16 +1025,16 @@ def run_qpe():

# The main difference with the QJIT version
for meas_idx, meas in enumerate(measurements):
qml.cond(meas, qml.PhaseShift)(
qp.cond(meas, qp.PhaseShift)(
-2 * jnp.pi / 2 ** (pow_a_idx + 2 - meas_idx), wires=est_wire
)

qml.Hadamard(wires=est_wire)
measurements.append(qml.measure(est_wire, reset=True))
qp.Hadamard(wires=est_wire)
measurements.append(qp.measure(est_wire, reset=True))

qml.adjoint(QFT)(wires=aux_wires[:-1])
qp.adjoint(QFT)(wires=aux_wires[:-1])

return qml.sample(measurements)
return qp.sample(measurements)

p, q = jnp.array(0, dtype=jnp.int32), jnp.array(0, dtype=jnp.int32)
successful_trials = jnp.array(0, dtype=jnp.int32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"executable_stable": true,
"executable_latest": true,
"dateOfPublication": "2025-04-04T09:00:00+00:00",
"dateOfLastModification": "2025-12-10T00:00:00+00:00",
"dateOfLastModification": "2026-04-17T00:00:00+00:00",
"categories": [
"Algorithms",
"Quantum Computing",
Expand Down
Loading
Loading