diff --git a/frontend/catalyst/from_plxpr/qfunc_interpreter.py b/frontend/catalyst/from_plxpr/qfunc_interpreter.py index 550eda6798..52c8026dc2 100644 --- a/frontend/catalyst/from_plxpr/qfunc_interpreter.py +++ b/frontend/catalyst/from_plxpr/qfunc_interpreter.py @@ -25,7 +25,7 @@ import pennylane as qml from jax._src.sharding_impls import UNSPECIFIED from jax._src.tree_util import tree_flatten -from jax.extend.core import ClosedJaxpr +from jax.extend.core import ClosedJaxpr, JaxprEqn from jax.interpreters.partial_eval import convert_constvars_jaxpr from pennylane.capture import PlxprInterpreter, pause from pennylane.capture.primitives import adjoint_transform_prim as plxpr_adjoint_transform_prim @@ -63,6 +63,7 @@ set_basis_state_p, set_state_p, state_p, + template_p, tensorobs_p, unitary_p, var_p, @@ -129,6 +130,35 @@ def __init__( super().__init__() + def interpret_operation_eqn(self, eqn): + """Interpret an equation corresponding to an operator.""" + if getattr(eqn.primitive, "prototype_op", False) is True: + return self.interpret_operation2_eqn(eqn) + + return super().interpret_operation_eqn(eqn) + + def interpret_operation2_eqn(self, eqn: JaxprEqn): + """Interpret Operator2.""" + self.init_qreg.insert_all_dangling_qubits() + invals = (self.read(invar) for invar in eqn.invars) + + eqn_params = dict(eqn.params) + dyn_argnames = eqn_params.pop("dyn_argnames", ()) + wire_argnames = eqn_params.pop("wire_argnames", ()) + + template_params = { + "template_name": eqn.primitive.name, + "dyn_argnames": dyn_argnames, + "wire_argnames": wire_argnames, + "ctrl": False, + "adjoint": False, + **eqn_params, + } + + out = template_p.bind(*invals, self.init_qreg.get(), **template_params) + self.init_qreg.set(out) + return out + def interpret_operation(self, op, is_adjoint=False, control_values=(), control_wires=()): """Re-bind a pennylane operation as a catalyst instruction. @@ -233,16 +263,24 @@ def _check_measurement_with_dynamic_allocation(self, measurement): if len(measurement.wires) == 0 and not isinstance( measurement, qml.measurements.StateMP ): - raise CompileError(textwrap.dedent(""" + raise CompileError( + textwrap.dedent( + """ Terminal measurements must take in an explicit list of wires when dynamically allocated wires are present in the program. - """)) + """ + ) + ) if any(is_dynamically_allocated_wire(w) for w in measurement.wires): - raise CompileError(textwrap.dedent(""" + raise CompileError( + textwrap.dedent( + """ Terminal measurements cannot take in dynamically allocated wires since they must be temporary. - """)) + """ + ) + ) # pylint: disable=too-many-branches def interpret_measurement(self, measurement): diff --git a/frontend/catalyst/jax_extras/lowering.py b/frontend/catalyst/jax_extras/lowering.py index d6718033f9..a952fdc611 100644 --- a/frontend/catalyst/jax_extras/lowering.py +++ b/frontend/catalyst/jax_extras/lowering.py @@ -207,10 +207,14 @@ def get_mlir_attribute_from_pyval(value): elif 0 <= value < 18446744073709551616: # = 2**64 attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value) else: - raise CompileError(textwrap.dedent(""" + raise CompileError( + textwrap.dedent( + """ Large interger attributes currently not supported in MLIR, see https://github.com/llvm/llvm-project/issues/128072 - """)) + """ + ) + ) case float(): attr = ir.FloatAttr.get(ir.F64Type.get(), value) diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 82a6db7d43..2b13b3e6a0 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -57,6 +57,7 @@ # yet updated to the latest MLIR, causing compatibility issues. This workaround will be removed # once JAX updates to a compatible MLIR version # pylint: disable=ungrouped-imports +from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval from catalyst.jax_extras.patches import mock_attributes from catalyst.utils.patching import Patcher @@ -118,6 +119,7 @@ SetBasisStateOp, SetStateOp, StateOp, + TemplateOp, TensorOp, VarianceOp, ) @@ -354,6 +356,64 @@ class MeasurementPlane(Enum): measure_in_basis_p.multiple_results = True decomprule_p = core.Primitive("decomposition_rule") decomprule_p.multiple_results = True +template_p = Primitive("template") + + +@template_p.def_impl +def _template_impl(*args, template_name, dyn_argnames, wire_argnames, ctrl, adjoint, **kwargs): + raise ValueError("No impl.") + + +@template_p.def_abstract_eval +def _template_aval(*_, **__): + return AbstractQreg() + + +def _template_lowering( + ctx, *args, template_name, dyn_argnames, wire_argnames, ctrl=False, adjoint=False, **kwargs +): + inputs = args[: len(dyn_argnames)] + wire_args = args[len(dyn_argnames) : len(dyn_argnames) + len(wire_argnames)] + + ctrl_inds = args[-3] if ctrl else None + ctrl_vals = args[-2] if ctrl else None + in_qreg = args[-1] + + out_qreg_type = in_qreg.type + + param_map = {} + for i, d in enumerate(dyn_argnames): + param_map[d] = (i,) + + in_qubits_map = {} + for i, w in enumerate(wire_argnames): + in_qubits_map[w] = (i,) + + static_data = {k: v for k, v in kwargs.items() if v is not None} + + template_name = get_mlir_attribute_from_pyval(template_name) + param_map = get_mlir_attribute_from_pyval(param_map) + in_qubits_map = get_mlir_attribute_from_pyval(in_qubits_map) + static_data = get_mlir_attribute_from_pyval(static_data) + + attrs = {} + if static_data: + attrs["static_data"] = static_data + if param_map: + attrs["param_map"] = param_map + + return TemplateOp( + template_name=template_name, + inputs=inputs, + in_qreg=in_qreg, + qubit_inds=wire_args, + in_ctrl_inds=ctrl_inds, + in_ctrl_vals=ctrl_vals, + adjoint=adjoint, + in_qubits_map=in_qubits_map, + out_qreg=out_qreg_type, + **attrs, + ).results def decomposition_rule(func=None, *, is_qreg=True, num_params=0, pauli_word=None, op_type=None): @@ -2904,10 +2964,13 @@ def subroutine_lowering(*args, **kwargs): retval = _pjit_lowering(*args, **kwargs) except NotImplementedError as e: if "MLIR translation rule for primitive" in str(e): - msg = str(e) + """ + msg = ( + str(e) + + """ This error sometimes occurs when using quantum operations inside subroutines but calling them outside a qnode """ + ) raise NotImplementedError(msg) from e raise e @@ -2967,6 +3030,7 @@ def subroutine_lowering(*args, **kwargs): (quantum_subroutine_prim, subroutine_lowering), (measure_in_basis_p, _measure_in_basis_lowering), (decomprule_p, _decomposition_rule_lowering), + (template_p, _template_lowering), ) diff --git a/mlir/include/Quantum/IR/QuantumOps.td b/mlir/include/Quantum/IR/QuantumOps.td index 54de5522ef..29594e9445 100644 --- a/mlir/include/Quantum/IR/QuantumOps.td +++ b/mlir/include/Quantum/IR/QuantumOps.td @@ -17,7 +17,9 @@ include "mlir/IR/OpBase.td" include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td" +include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" include "Quantum/IR/QuantumAttrDefs.td" include "Quantum/IR/QuantumDialect.td" @@ -739,6 +741,33 @@ def QubitUnitaryOp : UnitaryGate_Op<"unitary", [ParametrizedGate, NoMemoryEffect // ----- +def TemplateOp : Quantum_Op<"template", [NoMemoryEffect, AttrSizedOperandSegments, AllTypesMatch<["in_qreg", "out_qreg"]>]> { + + let summary = "Operation to represent quantum gates with arbitrary signatures"; + let description = [{}]; + + let arguments = (ins + StrAttr:$template_name, + Variadic:$inputs, + QuregType:$in_qreg, + Variadic<1DTensorOf<[I64]>>:$qubit_inds, + Optional<1DTensorOf<[I64]>>:$in_ctrl_inds, + Optional<1DTensorOf<[I1]>>:$in_ctrl_vals, + UnitAttr:$adjoint, + OptionalAttr:$param_map, + DictionaryAttr:$in_qubits_map, + OptionalAttr:$static_data + ); + + let results = (outs + QuregType:$out_qreg + ); + + let hasCustomAssemblyFormat = 1; +} + +// ----- + class Region_Op traits = []> : Quantum_Op; diff --git a/mlir/lib/Quantum/IR/QuantumOps.cpp b/mlir/lib/Quantum/IR/QuantumOps.cpp index 0fbec67099..b7b7e1d232 100644 --- a/mlir/lib/Quantum/IR/QuantumOps.cpp +++ b/mlir/lib/Quantum/IR/QuantumOps.cpp @@ -568,3 +568,225 @@ LogicalResult AdjointOp::verify() return success(); } + +// ----- + +void TemplateOp::print(OpAsmPrinter &p) +{ + // 1. Template Name + p << " "; + p.printAttribute(getTemplateNameAttr()); + + // 2. Variadic Inputs: (%arg0 : type, ...) + p << "("; + llvm::interleaveComma(llvm::zip(getInputs(), getInputs().getTypes()), p, + [&](auto pair) { p << std::get<0>(pair) << " : " << std::get<1>(pair); }); + p << ")"; + + p.increaseIndent(); + p.printNewline(); + + // 3. qreg and qubit_inds + p << "in_qreg (" << getInQreg() << ") qubit_inds ("; + llvm::interleaveComma(llvm::zip(getQubitInds(), getQubitInds().getTypes()), p, + [&](auto pair) { p << std::get<0>(pair) << " : " << std::get<1>(pair); }); + p << ")"; + + // 4. Conditional Line: adjoint, ctrls, ctrl_vals + bool hasAdjoint = getAdjoint(); + bool hasCtrls = getInCtrlInds() != nullptr; + bool hasCtrlVals = getInCtrlVals() != nullptr; + + if (hasAdjoint || hasCtrls || hasCtrlVals) { + p.printNewline(); + if (hasAdjoint) + p << "adjoint"; + if (hasCtrls) { + if (hasAdjoint) + p << " "; + p << "ctrls (" << getInCtrlInds() << " : " << getInCtrlInds().getType() << ")"; + } + if (hasCtrlVals) { + if (hasAdjoint || hasCtrls) + p << " "; + p << "ctrl_vals (" << getInCtrlVals() << " : " << getInCtrlVals().getType() << ")"; + } + } + + // 5. Results + p.printNewline(); + p << "-> " << getOutQreg().getType(); + + // 6. Indented Properties + if (auto paramMap = getParamMap()) { + p.printNewline(); + p << "param_map = " << paramMap; + } + + p.printNewline(); + p << "in_qubits_map = " << getInQubitsMap(); + + if (auto staticData = getStaticData()) { + p.printNewline(); + p << "static_data = " << staticData; + } + + // 7. Attribute Dictionary + SmallVector elidedAttrs = { + "template_name", "param_map", "in_qubits_map", "static_data", + "operandSegmentSizes", "adjoint" + }; + + // Check if there's anything left to print + bool hasExtraAttrs = false; + for (auto attr : getOperation()->getAttrs()) { + if (!llvm::is_contained(elidedAttrs, attr.getName().getValue())) { + hasExtraAttrs = true; + break; + } + } + + if (hasExtraAttrs) { + p.printNewline(); + p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs); + } + + p.decreaseIndent(); +} + +ParseResult TemplateOp::parse(OpAsmParser &parser, OperationState &result) +{ + // 1. Parse template Name + StringAttr templateName; + if (parser.parseAttribute(templateName, "template_name", result.attributes)) + return failure(); + + // 2. Parse Variadic Inputs + SmallVector inputs; + SmallVector inputTypes; + auto parseInputAndType = [&]() -> ParseResult { + OpAsmParser::UnresolvedOperand operand; + Type type; + if (parser.parseOperand(operand) || parser.parseColon() || parser.parseType(type)) + return failure(); + inputs.push_back(operand); + inputTypes.push_back(type); + return success(); + }; + + if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, parseInputAndType)) + return failure(); + + // 3. Parse single qreg + OpAsmParser::UnresolvedOperand inQreg; + if (parser.parseKeyword("in_qreg") || parser.parseLParen() || parser.parseOperand(inQreg) || + parser.parseRParen()) + return failure(); + + // 4. Parse qubit_inds + SmallVector qubitInds; + SmallVector qubitIndTypes; + auto parseQubitAndType = [&]() -> ParseResult { + OpAsmParser::UnresolvedOperand operand; + Type type; + if (parser.parseOperand(operand) || parser.parseColon() || parser.parseType(type)) + return failure(); + qubitInds.push_back(operand); + qubitIndTypes.push_back(type); + return success(); + }; + + if (parser.parseKeyword("qubit_inds") || + parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, parseQubitAndType)) + return failure(); + + // 5. Parse Adjoint, Ctrls, and Ctrl Vals (All Optional) + if (succeeded(parser.parseOptionalKeyword("adjoint"))) { + result.addAttribute("adjoint", parser.getBuilder().getUnitAttr()); + } + + OpAsmParser::UnresolvedOperand ctrlInds; + Type ctrlIndsType; + bool hasCtrls = false; + if (succeeded(parser.parseOptionalKeyword("ctrls"))) { + hasCtrls = true; + if (parser.parseLParen() || parser.parseOperand(ctrlInds) || parser.parseColon() || + parser.parseType(ctrlIndsType) || parser.parseRParen()) + return failure(); + } + + OpAsmParser::UnresolvedOperand ctrlVals; + Type ctrlValsType; + bool hasCtrlVals = false; + if (succeeded(parser.parseOptionalKeyword("ctrl_vals"))) { + hasCtrlVals = true; + if (parser.parseLParen() || parser.parseOperand(ctrlVals) || parser.parseColon() || + parser.parseType(ctrlValsType) || parser.parseRParen()) + return failure(); + } + + // 6. Resolve all Operands + Type inQregType = QuregType::get(parser.getContext()); + if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands) || + parser.resolveOperand(inQreg, inQregType, result.operands) || + parser.resolveOperands(qubitInds, qubitIndTypes, parser.getCurrentLocation(), + result.operands)) + return failure(); + + if (hasCtrls) { + if (parser.resolveOperand(ctrlInds, ctrlIndsType, result.operands)) + return failure(); + } + if (hasCtrlVals) { + if (parser.resolveOperand(ctrlVals, ctrlValsType, result.operands)) + return failure(); + } + + // 7. Set AttrSizedSegments + // Use the generated getter here too! + result.addAttribute("operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr({ + static_cast(inputs.size()), + 1, // in_qreg is strictly 1 + static_cast(qubitInds.size()), + hasCtrls ? 1 : 0, // 0 or 1 + hasCtrlVals ? 1 : 0 // 0 or 1 + })); + + // 8. Parse Return Types + if (parser.parseArrow()) + return failure(); + + Type outQregType; + if (parser.parseType(outQregType)) + return failure(); + + result.addTypes(outQregType); + + // 9. Parse Indented Properties + DictionaryAttr inQubitsMap; + + // Parse optional param_map + if (succeeded(parser.parseOptionalKeyword("param_map"))) { + DictionaryAttr paramMap; + if (parser.parseEqual() || parser.parseAttribute(paramMap, "param_map", result.attributes)) + return failure(); + } + + if (parser.parseKeyword("in_qubits_map") || parser.parseEqual() || + parser.parseAttribute(inQubitsMap, "in_qubits_map", result.attributes)) + return failure(); + + // Parse optional static_data + if (succeeded(parser.parseOptionalKeyword("static_data"))) { + DictionaryAttr staticData; + if (parser.parseEqual() || parser.parseAttribute(staticData, "static_data", result.attributes)) + return failure(); + } + + // 10. Parse Attribute Dictionary + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + return success(); +}