-
Notifications
You must be signed in to change notification settings - Fork 746
Description
There is a mysterious failure when all of the following conditions are met:
- use capture
- apply a pass
- returns
qml.sample()from a qnode - use dynamic number of shots
When these conditions are met, JAX leaks an argument from the inner jaxpr of the pennylane transform primitive out into the returned outvar of the top level jaxpr, as the dynamically-shaped shot dimension for the samples.
See https://xanaduhq.slack.com/archives/C06CNPQLK2T/p1770325357111409
qml.capture.enable()
dev = qml.device("lightning.qubit", wires=2)
@qjit
def workflow(shots:int):
@qml.transform(pass_name="cancel-inverses")
@qml.qnode(dev, shots=shots)
def aloha():
qml.Hadamard(wires=0)
return qml.sample()
return aloha()
x = workflow(5)
print(x)File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 735, in capture
return trace_from_pennylane(
^^^^^^^^^^^^^^^^^^^^^
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/from_plxpr/from_plxpr.py", line 536, in trace_from_pennylane
jaxpr = from_plxpr(plxpr)(*plxpr.in_avals)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/from_plxpr/from_plxpr.py", line 204, in wrapped_fn
return jax.make_jaxpr(original_fn)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/paul.wang/catalyst_new/catalyst/cat_env/lib/python3.12/site-packages/pennylane/capture/base_interpreter.py", line 389, in eval
outval = self.read(var)
^^^^^^^^^^^^^^
File "/home/paul.wang/catalyst_new/catalyst/cat_env/lib/python3.12/site-packages/pennylane/capture/base_interpreter.py", line 267, in read
return var.val if isinstance(var, jax.extend.core.Literal) else self._env[var]
~~~~~~~~~^^^^^
KeyError: Var(id=134741294431552):int64[]
The cause is because the top-level jaxpr is trying to return a Var from an inner jaxpr
{ lambda ; a:i64[]. let
b:i64[c,2] = transform[
args_slice=(0, 0, None)
consts_slice=(0, 1, None)
inner_jaxpr={ lambda c:i64[]; . let
d:i64[c,2] = qnode[
device=<lightning.qubit device (wires=2) at 0x7f1a30552f00>
execution_config=ExecutionConfig(grad_on_execution=False, use_device_gradient=None, use_device_jacobian_product=False, gradient_method='best', gradient_keyword_arguments={}, device_options={}, interface=<Interface.JAX: 'jax'>, derivative_order=1, mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True, executor_backend=<class 'pennylane.concurrency.executors.native.multiproc.MPPoolExec'>)
n_consts=1
qfunc_jaxpr={ lambda e:i64[]; . let
f:f64[] = convert_element_type[new_dtype=float64 weak_type=True] e
g:f64[] = mul 0.1:f64[] f
_:AbstractOperator() = RX[n_wires=1] g 1:i64[]
_:AbstractOperator() = Hadamard[n_wires=1] 0:i64[]
h:AbstractMeasurement(n_wires=0) = sample_wires[dtype=None]
in (h,) }
qnode=<QNode: device='<lightning.qubit device (wires=2) at 0x7f1a30552f00>', interface='jax', diff_method='best', shots='Shots(total=JitTracer<~int64[]>)'>
shots_len=1
] c c
in (d,) }
targs_slice=(1, None, None)
tkwargs=()
transform=<transform: cancel-inverses>
] a
in (c, b) }
The outvar c is leaked from the inner jaxpr.
In jax/_src/interpreters/partial_eval.py/_add_implicit_outputs, the vars for the dynamic dimensions of a returned jaxpr value are added to the jaxpr outvars https://github.com/jax-ml/jax/blob/5712de44e97c455faed1fd45532e821ca66d025a/jax/_src/interpreters/partial_eval.py#L1843. Somehow for the transform primitive, the added Var is a leaked one from an inner jaxpr.