Skip to content

Commit 3000919

Browse files
authored
Add support for StatePrep and BasisState in QJIT with PLxPR program capture (#1631)
**Context:** Currently, when you try to execute a circuit containing `qml.StatePrep()` through the QJIT pipeline with PLxPR program capture enabled it results in an error, since the operation is bound to the `qinst_p` primitive rather than the `set_state_p` primitive. `qinst_p` primitives expect input parameters to be of type `float64`, but `StatePrep` can have complex input values, hence the error (see issue #1630 for details). **Description of the Change:** Binds `qml.StatePrep` operations to the `set_state_p` primitive to correctly support StatePrep in QJIT with PLxPR program capture. This PR also adds support for `qml.BasisState()` in QJIT with PLxPR capture since it is closely related to StatePrep. **Related GitHub Issues:** Fixes #1630 [[sc-88638](https://app.shortcut.com/xanaduai/story/88638/make-qml-stateprep-qjit-compatible-with-plxpr-program-capture), [sc-88653](https://app.shortcut.com/xanaduai/story/88653/stateprep-not-supported-in-qjit-with-plxpr-program-capture-enabled)]
1 parent 4006459 commit 3000919

File tree

4 files changed

+185
-6
lines changed

4 files changed

+185
-6
lines changed

doc/releases/changelog-dev.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919

2020
<h3>Bug fixes 🐛</h3>
2121

22+
* Catalyst now correctly supports `qml.StatePrep()` and `qml.BasisState()` operations in the
23+
experimental PennyLane program-capture pipeline.
24+
[(#1631)](https://github.com/PennyLaneAI/catalyst/pull/1631)
25+
2226
<h3>Internal changes ⚙️</h3>
2327

2428
<h3>Documentation 📝</h3>

frontend/catalyst/from_plxpr.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import jax
2222
import jax.core
23+
import jax.numpy as jnp
2324
import pennylane as qml
2425
from jax.extend.linear_util import wrap_init
2526
from jax.interpreters.partial_eval import convert_constvars_jaxpr
@@ -65,6 +66,8 @@
6566
quantum_kernel_p,
6667
qunitary_p,
6768
sample_p,
69+
set_basis_state_p,
70+
set_state_p,
6871
state_p,
6972
var_p,
7073
while_p,
@@ -423,6 +426,37 @@ def handle_global_phase(self, phase, *wires, n_wires):
423426
gphase_p.bind(phase, ctrl_len=0, adjoint=False)
424427

425428

429+
@QFuncPlxprInterpreter.register_primitive(qml.BasisState._primitive)
430+
def handle_basis_state(self, *invals, n_wires):
431+
"""Handle the conversion from plxpr to Catalyst jaxpr for the BasisState primitive"""
432+
state_inval = invals[0]
433+
wires_inval = invals[1:]
434+
435+
state = jax.lax.convert_element_type(state_inval, jnp.dtype(jnp.bool))
436+
wires = [self.get_wire(w) for w in wires_inval]
437+
out_wires = set_basis_state_p.bind(*wires, state)
438+
439+
for wire_values, new_wire in zip(wires_inval, out_wires):
440+
self.wire_map[wire_values] = new_wire
441+
442+
443+
# pylint: disable=unused-argument
444+
@QFuncPlxprInterpreter.register_primitive(qml.StatePrep._primitive)
445+
def handle_state_prep(self, *invals, n_wires, **kwargs):
446+
"""Handle the conversion from plxpr to Catalyst jaxpr for the StatePrep primitive"""
447+
state_inval = invals[0]
448+
wires_inval = invals[1:]
449+
450+
# jnp.complex128 is the top element in the type promotion lattice so it is ok to do this:
451+
# https://jax.readthedocs.io/en/latest/type_promotion.html
452+
state = jax.lax.convert_element_type(state_inval, jnp.dtype(jnp.complex128))
453+
wires = [self.get_wire(w) for w in wires_inval]
454+
out_wires = set_state_p.bind(*wires, state)
455+
456+
for wire_values, new_wire in zip(wires_inval, out_wires):
457+
self.wire_map[wire_values] = new_wire
458+
459+
426460
# pylint: disable=unused-argument, too-many-arguments
427461
@QFuncPlxprInterpreter.register_primitive(plxpr_cond_prim)
428462
def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):

frontend/test/pytest/test_capture_integration.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,57 @@ def g(x):
177177

178178
assert jnp.allclose(actual, desired)
179179

180+
@pytest.mark.parametrize(
181+
"n_wires, basis_state",
182+
[
183+
(1, jnp.array([0])),
184+
(1, jnp.array([1])),
185+
(2, jnp.array([0, 0])),
186+
(2, jnp.array([0, 1])),
187+
(2, jnp.array([1, 0])),
188+
(2, jnp.array([1, 1])),
189+
],
190+
)
191+
def test_basis_state(self, backend, n_wires, basis_state):
192+
"""Test the integration for a circuit with BasisState."""
193+
dev = qml.device(backend, wires=n_wires)
194+
195+
@qml.qnode(dev)
196+
def circuit(_basis_state):
197+
qml.BasisState(_basis_state, wires=list(range(n_wires)))
198+
return qml.state()
199+
200+
desired = circuit(basis_state)
201+
actual = qjit(circuit, experimental_capture=True)(basis_state)
202+
203+
assert jnp.allclose(actual, desired)
204+
205+
@pytest.mark.parametrize(
206+
"n_wires, init_state",
207+
[
208+
(1, jnp.array([1.0, 0.0], dtype=jnp.complex128)),
209+
(1, jnp.array([0.0, 1.0], dtype=jnp.complex128)),
210+
(1, jnp.array([1.0, 1.0], dtype=jnp.complex128) / jnp.sqrt(2.0)),
211+
(1, jnp.array([0.0, 1.0], dtype=jnp.float64)),
212+
(1, jnp.array([0, 1], dtype=jnp.int64)),
213+
(2, jnp.array([1.0, 0.0, 0.0, 0.0], dtype=jnp.complex128)),
214+
(2, jnp.array([0.0, 1.0, 0.0, 0.0], dtype=jnp.complex128)),
215+
],
216+
)
217+
def test_state_prep(self, backend, n_wires, init_state):
218+
"""Test the integration for a circuit with StatePrep."""
219+
dev = qml.device(backend, wires=n_wires)
220+
221+
@qml.qnode(dev)
222+
def circuit(init_state):
223+
qml.StatePrep(init_state, wires=list(range(n_wires)))
224+
return qml.state()
225+
226+
desired = circuit(init_state)
227+
actual = qjit(circuit, experimental_capture=True)(init_state)
228+
229+
assert jnp.allclose(actual, desired)
230+
180231
@pytest.mark.xfail(reason="Adjoint not supported.")
181232
@pytest.mark.parametrize("theta, val", [(jnp.pi, 0), (-100.0, 1)])
182233
def test_adjoint(self, backend, theta, val):

frontend/test/pytest/test_from_plxpr.py

Lines changed: 96 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,40 @@ def capture(self, args):
4848
return JAXPRRunner(fn=lambda: None, compile_options=catalyst.CompileOptions())
4949

5050

51-
def compare_call_jaxprs(jaxpr1, jaxpr2, skip_eqns=()):
51+
def compare_call_jaxprs(jaxpr1, jaxpr2, skip_eqns=(), ignore_order=False):
5252
"""Compares two call jaxprs and validates that they are essentially equal."""
5353
for inv1, inv2 in zip(jaxpr1.invars, jaxpr2.invars):
5454
assert inv1.aval == inv2.aval, f"{inv1.aval}, {inv2.aval}"
5555
for ov1, ov2 in zip(jaxpr1.outvars, jaxpr2.outvars):
5656
assert ov1.aval == ov2.aval
57-
assert len(jaxpr1.eqns) == len(jaxpr2.eqns)
58-
59-
for i, (eqn1, eqn2) in enumerate(zip(jaxpr1.eqns, jaxpr2.eqns)):
60-
if i not in skip_eqns:
61-
compare_eqns(eqn1, eqn2)
57+
assert len(jaxpr1.eqns) == len(
58+
jaxpr2.eqns
59+
), f"Number of equations differ: {len(jaxpr1.eqns)} vs {len(jaxpr2.eqns)}"
60+
61+
if not ignore_order:
62+
# Assert that equations in both jaxprs are equivalent and in same order
63+
for i, (eqn1, eqn2) in enumerate(zip(jaxpr1.eqns, jaxpr2.eqns)):
64+
if i not in skip_eqns:
65+
compare_eqns(eqn1, eqn2)
66+
67+
else:
68+
# Assert that equations in both jaxprs are equivalent but in any order
69+
eqns1 = [eqn for i, eqn in enumerate(jaxpr1.eqns) if i not in skip_eqns]
70+
eqns2 = [eqn for i, eqn in enumerate(jaxpr2.eqns) if i not in skip_eqns]
71+
72+
for eqn1 in eqns1:
73+
found_match = False
74+
for i, eqn2 in enumerate(eqns2):
75+
try:
76+
compare_eqns(eqn1, eqn2)
77+
# Remove the matched equation to prevent double-matching
78+
eqns2.pop(i)
79+
found_match = True
80+
break # Exit inner loop after finding a match
81+
except AssertionError:
82+
pass # Continue to the next equation in eqns2
83+
if not found_match:
84+
raise AssertionError(f"No matching equation found for: {eqn1}")
6285

6386

6487
def compare_eqns(eqn1, eqn2):
@@ -428,6 +451,73 @@ def circuit():
428451

429452
compare_call_jaxprs(call_jaxpr_pl, call_jaxpr_c)
430453

454+
def test_basis_state(self, disable_capture):
455+
"""Test comparison and execution of a jaxpr containing BasisState."""
456+
dev = qml.device("lightning.qubit", wires=2)
457+
458+
@qml.qnode(dev)
459+
def circuit(_basis_state):
460+
qml.BasisState(_basis_state, wires=[0, 1])
461+
return qml.state()
462+
463+
basis_state = np.array([1, 1])
464+
expected_state_vector = np.array([0, 0, 0, 1], dtype=np.complex128)
465+
466+
qml.capture.enable()
467+
plxpr = jax.make_jaxpr(circuit)(basis_state)
468+
converted = from_plxpr(plxpr)(basis_state)
469+
qml.capture.disable()
470+
471+
assert converted.eqns[0].primitive == catalyst.jax_primitives.quantum_kernel_p
472+
assert converted.eqns[0].params["qnode"] is circuit
473+
474+
catalyst_res = catalyst_execute_jaxpr(converted)(basis_state)
475+
assert len(catalyst_res) == 1
476+
assert qml.math.allclose(catalyst_res[0], expected_state_vector)
477+
478+
qjit_obj = qjit(circuit)
479+
qjit_obj(basis_state)
480+
catalxpr = qjit_obj.jaxpr
481+
call_jaxpr_pl = get_call_jaxpr(converted)
482+
call_jaxpr_c = get_call_jaxpr(catalxpr)
483+
484+
# Ignore ordering of eqns when comparing jaxpr since Catalyst performs sorting
485+
compare_call_jaxprs(call_jaxpr_pl, call_jaxpr_c, ignore_order=True)
486+
487+
def test_state_prep(self, disable_capture):
488+
"""Test comparison and execution of a jaxpr containing StatePrep."""
489+
dev = qml.device("lightning.qubit", wires=1)
490+
491+
@qml.qnode(dev)
492+
def circuit(_init_state):
493+
# NOTE: Require validate_norm=False here otherwise Catalyst jaxpr contains
494+
# unused function that computes norm
495+
qml.StatePrep(_init_state, wires=0, validate_norm=False)
496+
return qml.state()
497+
498+
init_state = np.array([1, 1], dtype=np.complex128) / np.sqrt(2)
499+
500+
qml.capture.enable()
501+
plxpr = jax.make_jaxpr(circuit)(init_state)
502+
converted = from_plxpr(plxpr)(init_state)
503+
qml.capture.disable()
504+
505+
assert converted.eqns[0].primitive == catalyst.jax_primitives.quantum_kernel_p
506+
assert converted.eqns[0].params["qnode"] is circuit
507+
508+
catalyst_res = catalyst_execute_jaxpr(converted)(init_state)
509+
assert len(catalyst_res) == 1
510+
assert qml.math.allclose(catalyst_res[0], init_state)
511+
512+
qjit_obj = qjit(circuit)
513+
qjit_obj(init_state)
514+
catalxpr = qjit_obj.jaxpr
515+
call_jaxpr_pl = get_call_jaxpr(converted)
516+
call_jaxpr_c = get_call_jaxpr(catalxpr)
517+
518+
# Ignore ordering of eqns when comparing jaxpr since Catalyst performs sorting
519+
compare_call_jaxprs(call_jaxpr_pl, call_jaxpr_c, ignore_order=True)
520+
431521
def test_multiple_measurements(self, disable_capture):
432522
"""Test that we can convert a circuit with multiple measurement returns."""
433523

0 commit comments

Comments
 (0)