Skip to content

capture + dynamic shots + sample + applying a pass fails to create plxpr #9054

@paul0403

Description

@paul0403

There is a mysterious failure when all of the following conditions are met:

  • use capture
  • apply a pass
  • returns qml.sample() from a qnode
  • use dynamic number of shots

When these conditions are met, JAX leaks an argument from the inner jaxpr of the pennylane transform primitive out into the returned outvar of the top level jaxpr, as the dynamically-shaped shot dimension for the samples.

See https://xanaduhq.slack.com/archives/C06CNPQLK2T/p1770325357111409

qml.capture.enable()
dev = qml.device("lightning.qubit", wires=2)

@qjit
def workflow(shots:int):
    @qml.transform(pass_name="cancel-inverses")
    @qml.qnode(dev, shots=shots)
    def aloha():
        qml.Hadamard(wires=0)
        return qml.sample()

    return aloha()

x = workflow(5)
print(x)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 735, in capture
    return trace_from_pennylane(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/from_plxpr/from_plxpr.py", line 536, in trace_from_pennylane
    jaxpr = from_plxpr(plxpr)(*plxpr.in_avals)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/from_plxpr/from_plxpr.py", line 204, in wrapped_fn
    return jax.make_jaxpr(original_fn)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paul.wang/catalyst_new/catalyst/cat_env/lib/python3.12/site-packages/pennylane/capture/base_interpreter.py", line 389, in eval
    outval = self.read(var)
             ^^^^^^^^^^^^^^
  File "/home/paul.wang/catalyst_new/catalyst/cat_env/lib/python3.12/site-packages/pennylane/capture/base_interpreter.py", line 267, in read
    return var.val if isinstance(var, jax.extend.core.Literal) else self._env[var]
                                                                    ~~~~~~~~~^^^^^
KeyError: Var(id=134741294431552):int64[]

The cause is because the top-level jaxpr is trying to return a Var from an inner jaxpr

{ lambda ; a:i64[]. let
    b:i64[c,2] = transform[
      args_slice=(0, 0, None)
      consts_slice=(0, 1, None)
      inner_jaxpr={ lambda c:i64[]; . let
          d:i64[c,2] = qnode[
            device=<lightning.qubit device (wires=2) at 0x7f1a30552f00>
            execution_config=ExecutionConfig(grad_on_execution=False, use_device_gradient=None, use_device_jacobian_product=False, gradient_method='best', gradient_keyword_arguments={}, device_options={}, interface=<Interface.JAX: 'jax'>, derivative_order=1, mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True, executor_backend=<class 'pennylane.concurrency.executors.native.multiproc.MPPoolExec'>)
            n_consts=1
            qfunc_jaxpr={ lambda e:i64[]; . let
                f:f64[] = convert_element_type[new_dtype=float64 weak_type=True] e
                g:f64[] = mul 0.1:f64[] f
                _:AbstractOperator() = RX[n_wires=1] g 1:i64[]
                _:AbstractOperator() = Hadamard[n_wires=1] 0:i64[]
                h:AbstractMeasurement(n_wires=0) = sample_wires[dtype=None] 
              in (h,) }
            qnode=<QNode: device='<lightning.qubit device (wires=2) at 0x7f1a30552f00>', interface='jax', diff_method='best', shots='Shots(total=JitTracer<~int64[]>)'>
            shots_len=1
          ] c c
        in (d,) }
      targs_slice=(1, None, None)
      tkwargs=()
      transform=<transform: cancel-inverses>
    ] a
  in (c, b) }

The outvar c is leaked from the inner jaxpr.

In jax/_src/interpreters/partial_eval.py/_add_implicit_outputs, the vars for the dynamic dimensions of a returned jaxpr value are added to the jaxpr outvars https://github.com/jax-ml/jax/blob/5712de44e97c455faed1fd45532e821ca66d025a/jax/_src/interpreters/partial_eval.py#L1843. Somehow for the transform primitive, the added Var is a leaked one from an inner jaxpr.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions