Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions frontend/catalyst/from_plxpr/qfunc_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import pennylane as qml
from jax._src.sharding_impls import UNSPECIFIED
from jax._src.tree_util import tree_flatten
from jax.extend.core import ClosedJaxpr
from jax.extend.core import ClosedJaxpr, JaxprEqn
from jax.interpreters.partial_eval import convert_constvars_jaxpr
from pennylane.capture import PlxprInterpreter, pause
from pennylane.capture.primitives import adjoint_transform_prim as plxpr_adjoint_transform_prim
Expand Down Expand Up @@ -63,6 +63,7 @@
set_basis_state_p,
set_state_p,
state_p,
template_p,
tensorobs_p,
unitary_p,
var_p,
Expand Down Expand Up @@ -129,6 +130,35 @@ def __init__(

super().__init__()

def interpret_operation_eqn(self, eqn):
"""Interpret an equation corresponding to an operator."""
if getattr(eqn.primitive, "prototype_op", False) is True:
return self.interpret_operation2_eqn(eqn)

return super().interpret_operation_eqn(eqn)

def interpret_operation2_eqn(self, eqn: JaxprEqn):
"""Interpret Operator2."""
self.init_qreg.insert_all_dangling_qubits()
invals = (self.read(invar) for invar in eqn.invars)

eqn_params = dict(eqn.params)
dyn_argnames = eqn_params.pop("dyn_argnames", ())
wire_argnames = eqn_params.pop("wire_argnames", ())

template_params = {
"template_name": eqn.primitive.name,
"dyn_argnames": dyn_argnames,
"wire_argnames": wire_argnames,
"ctrl": False,
"adjoint": False,
**eqn_params,
}

out = template_p.bind(*invals, self.init_qreg.get(), **template_params)
self.init_qreg.set(out)
return out

def interpret_operation(self, op, is_adjoint=False, control_values=(), control_wires=()):
"""Re-bind a pennylane operation as a catalyst instruction.

Expand Down Expand Up @@ -233,16 +263,24 @@ def _check_measurement_with_dynamic_allocation(self, measurement):
if len(measurement.wires) == 0 and not isinstance(
measurement, qml.measurements.StateMP
):
raise CompileError(textwrap.dedent("""
raise CompileError(
textwrap.dedent(
"""
Terminal measurements must take in an explicit list of wires when
dynamically allocated wires are present in the program.
"""))
"""
)
)

if any(is_dynamically_allocated_wire(w) for w in measurement.wires):
raise CompileError(textwrap.dedent("""
raise CompileError(
textwrap.dedent(
"""
Terminal measurements cannot take in dynamically allocated wires
since they must be temporary.
"""))
"""
)
)

# pylint: disable=too-many-branches
def interpret_measurement(self, measurement):
Expand Down
8 changes: 6 additions & 2 deletions frontend/catalyst/jax_extras/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,14 @@ def get_mlir_attribute_from_pyval(value):
elif 0 <= value < 18446744073709551616: # = 2**64
attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value)
else:
raise CompileError(textwrap.dedent("""
raise CompileError(
textwrap.dedent(
"""
Large interger attributes currently not supported in MLIR,
see https://github.com/llvm/llvm-project/issues/128072
"""))
"""
)
)

case float():
attr = ir.FloatAttr.get(ir.F64Type.get(), value)
Expand Down
66 changes: 65 additions & 1 deletion frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
# yet updated to the latest MLIR, causing compatibility issues. This workaround will be removed
# once JAX updates to a compatible MLIR version
# pylint: disable=ungrouped-imports
from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval
from catalyst.jax_extras.patches import mock_attributes
from catalyst.utils.patching import Patcher

Expand Down Expand Up @@ -118,6 +119,7 @@
SetBasisStateOp,
SetStateOp,
StateOp,
TemplateOp,
TensorOp,
VarianceOp,
)
Expand Down Expand Up @@ -354,6 +356,64 @@
measure_in_basis_p.multiple_results = True
decomprule_p = core.Primitive("decomposition_rule")
decomprule_p.multiple_results = True
template_p = Primitive("template")


@template_p.def_impl
def _template_impl(*args, template_name, dyn_argnames, wire_argnames, ctrl, adjoint, **kwargs):
raise ValueError("No impl.")


@template_p.def_abstract_eval
def _template_aval(*_, **__):
return AbstractQreg()


def _template_lowering(

Check notice on line 372 in frontend/catalyst/jax_primitives.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_primitives.py#L372

Too many arguments (6/5) (too-many-arguments)
ctx, *args, template_name, dyn_argnames, wire_argnames, ctrl=False, adjoint=False, **kwargs
):
inputs = args[: len(dyn_argnames)]
wire_args = args[len(dyn_argnames) : len(dyn_argnames) + len(wire_argnames)]

ctrl_inds = args[-3] if ctrl else None
ctrl_vals = args[-2] if ctrl else None
in_qreg = args[-1]

out_qreg_type = in_qreg.type

param_map = {}
for i, d in enumerate(dyn_argnames):
param_map[d] = (i,)

in_qubits_map = {}
for i, w in enumerate(wire_argnames):
in_qubits_map[w] = (i,)

static_data = {k: v for k, v in kwargs.items() if v is not None}

template_name = get_mlir_attribute_from_pyval(template_name)
param_map = get_mlir_attribute_from_pyval(param_map)
in_qubits_map = get_mlir_attribute_from_pyval(in_qubits_map)
static_data = get_mlir_attribute_from_pyval(static_data)

attrs = {}
if static_data:
attrs["static_data"] = static_data
if param_map:
attrs["param_map"] = param_map

return TemplateOp(
template_name=template_name,
inputs=inputs,
in_qreg=in_qreg,
qubit_inds=wire_args,
in_ctrl_inds=ctrl_inds,
in_ctrl_vals=ctrl_vals,
adjoint=adjoint,
in_qubits_map=in_qubits_map,
out_qreg=out_qreg_type,
**attrs,
).results


def decomposition_rule(func=None, *, is_qreg=True, num_params=0, pauli_word=None, op_type=None):
Expand Down Expand Up @@ -2904,10 +2964,13 @@
retval = _pjit_lowering(*args, **kwargs)
except NotImplementedError as e:
if "MLIR translation rule for primitive" in str(e):
msg = str(e) + """
msg = (
str(e)
+ """
This error sometimes occurs when using quantum operations
inside subroutines but calling them outside a qnode
"""
)
raise NotImplementedError(msg) from e
raise e

Expand Down Expand Up @@ -2967,6 +3030,7 @@
(quantum_subroutine_prim, subroutine_lowering),
(measure_in_basis_p, _measure_in_basis_lowering),
(decomprule_p, _decomposition_rule_lowering),
(template_p, _template_lowering),
)


