Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
727b88d
alloc, dealloc and compbasis
paul0403 Apr 8, 2026
cd4fb6b
py format
paul0403 Apr 8, 2026
d37d044
cpp format
paul0403 Apr 8, 2026
57dba08
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 8, 2026
655fb8a
tests on compbasis
paul0403 Apr 8, 2026
f5cc463
namedobs and hermitian; some tests
paul0403 Apr 8, 2026
1715398
more tests
paul0403 Apr 9, 2026
cec77df
docstrings on tests
paul0403 Apr 9, 2026
e6df156
changelog
paul0403 Apr 9, 2026
fbe62d6
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 9, 2026
cd07d03
unused imports
paul0403 Apr 9, 2026
83e53f7
[qref 3.2] Migrate gate-like ops' plxpr conversion to reference seman…
paul0403 Apr 9, 2026
1f0b496
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 10, 2026
88847c1
Merge branch 'paul0403/qref_frontend_alloc_and_obs' into paul0403/qre…
paul0403 Apr 10, 2026
87fdd5c
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 10, 2026
0b156ed
Merge branch 'paul0403/qref_frontend_alloc_and_obs' into paul0403/qre…
paul0403 Apr 10, 2026
6a24da4
add verify no quantum use after free pass to pipeline
paul0403 Apr 13, 2026
dcf8d73
add some tests for custom op
paul0403 Apr 13, 2026
e14e74c
multirz
paul0403 Apr 13, 2026
316f918
pcphase
paul0403 Apr 13, 2026
a346b33
paulirot op
paul0403 Apr 13, 2026
40360b2
gphase
paul0403 Apr 13, 2026
747542e
unitary
paul0403 Apr 13, 2026
0e3ed72
state prep
paul0403 Apr 13, 2026
0361f0f
set basis state
paul0403 Apr 13, 2026
74ab31f
measure op
paul0403 Apr 13, 2026
696bbbb
add tests for all gates inside a dynamic allocation
paul0403 Apr 14, 2026
9becb2a
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 14, 2026
cfb634e
Merge branch 'paul0403/qref_frontend_alloc_and_obs' into paul0403/qre…
paul0403 Apr 14, 2026
7edcd57
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 15, 2026
c092fd6
Merge branch 'paul0403/qref_frontend_alloc_and_obs' into paul0403/qre…
paul0403 Apr 15, 2026
3e75f6c
burn recorder and handler
paul0403 Apr 15, 2026
457c403
distinguish static vs dynamic number of qubits in alloc primitive
paul0403 Apr 15, 2026
9cc00a7
Merge branch 'paul0403/qref_frontend_alloc_and_obs' into paul0403/qre…
paul0403 Apr 15, 2026
ff4e754
dyn alloc kwarg name
paul0403 Apr 15, 2026
d00e261
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 15, 2026
abb9e25
CI?
paul0403 Apr 15, 2026
0749dd5
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 15, 2026
8352228
update temp test script
paul0403 Apr 16, 2026
6f667c2
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 16, 2026
8e46cf8
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 16, 2026
ee24eb1
fix CI
paul0403 Apr 16, 2026
f7c89bb
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 16, 2026
57b9369
Merge branch 'paul0403/qref_frontend_alloc_and_obs' into paul0403/qre…
paul0403 Apr 16, 2026
fabd503
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 16, 2026
f87b762
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 17, 2026
4670842
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 17, 2026
ab6be06
changelog number
paul0403 Apr 17, 2026
e72cc6f
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 20, 2026
2a5a30c
[qref 3.5] Migrate adjoint op's plxpr conversion to reference semantics
paul0403 Apr 21, 2026
2b824d7
.
paul0403 Apr 21, 2026
2345e58
changelog number
paul0403 Apr 21, 2026
5a462a9
basic adjoint test
paul0403 Apr 21, 2026
c256800
test for alloc (outside adjoint)
paul0403 Apr 21, 2026
7e8199b
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 21, 2026
71329fc
Merge branch 'paul0403/qref_frontend_gates' into paul0403/qref_fronte…
paul0403 Apr 21, 2026
173caad
add test for allocate inside an adjoint
paul0403 Apr 21, 2026
2c3bd3f
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 21, 2026
d638fd8
Merge branch 'paul0403/qref_frontend_gates' into paul0403/qref_fronte…
paul0403 Apr 21, 2026
72fe9df
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 22, 2026
8cd8501
update tests to new get op constant folder
paul0403 Apr 22, 2026
ac6c262
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 22, 2026
fa7acd4
Merge branch 'paul0403/qref_frontend_gates' into paul0403/qref_fronte…
paul0403 Apr 22, 2026
1c1d917
update test for new const folder on get op
paul0403 Apr 22, 2026
f39c4b2
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 23, 2026
0e52909
just lower pl adjoint prim
paul0403 Apr 23, 2026
df7d6b7
add test with loop+adjoint
paul0403 Apr 23, 2026
303233b
pylint
paul0403 Apr 23, 2026
53aa090
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 24, 2026
3093c3f
update some adjoint tests in test_from_plxpr.py
paul0403 Apr 24, 2026
98e5ea4
add test for adjoint on paulirot
paul0403 Apr 24, 2026
29a152e
unused import
paul0403 Apr 24, 2026
61d25cf
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 25, 2026
2a65f1f
adjoint test for dynamic qubit allocation is no longer xfail
paul0403 Apr 27, 2026
3a26cf1
Merge branch 'paul0403/qref_frontend_main' into paul0403/qref_fronten…
paul0403 Apr 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@
[(#2672)](https://github.com/PennyLaneAI/catalyst/pull/2672)
[(#2694)](https://github.com/PennyLaneAI/catalyst/pull/2694)
[(#2717)](https://github.com/PennyLaneAI/catalyst/pull/2717)
[(#2720)](https://github.com/PennyLaneAI/catalyst/pull/2720)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self reminder to add ctrl+adjoint tests


* Removed the `condition` operand from `pbc.ppm` (Pauli Product Measurement) operations.
Conditional PPR decompositions in the `decompose-clifford-ppr` pass now emit the
Expand Down
76 changes: 5 additions & 71 deletions frontend/catalyst/from_plxpr/qfunc_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,8 @@
import jax.numpy as jnp
import pennylane as qp
from jax._src.sharding_impls import UNSPECIFIED
from jax._src.tree_util import tree_flatten
from jax.extend.core import ClosedJaxpr
from jax.interpreters.partial_eval import DynamicJaxprTracer, convert_constvars_jaxpr
from jax.interpreters.partial_eval import DynamicJaxprTracer
from pennylane.capture import PlxprInterpreter, pause
from pennylane.capture.primitives import adjoint_transform_prim as plxpr_adjoint_transform_prim
from pennylane.capture.primitives import ctrl_transform_prim as plxpr_ctrl_transform_prim
from pennylane.capture.primitives import measure_prim as plxpr_measure_prim
from pennylane.capture.primitives import pauli_measure_prim as plxpr_pauli_measure_prim
Expand Down Expand Up @@ -56,7 +53,6 @@
from catalyst.jax_primitives import (
AbstractQbit,
MeasurementPlane,
adjoint_p,
cond_p,
counts_p,
decomprule_p,
Expand Down Expand Up @@ -229,7 +225,10 @@ def _check_measurement_with_dynamic_allocation(self, measurement):
dynamically allocated wires are present in the program.
"""))

if any(is_dynamically_allocated_wire(w) for w in measurement.wires):
if any(
isinstance(w, DynamicJaxprTracer) and isinstance(w.val.aval, QrefQubit)
for w in measurement.wires
):
raise CompileError(textwrap.dedent("""
Terminal measurements cannot take in dynamically allocated wires
since they must be temporary.
Expand Down Expand Up @@ -762,71 +761,6 @@ def handle_ctrl_transform(self, *invals, jaxpr, n_control, control_values, work_
return []


# pylint: disable=unused-argument
@PLxPRToQuantumJaxprInterpreter.register_primitive(plxpr_adjoint_transform_prim)
def handle_adjoint_transform(
self,
*plxpr_invals,
jaxpr,
lazy,
n_consts,
):
"""Handle the conversion from plxpr to Catalyst jaxpr for the adjoint primitive"""

if any(is_dynamically_allocated_wire(arg) for arg in plxpr_invals):
raise NotImplementedError(
"Dynamically allocated wires cannot be used in quantum adjoints yet."
)

assert jaxpr is not None
consts = plxpr_invals[:n_consts]
args = plxpr_invals[n_consts:]

# Add the iteration start and the qreg to the args
self.init_qreg.insert_all_dangling_qubits()
qreg = self.init_qreg.get()

jaxpr = ClosedJaxpr(jaxpr, consts)

def calling_convention(*args_plus_qreg):
# The last arg is the scope argument for the body jaxpr
*args, qreg = args_plus_qreg

# Launch a new interpreter for the body region
# A new interpreter's root qreg value needs a new recorder
converter = copy(self)
converter.qubit_index_recorder = QubitIndexRecorder()
init_qreg = QubitHandler(qreg, converter.qubit_index_recorder)
converter.init_qreg = init_qreg

retvals = converter(jaxpr, *args)
init_qreg.insert_all_dangling_qubits()
return *retvals, converter.init_qreg.get()

converted_jaxpr_branch = jax.make_jaxpr(calling_convention)(*args, qreg)

converted_closed_jaxpr_branch = ClosedJaxpr(
convert_constvars_jaxpr(converted_jaxpr_branch.jaxpr), ()
)
new_consts = converted_jaxpr_branch.consts
_, args_tree = tree_flatten((new_consts, args, [qreg]))
# Perform the binding
outvals = adjoint_p.bind(
*new_consts,
*args,
qreg,
jaxpr=converted_closed_jaxpr_branch,
args_tree=args_tree,
)

# We assume the last output value is the returned qreg.
# Update the current qreg and remove it from the output values.
self.init_qreg.set(outvals.pop())

# Return only the output values that match the plxpr output values
return outvals


@PLxPRToQuantumJaxprInterpreter.register_primitive(transform_prim)
def _error_on_transform(*args, **kwargs):
raise NotImplementedError("transforms cannot currently be applied inside a QNode.")
Expand Down
35 changes: 35 additions & 0 deletions frontend/catalyst/from_plxpr/qref_jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
of quantum operations, measurements, and observables to reference semantics JAXPR.
"""

from jax._src import source_info_util
from jax._src.lib.mlir import ir
from jax.core import AbstractValue, ShapedArray
from jax.extend.core import Primitive
Expand All @@ -25,6 +26,7 @@
ExtUIOp,
)
from jaxlib.mlir.dialects.stablehlo import ConvertOp as StableHLOConvertOp
from pennylane.capture.primitives import adjoint_transform_prim as plxpr_adjoint_transform_prim

# TODO: remove after jax v0.7.2 upgrade
# Mock _ods_cext.globals.register_traceback_file_exclusion due to API conflicts between
Expand Down Expand Up @@ -54,6 +56,7 @@
),
):
from mlir_quantum.dialects.qref import (
AdjointOp,
AllocOp,
ComputationalBasisOp,
CustomOp,
Expand Down Expand Up @@ -549,6 +552,37 @@ def _qref_unitary_lowering(
return ()


#
# PL adjoint primitive
#
# pylint: disable=unused-argument
def _pl_adjoint_lowering(
jax_ctx,
*plxpr_invals,
jaxpr,
lazy,
n_consts,
):
new_jaxpr = jaxpr.replace(constvars=(), invars=jaxpr.constvars + jaxpr.invars)

op = AdjointOp()
adjoint_block = op.regions[0].blocks.append()
with ir.InsertionPoint(adjoint_block):
source_info_util.extend_name_stack("adjoint")
_, _ = mlir.jaxpr_subcomp(
jax_ctx.module_context,
new_jaxpr,
jax_ctx.name_stack.extend("adjoint"),
mlir.TokenSet(),
[],
*plxpr_invals,
dim_var_values=jax_ctx.dim_var_values,
const_lowering=jax_ctx.const_lowering,
)

return ()


#
# measure
#
Expand Down Expand Up @@ -685,4 +719,5 @@ def _qref_hermitian_lowering(jax_ctx: mlir.LoweringRuleContext, matrix: ir.Value
(qref_compbasis_p, _qref_compbasis_lowering),
(qref_namedobs_p, _qref_named_obs_lowering),
(qref_hermitian_p, _qref_hermitian_lowering),
(plxpr_adjoint_transform_prim, _pl_adjoint_lowering),
)
2 changes: 1 addition & 1 deletion frontend/catalyst/jax_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,7 @@ def _trace_classical_phase(
# with an extra computational cost

if any(isinstance(wire, qml.wires.DynamicWire) for wire in quantum_tape.wires):
msg = "qml.allocate() with qjit is only supported with program capture enabled."
msg = "qp.allocate() with qjit is only supported with program capture enabled."
raise CompileError(msg)

# 1. Recompute the original return
Expand Down
40 changes: 40 additions & 0 deletions frontend/test/lit/test_qref/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,43 @@ def test_while_loop_with_dynamic_allocation(i: int):


print(test_while_loop_with_dynamic_allocation.mlir)


# CHECK: func.func public @test_loop_with_adjoint() -> tensor<f64>
@qp.qjit(capture=True, autograph=True, target="mlir")
@qp.qnode(qp.device("null.qubit", wires=3))
def test_loop_with_adjoint():
"""
Test loops when adjoints are present.
"""
# CHECK: [[one_index:%.+]] = arith.constant 1 : index
# CHECK: [[ten_index:%.+]] = arith.constant 10 : index
# CHECK: [[zero_index:%.+]] = arith.constant 0 : index

# CHECK: [[reg_device:%.+]] = qref.alloc( 3) : !qref.reg<3>

def f(q1: int, q2: int, n: int):
for _ in range(n):
qp.SWAP(wires=[q1, q2])

# CHECK: [[reg_alloc:%.+]] = qref.alloc( 2) : !qref.reg<2>
# CHECK: [[q1_alloc:%.+]] = qref.get [[reg_alloc]][ 1] : !qref.reg<2> -> !qref.bit
with qp.allocate(2) as q:

# CHECK: scf.for %arg0 = [[zero_index]] to [[ten_index]] step [[one_index]] {
# CHECK: qref.adjoint {
# CHECK: scf.for %arg1 = [[zero_index]] to %arg0 step [[one_index]] {
# CHECK: [[q0:%.+]] = qref.get [[reg_device]][ 0] : !qref.reg<3> -> !qref.bit
# CHECK: qref.custom "SWAP"() [[q1_alloc]], [[q0]] : !qref.bit, !qref.bit
# CHECK: }
# CHECK: }
# CHECK: }
for n in range(10):
qp.adjoint(f)(q[1], 0, n)

# qref.dealloc [[reg_alloc]] : !qref.reg<2>

return qp.expval(qp.X(0))


print(test_loop_with_adjoint.mlir)
102 changes: 102 additions & 0 deletions frontend/test/lit/test_qref/test_flat_circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,105 @@ def test_set_basis_state():


print(test_set_basis_state.mlir)


# CHECK: func.func public @test_adjoint(%arg0: tensor<i64>) -> tensor<f64>
@qp.qjit(capture=True, target="mlir")
@qp.qnode(qp.device("null.qubit", wires=4))
def test_adjoint(i: int):
"""
Test adjoint
"""
# CHECK-DAG: [[angle:%.+]] = arith.constant 1.000000e-01 : f64
# CHECK-DAG: [[angle_adj:%.+]] = arith.constant -1.000000e-01 : f64

# CHECK: [[reg:%.+]] = qref.alloc( 4) : !qref.reg<4>

# CHECK: qref.adjoint {
# CHECK: [[q0:%.+]] = qref.get [[reg]][ 0] : !qref.reg<4> -> !qref.bit
# CHECK: [[i:%.+]] = tensor.extract %arg0[] : tensor<i64>
# CHECK: [[qi:%.+]] = qref.get [[reg]][[[i]]] : !qref.reg<4>, i64 -> !qref.bit
# CHECK: qref.custom "CNOT"() [[q0]], [[qi]] : !qref.bit, !qref.bit
# CHECK: }
qp.adjoint(qp.CNOT)(wires=[0, i])

# CHECK: [[q0:%.+]] = qref.get [[reg]][ 0] : !qref.reg<4> -> !qref.bit
# CHECK: qref.custom "RX"([[angle_adj]]) [[q0]] : !qref.bit
qp.adjoint(qp.RX(0.1, wires=0))

# CHECK: [[q0:%.+]] = qref.get [[reg]][ 0] : !qref.reg<4> -> !qref.bit
# CHECK: qref.paulirot ["X"]([[angle]]) [[q0]] adj : !qref.bit
qp.adjoint(qp.PauliRot(0.1, ["X"], 0))

def f(wires):
qp.X(wires)

# CHECK: qref.adjoint {
# CHECK: [[q0:%.+]] = qref.get [[reg]][ 0] : !qref.reg<4> -> !qref.bit
# CHECK: qref.custom "PauliX"() [[q0]] : !qref.bit
# CHECK: }
qp.adjoint(f)(0)

# CHECK: qref.adjoint {
# CHECK: [[i:%.+]] = tensor.extract %arg0[] : tensor<i64>
# CHECK: [[qi:%.+]] = qref.get [[reg]][[[i]]] : !qref.reg<4>, i64 -> !qref.bit
# CHECK: qref.custom "PauliX"() [[qi]] : !qref.bit
# CHECK: }
qp.adjoint(f)(i)

return qp.expval(qp.X(0))


print(test_adjoint.mlir)


# CHECK: func.func public @test_adjoint_with_allocation() -> tensor<f64>
@qp.qjit(capture=True, target="mlir")
@qp.qnode(qp.device("null.qubit", wires=4))
def test_adjoint_with_allocation():
"""
Test adjoint with dynamic qubit allocation
"""
# CHECK-DAG: [[angle:%.+]] = arith.constant 1.000000e-01 : f64

# CHECK-DAG: [[reg_device:%.+]] = qref.alloc( 4) : !qref.reg<4>

def f(wires):
qp.RX(0.1, wires)

# CHECK: [[reg_alloc:%.+]] = qref.alloc( 2) : !qref.reg<2>
# CHECK: [[alloc_q0:%.+]] = qref.get [[reg_alloc]][ 0] : !qref.reg<2> -> !qref.bit
# CHECK: [[alloc_q1:%.+]] = qref.get [[reg_alloc]][ 1] : !qref.reg<2> -> !qref.bit
with qp.allocate(2) as q:

# CHECK: qref.adjoint {
# CHECK: qref.custom "PauliX"() [[alloc_q0]] : !qref.bit
# CHECK: }
qp.adjoint(qp.X)(q[0])

# CHECK: qref.adjoint {
# CHECK: qref.custom "RX"([[angle]]) [[alloc_q1]] : !qref.bit
# CHECK: }
qp.adjoint(f)(q[1])
# CHECK: qref.dealloc [[reg_alloc]] : !qref.reg<2>

# CHECK: qref.adjoint {
# CHECK: [[reg_alloc:%.+]] = qref.alloc( 2) : !qref.reg<2>
# CHECK: [[alloc_q0:%.+]] = qref.get [[reg_alloc]][ 0] : !qref.reg<2> -> !qref.bit
# CHECK: [[alloc_q1:%.+]] = qref.get [[reg_alloc]][ 1] : !qref.reg<2> -> !qref.bit
# CHECK: qref.custom "PauliX"() [[alloc_q0]] : !qref.bit
# CHECK: [[q0:%.+]] = qref.get [[reg_device]][ 0] : !qref.reg<4> -> !qref.bit
# CHECK: qref.custom "CNOT"() [[q0]], [[alloc_q1]] : !qref.bit, !qref.bit
# CHECK: qref.dealloc [[reg_alloc]] : !qref.reg<2>
# CHECK: }
def g():
with qp.allocate(2) as q:
qp.X(q[0])
qp.CNOT(wires=[0, q[1]])

qp.adjoint(g)()

return qp.expval(qp.X(0))


print(test_adjoint_with_allocation.mlir)
24 changes: 8 additions & 16 deletions frontend/test/pytest/from_plxpr/test_from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
import pennylane as qp
import pytest
from pennylane.capture.primitives import for_loop_prim, while_loop_prim
from pennylane.capture.primitives import adjoint_transform_prim, for_loop_prim, while_loop_prim

import catalyst
from catalyst.from_plxpr import from_plxpr
Expand All @@ -28,10 +28,6 @@
qref_get_p,
qref_qinst_p,
)
from catalyst.jax_primitives import (
adjoint_p,
qinsert_p,
)

pytestmark = pytest.mark.usefixtures("disable_capture")

Expand Down Expand Up @@ -441,8 +437,8 @@ def c():
catalyst_xpr = from_plxpr(plxpr)()
qfunc_xpr = catalyst_xpr.eqns[0].params["call_jaxpr"]

assert qfunc_xpr.eqns[-6].primitive == qref_qinst_p
assert qfunc_xpr.eqns[-6].params == {
assert qfunc_xpr.eqns[-5].primitive == qref_qinst_p
assert qfunc_xpr.eqns[-5].params == {
"adjoint": num_adjoints % 2 == 1,
"ctrl_len": 0,
"op": "S",
Expand Down Expand Up @@ -546,14 +542,12 @@ def c(x):
assert qfunc_xpr.eqns[1].primitive == qref_alloc_p
assert qfunc_xpr.eqns[2].primitive == qref_get_p
assert qfunc_xpr.eqns[3].primitive == qref_qinst_p
assert qfunc_xpr.eqns[4].primitive == qinsert_p

eqn = qfunc_xpr.eqns[5]
assert eqn.primitive == adjoint_p
assert eqn.invars[0] == qfunc_xpr.invars[0] # x
assert eqn.invars[1] == qfunc_xpr.eqns[4].outvars[0] # the qreg
assert eqn.outvars[0] == qfunc_xpr.eqns[6].invars[0] # also the qreg
assert len(eqn.outvars) == 1
eqn = qfunc_xpr.eqns[4]
assert eqn.primitive == adjoint_transform_prim
assert eqn.invars[0] == qfunc_xpr.eqns[1].outvars[0] # the qreg, as a closure variable
assert eqn.invars[1] == qfunc_xpr.invars[0] # x
assert len(eqn.outvars) == 0

target_xpr = eqn.params["jaxpr"]
assert target_xpr.eqns[1].primitive == qref_get_p
Expand All @@ -566,8 +560,6 @@ def c(x):
"params_len": 1,
"qubits_len": 2,
}
assert target_xpr.eqns[4].primitive == qinsert_p
assert target_xpr.eqns[5].primitive == qinsert_p

@pytest.mark.parametrize("as_qfunc", (True, False))
def test_dynamic_control_wires(self, as_qfunc):
Expand Down
Loading
Loading