diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 5912d4362f..c6239e5009 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -761,6 +761,7 @@ [(#2672)](https://github.com/PennyLaneAI/catalyst/pull/2672) [(#2694)](https://github.com/PennyLaneAI/catalyst/pull/2694) [(#2717)](https://github.com/PennyLaneAI/catalyst/pull/2717) + [(#2720)](https://github.com/PennyLaneAI/catalyst/pull/2720) * Removed the `condition` operand from `pbc.ppm` (Pauli Product Measurement) operations. Conditional PPR decompositions in the `decompose-clifford-ppr` pass now emit the diff --git a/frontend/catalyst/from_plxpr/qfunc_interpreter.py b/frontend/catalyst/from_plxpr/qfunc_interpreter.py index e0ab12ca1f..c25f4c54b9 100644 --- a/frontend/catalyst/from_plxpr/qfunc_interpreter.py +++ b/frontend/catalyst/from_plxpr/qfunc_interpreter.py @@ -24,11 +24,8 @@ import jax.numpy as jnp import pennylane as qp from jax._src.sharding_impls import UNSPECIFIED -from jax._src.tree_util import tree_flatten -from jax.extend.core import ClosedJaxpr -from jax.interpreters.partial_eval import DynamicJaxprTracer, convert_constvars_jaxpr +from jax.interpreters.partial_eval import DynamicJaxprTracer from pennylane.capture import PlxprInterpreter, pause -from pennylane.capture.primitives import adjoint_transform_prim as plxpr_adjoint_transform_prim from pennylane.capture.primitives import ctrl_transform_prim as plxpr_ctrl_transform_prim from pennylane.capture.primitives import measure_prim as plxpr_measure_prim from pennylane.capture.primitives import pauli_measure_prim as plxpr_pauli_measure_prim @@ -56,7 +53,6 @@ from catalyst.jax_primitives import ( AbstractQbit, MeasurementPlane, - adjoint_p, cond_p, counts_p, decomprule_p, @@ -229,7 +225,10 @@ def _check_measurement_with_dynamic_allocation(self, measurement): dynamically allocated wires are present in the program. """)) - if any(is_dynamically_allocated_wire(w) for w in measurement.wires): + if any( + isinstance(w, DynamicJaxprTracer) and isinstance(w.val.aval, QrefQubit) + for w in measurement.wires + ): raise CompileError(textwrap.dedent(""" Terminal measurements cannot take in dynamically allocated wires since they must be temporary. @@ -762,71 +761,6 @@ def handle_ctrl_transform(self, *invals, jaxpr, n_control, control_values, work_ return [] -# pylint: disable=unused-argument -@PLxPRToQuantumJaxprInterpreter.register_primitive(plxpr_adjoint_transform_prim) -def handle_adjoint_transform( - self, - *plxpr_invals, - jaxpr, - lazy, - n_consts, -): - """Handle the conversion from plxpr to Catalyst jaxpr for the adjoint primitive""" - - if any(is_dynamically_allocated_wire(arg) for arg in plxpr_invals): - raise NotImplementedError( - "Dynamically allocated wires cannot be used in quantum adjoints yet." - ) - - assert jaxpr is not None - consts = plxpr_invals[:n_consts] - args = plxpr_invals[n_consts:] - - # Add the iteration start and the qreg to the args - self.init_qreg.insert_all_dangling_qubits() - qreg = self.init_qreg.get() - - jaxpr = ClosedJaxpr(jaxpr, consts) - - def calling_convention(*args_plus_qreg): - # The last arg is the scope argument for the body jaxpr - *args, qreg = args_plus_qreg - - # Launch a new interpreter for the body region - # A new interpreter's root qreg value needs a new recorder - converter = copy(self) - converter.qubit_index_recorder = QubitIndexRecorder() - init_qreg = QubitHandler(qreg, converter.qubit_index_recorder) - converter.init_qreg = init_qreg - - retvals = converter(jaxpr, *args) - init_qreg.insert_all_dangling_qubits() - return *retvals, converter.init_qreg.get() - - converted_jaxpr_branch = jax.make_jaxpr(calling_convention)(*args, qreg) - - converted_closed_jaxpr_branch = ClosedJaxpr( - convert_constvars_jaxpr(converted_jaxpr_branch.jaxpr), () - ) - new_consts = converted_jaxpr_branch.consts - _, args_tree = tree_flatten((new_consts, args, [qreg])) - # Perform the binding - outvals = adjoint_p.bind( - *new_consts, - *args, - qreg, - jaxpr=converted_closed_jaxpr_branch, - args_tree=args_tree, - ) - - # We assume the last output value is the returned qreg. - # Update the current qreg and remove it from the output values. - self.init_qreg.set(outvals.pop()) - - # Return only the output values that match the plxpr output values - return outvals - - @PLxPRToQuantumJaxprInterpreter.register_primitive(transform_prim) def _error_on_transform(*args, **kwargs): raise NotImplementedError("transforms cannot currently be applied inside a QNode.") diff --git a/frontend/catalyst/from_plxpr/qref_jax_primitives.py b/frontend/catalyst/from_plxpr/qref_jax_primitives.py index 3a83fb4784..f6173d0bf0 100644 --- a/frontend/catalyst/from_plxpr/qref_jax_primitives.py +++ b/frontend/catalyst/from_plxpr/qref_jax_primitives.py @@ -16,6 +16,7 @@ of quantum operations, measurements, and observables to reference semantics JAXPR. """ +from jax._src import source_info_util from jax._src.lib.mlir import ir from jax.core import AbstractValue, ShapedArray from jax.extend.core import Primitive @@ -25,6 +26,7 @@ ExtUIOp, ) from jaxlib.mlir.dialects.stablehlo import ConvertOp as StableHLOConvertOp +from pennylane.capture.primitives import adjoint_transform_prim as plxpr_adjoint_transform_prim # TODO: remove after jax v0.7.2 upgrade # Mock _ods_cext.globals.register_traceback_file_exclusion due to API conflicts between @@ -54,6 +56,7 @@ ), ): from mlir_quantum.dialects.qref import ( + AdjointOp, AllocOp, ComputationalBasisOp, CustomOp, @@ -549,6 +552,37 @@ def _qref_unitary_lowering( return () +# +# PL adjoint primitive +# +# pylint: disable=unused-argument +def _pl_adjoint_lowering( + jax_ctx, + *plxpr_invals, + jaxpr, + lazy, + n_consts, +): + new_jaxpr = jaxpr.replace(constvars=(), invars=jaxpr.constvars + jaxpr.invars) + + op = AdjointOp() + adjoint_block = op.regions[0].blocks.append() + with ir.InsertionPoint(adjoint_block): + source_info_util.extend_name_stack("adjoint") + _, _ = mlir.jaxpr_subcomp( + jax_ctx.module_context, + new_jaxpr, + jax_ctx.name_stack.extend("adjoint"), + mlir.TokenSet(), + [], + *plxpr_invals, + dim_var_values=jax_ctx.dim_var_values, + const_lowering=jax_ctx.const_lowering, + ) + + return () + + # # measure # @@ -685,4 +719,5 @@ def _qref_hermitian_lowering(jax_ctx: mlir.LoweringRuleContext, matrix: ir.Value (qref_compbasis_p, _qref_compbasis_lowering), (qref_namedobs_p, _qref_named_obs_lowering), (qref_hermitian_p, _qref_hermitian_lowering), + (plxpr_adjoint_transform_prim, _pl_adjoint_lowering), ) diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 45d30a7989..b5b05494b3 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -1529,7 +1529,7 @@ def _trace_classical_phase( # with an extra computational cost if any(isinstance(wire, qml.wires.DynamicWire) for wire in quantum_tape.wires): - msg = "qml.allocate() with qjit is only supported with program capture enabled." + msg = "qp.allocate() with qjit is only supported with program capture enabled." raise CompileError(msg) # 1. Recompute the original return diff --git a/frontend/test/lit/test_qref/test_control_flow.py b/frontend/test/lit/test_qref/test_control_flow.py index 5f42f21860..c5f7b320fc 100644 --- a/frontend/test/lit/test_qref/test_control_flow.py +++ b/frontend/test/lit/test_qref/test_control_flow.py @@ -387,3 +387,43 @@ def test_while_loop_with_dynamic_allocation(i: int): print(test_while_loop_with_dynamic_allocation.mlir) + + +# CHECK: func.func public @test_loop_with_adjoint() -> tensor +@qp.qjit(capture=True, autograph=True, target="mlir") +@qp.qnode(qp.device("null.qubit", wires=3)) +def test_loop_with_adjoint(): + """ + Test loops when adjoints are present. + """ + # CHECK: [[one_index:%.+]] = arith.constant 1 : index + # CHECK: [[ten_index:%.+]] = arith.constant 10 : index + # CHECK: [[zero_index:%.+]] = arith.constant 0 : index + + # CHECK: [[reg_device:%.+]] = qref.alloc( 3) : !qref.reg<3> + + def f(q1: int, q2: int, n: int): + for _ in range(n): + qp.SWAP(wires=[q1, q2]) + + # CHECK: [[reg_alloc:%.+]] = qref.alloc( 2) : !qref.reg<2> + # CHECK: [[q1_alloc:%.+]] = qref.get [[reg_alloc]][ 1] : !qref.reg<2> -> !qref.bit + with qp.allocate(2) as q: + + # CHECK: scf.for %arg0 = [[zero_index]] to [[ten_index]] step [[one_index]] { + # CHECK: qref.adjoint { + # CHECK: scf.for %arg1 = [[zero_index]] to %arg0 step [[one_index]] { + # CHECK: [[q0:%.+]] = qref.get [[reg_device]][ 0] : !qref.reg<3> -> !qref.bit + # CHECK: qref.custom "SWAP"() [[q1_alloc]], [[q0]] : !qref.bit, !qref.bit + # CHECK: } + # CHECK: } + # CHECK: } + for n in range(10): + qp.adjoint(f)(q[1], 0, n) + + # qref.dealloc [[reg_alloc]] : !qref.reg<2> + + return qp.expval(qp.X(0)) + + +print(test_loop_with_adjoint.mlir) diff --git a/frontend/test/lit/test_qref/test_flat_circuits.py b/frontend/test/lit/test_qref/test_flat_circuits.py index 7f466d75da..0ee9224912 100644 --- a/frontend/test/lit/test_qref/test_flat_circuits.py +++ b/frontend/test/lit/test_qref/test_flat_circuits.py @@ -361,3 +361,105 @@ def test_set_basis_state(): print(test_set_basis_state.mlir) + + +# CHECK: func.func public @test_adjoint(%arg0: tensor) -> tensor +@qp.qjit(capture=True, target="mlir") +@qp.qnode(qp.device("null.qubit", wires=4)) +def test_adjoint(i: int): + """ + Test adjoint + """ + # CHECK-DAG: [[angle:%.+]] = arith.constant 1.000000e-01 : f64 + # CHECK-DAG: [[angle_adj:%.+]] = arith.constant -1.000000e-01 : f64 + + # CHECK: [[reg:%.+]] = qref.alloc( 4) : !qref.reg<4> + + # CHECK: qref.adjoint { + # CHECK: [[q0:%.+]] = qref.get [[reg]][ 0] : !qref.reg<4> -> !qref.bit + # CHECK: [[i:%.+]] = tensor.extract %arg0[] : tensor + # CHECK: [[qi:%.+]] = qref.get [[reg]][[[i]]] : !qref.reg<4>, i64 -> !qref.bit + # CHECK: qref.custom "CNOT"() [[q0]], [[qi]] : !qref.bit, !qref.bit + # CHECK: } + qp.adjoint(qp.CNOT)(wires=[0, i]) + + # CHECK: [[q0:%.+]] = qref.get [[reg]][ 0] : !qref.reg<4> -> !qref.bit + # CHECK: qref.custom "RX"([[angle_adj]]) [[q0]] : !qref.bit + qp.adjoint(qp.RX(0.1, wires=0)) + + # CHECK: [[q0:%.+]] = qref.get [[reg]][ 0] : !qref.reg<4> -> !qref.bit + # CHECK: qref.paulirot ["X"]([[angle]]) [[q0]] adj : !qref.bit + qp.adjoint(qp.PauliRot(0.1, ["X"], 0)) + + def f(wires): + qp.X(wires) + + # CHECK: qref.adjoint { + # CHECK: [[q0:%.+]] = qref.get [[reg]][ 0] : !qref.reg<4> -> !qref.bit + # CHECK: qref.custom "PauliX"() [[q0]] : !qref.bit + # CHECK: } + qp.adjoint(f)(0) + + # CHECK: qref.adjoint { + # CHECK: [[i:%.+]] = tensor.extract %arg0[] : tensor + # CHECK: [[qi:%.+]] = qref.get [[reg]][[[i]]] : !qref.reg<4>, i64 -> !qref.bit + # CHECK: qref.custom "PauliX"() [[qi]] : !qref.bit + # CHECK: } + qp.adjoint(f)(i) + + return qp.expval(qp.X(0)) + + +print(test_adjoint.mlir) + + +# CHECK: func.func public @test_adjoint_with_allocation() -> tensor +@qp.qjit(capture=True, target="mlir") +@qp.qnode(qp.device("null.qubit", wires=4)) +def test_adjoint_with_allocation(): + """ + Test adjoint with dynamic qubit allocation + """ + # CHECK-DAG: [[angle:%.+]] = arith.constant 1.000000e-01 : f64 + + # CHECK-DAG: [[reg_device:%.+]] = qref.alloc( 4) : !qref.reg<4> + + def f(wires): + qp.RX(0.1, wires) + + # CHECK: [[reg_alloc:%.+]] = qref.alloc( 2) : !qref.reg<2> + # CHECK: [[alloc_q0:%.+]] = qref.get [[reg_alloc]][ 0] : !qref.reg<2> -> !qref.bit + # CHECK: [[alloc_q1:%.+]] = qref.get [[reg_alloc]][ 1] : !qref.reg<2> -> !qref.bit + with qp.allocate(2) as q: + + # CHECK: qref.adjoint { + # CHECK: qref.custom "PauliX"() [[alloc_q0]] : !qref.bit + # CHECK: } + qp.adjoint(qp.X)(q[0]) + + # CHECK: qref.adjoint { + # CHECK: qref.custom "RX"([[angle]]) [[alloc_q1]] : !qref.bit + # CHECK: } + qp.adjoint(f)(q[1]) + # CHECK: qref.dealloc [[reg_alloc]] : !qref.reg<2> + + # CHECK: qref.adjoint { + # CHECK: [[reg_alloc:%.+]] = qref.alloc( 2) : !qref.reg<2> + # CHECK: [[alloc_q0:%.+]] = qref.get [[reg_alloc]][ 0] : !qref.reg<2> -> !qref.bit + # CHECK: [[alloc_q1:%.+]] = qref.get [[reg_alloc]][ 1] : !qref.reg<2> -> !qref.bit + # CHECK: qref.custom "PauliX"() [[alloc_q0]] : !qref.bit + # CHECK: [[q0:%.+]] = qref.get [[reg_device]][ 0] : !qref.reg<4> -> !qref.bit + # CHECK: qref.custom "CNOT"() [[q0]], [[alloc_q1]] : !qref.bit, !qref.bit + # CHECK: qref.dealloc [[reg_alloc]] : !qref.reg<2> + # CHECK: } + def g(): + with qp.allocate(2) as q: + qp.X(q[0]) + qp.CNOT(wires=[0, q[1]]) + + qp.adjoint(g)() + + return qp.expval(qp.X(0)) + + +print(test_adjoint_with_allocation.mlir) diff --git a/frontend/test/pytest/from_plxpr/test_from_plxpr.py b/frontend/test/pytest/from_plxpr/test_from_plxpr.py index ff30efd3cd..96ac228ab1 100644 --- a/frontend/test/pytest/from_plxpr/test_from_plxpr.py +++ b/frontend/test/pytest/from_plxpr/test_from_plxpr.py @@ -19,7 +19,7 @@ import numpy as np import pennylane as qp import pytest -from pennylane.capture.primitives import for_loop_prim, while_loop_prim +from pennylane.capture.primitives import adjoint_transform_prim, for_loop_prim, while_loop_prim import catalyst from catalyst.from_plxpr import from_plxpr @@ -28,10 +28,6 @@ qref_get_p, qref_qinst_p, ) -from catalyst.jax_primitives import ( - adjoint_p, - qinsert_p, -) pytestmark = pytest.mark.usefixtures("disable_capture") @@ -441,8 +437,8 @@ def c(): catalyst_xpr = from_plxpr(plxpr)() qfunc_xpr = catalyst_xpr.eqns[0].params["call_jaxpr"] - assert qfunc_xpr.eqns[-6].primitive == qref_qinst_p - assert qfunc_xpr.eqns[-6].params == { + assert qfunc_xpr.eqns[-5].primitive == qref_qinst_p + assert qfunc_xpr.eqns[-5].params == { "adjoint": num_adjoints % 2 == 1, "ctrl_len": 0, "op": "S", @@ -546,14 +542,12 @@ def c(x): assert qfunc_xpr.eqns[1].primitive == qref_alloc_p assert qfunc_xpr.eqns[2].primitive == qref_get_p assert qfunc_xpr.eqns[3].primitive == qref_qinst_p - assert qfunc_xpr.eqns[4].primitive == qinsert_p - eqn = qfunc_xpr.eqns[5] - assert eqn.primitive == adjoint_p - assert eqn.invars[0] == qfunc_xpr.invars[0] # x - assert eqn.invars[1] == qfunc_xpr.eqns[4].outvars[0] # the qreg - assert eqn.outvars[0] == qfunc_xpr.eqns[6].invars[0] # also the qreg - assert len(eqn.outvars) == 1 + eqn = qfunc_xpr.eqns[4] + assert eqn.primitive == adjoint_transform_prim + assert eqn.invars[0] == qfunc_xpr.eqns[1].outvars[0] # the qreg, as a closure variable + assert eqn.invars[1] == qfunc_xpr.invars[0] # x + assert len(eqn.outvars) == 0 target_xpr = eqn.params["jaxpr"] assert target_xpr.eqns[1].primitive == qref_get_p @@ -566,8 +560,6 @@ def c(x): "params_len": 1, "qubits_len": 2, } - assert target_xpr.eqns[4].primitive == qinsert_p - assert target_xpr.eqns[5].primitive == qinsert_p @pytest.mark.parametrize("as_qfunc", (True, False)) def test_dynamic_control_wires(self, as_qfunc): diff --git a/frontend/test/pytest/test_dynamic_qubit_allocation.py b/frontend/test/pytest/test_dynamic_qubit_allocation.py index d9bccf5c41..514e7e89cf 100644 --- a/frontend/test/pytest/test_dynamic_qubit_allocation.py +++ b/frontend/test/pytest/test_dynamic_qubit_allocation.py @@ -534,6 +534,25 @@ def circuit(): assert np.allclose(observed, expected) +@pytest.mark.usefixtures("use_capture") +def test_adjoint(backend): + """ + Test adjoints work. + """ + + @qjit + @qp.qnode(qp.device(backend, wires=2)) + def circuit(): + with qp.allocate(1) as q: + qp.adjoint(qp.X)(q[0]) + qp.CNOT(wires=[q[0], 0]) + return qp.probs(wires=[0, 1]) + + observed = circuit() + expected = [0, 0, 1, 0] + assert np.allclose(observed, expected) + + def test_no_capture(backend): """ Test error message when used without capture. @@ -614,24 +633,5 @@ def circuit(): return qp.probs(q) -@pytest.mark.usefixtures("use_capture") -def test_unsupported_adjoint(backend): - """ - Test that an error is raised when a dynamically allocated wire is passed into a adjoint. - """ - - with pytest.raises( - NotImplementedError, - match="Dynamically allocated wires cannot be used in quantum adjoints yet.", - ): - - @qjit - @qp.qnode(qp.device(backend, wires=2)) - def circuit(): - with qp.allocate(1) as q: - qp.adjoint(qp.X)(q[0]) - return qp.probs(wires=[0, 1]) - - if __name__ == "__main__": pytest.main(["-x", __file__])