Expand Down
29 changes: 29 additions & 0 deletions mlir/include/Quantum/IR/QuantumOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

include "mlir/IR/OpBase.td"
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"

include "Quantum/IR/QuantumAttrDefs.td"
include "Quantum/IR/QuantumDialect.td"
Expand Down Expand Up @@ -739,6 +741,33 @@ def QubitUnitaryOp : UnitaryGate_Op<"unitary", [ParametrizedGate, NoMemoryEffect

// -----

def TemplateOp : Quantum_Op<"template", [NoMemoryEffect, AttrSizedOperandSegments, AllTypesMatch<["in_qreg", "out_qreg"]>]> {

let summary = "Operation to represent quantum gates with arbitrary signatures";
let description = [{}];

let arguments = (ins
StrAttr:$template_name,
Variadic<AnyType>:$inputs,
QuregType:$in_qreg,
Variadic<1DTensorOf<[I64]>>:$qubit_inds,
Optional<1DTensorOf<[I64]>>:$in_ctrl_inds,
Optional<1DTensorOf<[I1]>>:$in_ctrl_vals,
UnitAttr:$adjoint,
OptionalAttr<DictionaryAttr>:$param_map,
DictionaryAttr:$in_qubits_map,
OptionalAttr<DictionaryAttr>:$static_data
);

let results = (outs
QuregType:$out_qreg
);

let hasCustomAssemblyFormat = 1;
}

// -----

class Region_Op<string mnemonic, list<Trait> traits = []> :
Quantum_Op<mnemonic, traits # [NoMemoryEffect]>;

Expand Down
Loading
Loading