Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 10 additions & 19 deletions python/cudaq/kernel/ast_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def __init__(self,
*,
uniqueId=None,
kernelModuleName=None,
parentVariables=None,
locationOffset=('', 0),
verbose=False):
"""
Expand Down Expand Up @@ -426,6 +427,7 @@ def node_error(msg):
self.isSubscriptRoot = False
self.verbose = verbose
self.currentNode = None
self.parentVariables = parentVariables or {}

def debug_msg(self, msg, node=None):
if self.verbose:
Expand Down Expand Up @@ -2533,6 +2535,11 @@ def processDecorator(name, path=None):
decorator = resolve_qualified_symbol(name)
else:
decorator = recover_kernel_decorator(name)
if decorator is None and name in self.parentVariables:
from .kernel_decorator import isa_kernel_decorator
var = self.parentVariables[name]
if isa_kernel_decorator(var):
decorator = var

if decorator and not name in self.symbolTable:
callableTy = decorator.signature.get_callable_type()
Expand Down Expand Up @@ -2576,21 +2583,6 @@ def processDecoratorCall(symName):
# of eliminating unnecessary copies.
return self.__migrateLists(result, copy_list_to_stack)

def processFunctionCall(kernel):
nrArgs = len(kernel.type.inputs)
values = self.__groupValues(node.args, [(nrArgs, nrArgs)])
values = convertArguments([t for t in kernel.type.inputs], values)
if len(kernel.type.results) == 0:
func.CallOp(kernel, values)
return

# The logic for calls that return values must match the logic in
# `visit_Return`; anything copied to the heap during return must be
# copied back to the stack. Compiler optimizations should take care
# of eliminating unnecessary copies.
result = func.CallOp(kernel, values).result
return self.__migrateLists(result, copy_list_to_stack)

def resolveQualifiedName(pyVal):
if isinstance(pyVal, ast.Name):
return None, pyVal.id
Expand Down Expand Up @@ -5394,25 +5386,24 @@ def compile_to_mlir(uniqueId, astModule, signature: KernelSignature, **kwargs):

verbose = 'verbose' in kwargs and kwargs['verbose']
lineNumberOffset = kwargs['location'] if 'location' in kwargs else ('', 0)
preCompile = kwargs['preCompile'] if 'preCompile' in kwargs else False
kernelModuleName = kwargs[
'kernelModuleName'] if 'kernelModuleName' in kwargs else None
parentVariables = kwargs[
'parentVariables'] if 'parentVariables' in kwargs else None

# Initialize the captured arguments list to be populated by the AST Bridge.
signature.captured_args = []
# Create the AST Bridge
bridge = PyASTBridge(signature,
uniqueId=uniqueId,
verbose=verbose,
parentVariables=parentVariables,
locationOffset=lineNumberOffset,
kernelModuleName=kernelModuleName)

ValidateArgumentAnnotations(bridge).visit(astModule)
ValidateReturnStatements(bridge).visit(astModule)

if not preCompile:
raise RuntimeError("must be precompile mode")

# Build the AOT Quake Module for this kernel.
bridge.visit(astModule)

Expand Down
10 changes: 9 additions & 1 deletion python/cudaq/kernel/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,14 @@ def apply_call(self, target, *target_arguments):
```
"""
if isa_kernel_decorator(target):
if not target.is_compiled():
name = target.name
emitFatalError(
f"Kernel '{name}' must be compiled to be used in the kernel builder. "
f"Call `{name}.compile()` before initializing the kernel builder, "
f"or deactivate deferred compilation:\n\n"
f" @cudaq.kernel(defer_compilation=False)\n"
f" def {name}(...): ...\n")
target = self.resolve_callable_arg(self.insertPoint, target)
self.__applyControlOrAdjoint(target, False, [], *target_arguments)

Expand All @@ -1361,8 +1369,8 @@ def resolve_callable_arg(self, insPt, target):
Returns a `CreateLambdaOp` closure.
"""
# Add the target kernel to the current module.
cudaq_runtime.updateModule(self.uniqName, self.module, target.qkeModule)
fulluniq = nvqppPrefix + target.uniqName
cudaq_runtime.updateModule(fulluniq, self.module, target.qkeModule)
fn = recover_func_op(self.module, fulluniq)

# build the closure to capture the lifted `args`
Expand Down
176 changes: 124 additions & 52 deletions python/cudaq/kernel/kernel_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import ast
import inspect
import json
from functools import wraps
from cudaq.kernel.utils import emitWarning
import numpy as np
import sys
Expand All @@ -33,6 +34,42 @@
# representation and ultimately executable code.


def ensure_compiled(method):
"""
Decorator for `PyKernelDecorator` methods that ensures the kernel
is compiled before the method body executes.
"""

@wraps(method)
def wrapper(self, *args, **kwargs):
self._ensure_compiled()
return method(self, *args, **kwargs)

return wrapper


def ensure_not_recursive(method):
"""
Decorator for `PyKernelDecorator.resolve_decorator_at_callsite` method that
ensures the method is not called recursively.
"""

@wraps(method)
def wrapper(self, *args, **kwargs):
if self._resolving_arguments:
self._resolving_arguments = False
emitFatalError(
f"Not supported: recursive kernel call detected in {self.name}."
)

self._resolving_arguments = True
ret = method(self, *args, **kwargs)
self._resolving_arguments = False
return ret

return wrapper


class DecoratorCapture:

def __init__(self, decorator, values):
Expand Down Expand Up @@ -66,17 +103,19 @@ def __repr__(self):
class PyKernelDecorator(object):
"""
The `PyKernelDecorator` serves as a standard Python decorator that takes
the decorated function as input and optionally lowers its AST
representation to executable code via MLIR. This decorator enables full JIT
compilation mode, where the function is lowered to an MLIR representation.
the decorated function as input. The function AST is parsed and converted to
a Quake MLIR representation. This is passed on to the CUDAQ runtime for
execution at kernel call time.

This decorator exposes a call overload that executes the code via the
MLIR `ExecutionEngine` for the MLIR mode.
By default, MLIR compilation is deferred until the first call to the kernel.
If `defer_compilation` is set to `False`, the kernel will be compiled at
declaration time instead.
"""

def __init__(self,
function,
verbose=False,
defer_compilation=True,
module=None,
kernelName=None,
signature=None,
Expand All @@ -89,11 +128,12 @@ def __init__(self,
self.kernelModuleName = None
self.name = kernelName
self.verbose = verbose
# The `qkeModule` will be the quake target independent ModuleOp
self.qkeModule = None
# The `nvqModule` will be (if present) the default simulation ModuleOp
self.nvqModule = None
# Caches the `qkeModule` property once compiled
self._cached_qkeModule = None
self.defModule = _recover_module('cudaq.kernel.kernel_decorator')
# Whether we are currently resolving arguments to self. Used to detect
# (and prevent) recursive kernel calls.
self._resolving_arguments = False

if isinstance(function, str):
self.kernelFunction = None
Expand All @@ -117,12 +157,14 @@ def __init__(self,

if decorator is not None:
# shallow copy attributes from `decorator`
self.__dict__.update(vars(decorator))
self.uniqueId = decorator.uniqueId
self.uniqName = decorator.uniqName
else:
self.uniqueId = int(kernelName.split("..0x")[1], 16)
self.uniqName = kernelName

self.qkeModule = module
self._cached_qkeModule = module
self.astModule = None
self.signature = KernelSignature.parse_from_mlir(
self.qkeModule, self.uniqName)
else:
Expand All @@ -140,15 +182,24 @@ def __init__(self,
self.astModule = _parse_ast(self.funcSrc, self.verbose)
self.signature = KernelSignature.parse_from_ast(
self.astModule, self.name)
self.uniqueId = id(self)
self.uniqName = self.name + ".." + hex(self.uniqueId)

self.pre_compile()
if not defer_compilation:
self.compile()

def __del__(self):
# explicitly call `del` on the MLIR `ModuleOp` wrappers.
if self.qkeModule:
del self.qkeModule
if self.nvqModule:
del self.nvqModule
if self._cached_qkeModule:
del self._cached_qkeModule

@property
@ensure_compiled
def qkeModule(self):
"""
A target independent Quake MLIR representation of the kernel.
"""
return self._cached_qkeModule

def signatureWithCallables(self):
"""
Expand All @@ -174,40 +225,53 @@ def captured_variables(self):
"""The list of variables captured by the kernel."""
return self.signature.captured_variables()

def pre_compile(self):
def _ensure_compiled(self):
"""
Compile the Python AST to portable Quake.
Ensure that the kernel is compiled.
"""
if self._cached_qkeModule is None:
self.compile()

# If this target requires library mode, do not compile it to MLIR.
# TODO: this should always compile to MLIR.
handler = get_target_handler()
if handler.skip_compilation():
return
def is_compiled(self):
"""Whether the kernel has already been compiled."""
return self._cached_qkeModule is not None

# Otherwise, `precompile` the kernel to portable MLIR.
if self.qkeModule:
raise RuntimeError(self.name + " was already compiled")
self.uniqueId = id(self)
self.uniqName = self.name + ".." + hex(self.uniqueId)
self.qkeModule = compile_to_mlir(id(self),
self.astModule,
self.signature,
verbose=self.verbose,
location=self.location,
parentVariables=self.globalScopedVars,
preCompile=True,
kernelName=self.name,
kernelModuleName=self.kernelModuleName)

if (cudaq_runtime.is_current_target_full_qir() and
not self.signatureWithCallables()):
resMod = self.convert_to_full_qir([])
if not self.nvqModule:
self.nvqModule = resMod
def supports_compilation(self):
"""Whether the kernel can be compiled for the current target."""
handler = get_target_handler()
return not handler.skip_compilation()

def compile(self):
return
"""
Compile the Python AST to portable Quake.
"""
if not self.astModule:
emitFatalError(
f"Cannot compile kernel {self.name}: no AST module available")

if not self.supports_compilation():
emitFatalError(
f"Cannot compile kernel '{self.name}': target handler "
f"'{cudaq_runtime.get_target().name}' does not support compilation"
)

self._cached_qkeModule = compile_to_mlir(
id(self),
self.astModule,
self.signature,
verbose=self.verbose,
location=self.location,
parentVariables=self.globalScopedVars,
kernelName=self.name,
kernelModuleName=self.kernelModuleName)

# recursively compile any captured kernels if required
for captured_arg in self.signature.captured_args:
if isinstance(captured_arg, CapturedVariable
) and captured_arg.name in self.globalScopedVars:
var = self.globalScopedVars[captured_arg.name]
if isa_kernel_decorator(var):
var._ensure_compiled()

def convert_to_full_qir(self, vals):
return self.lower_quake_to_codegen(vals)
Expand All @@ -218,8 +282,6 @@ def lower_quake_to_codegen(self, argValues):
generation. If argument values are provided, we run argument synthesis
and specialize this instance of the kernel.
"""
if not self.qkeModule:
emitFatalError(f"no module in kernel decorator {self.name}")
result = cudaq_runtime.cloneModule(self.qkeModule)

if argValues:
Expand Down Expand Up @@ -276,19 +338,21 @@ def merge_quake_source(self, quakeText):

def __str__(self):
"""
Return the MLIR Module string representation for this kernel.
Return a string representation for this kernel as MLIR.
"""
if self.qkeModule:
return str(self.qkeModule)
return "The decorator " + hex(id(self)) + " is malformed"
return str(self.qkeModule)

def enable_return_to_log(self):
"""
Enable translation from `return` statements to QIR output log
"""
self.qkeModule.operation.attributes.__setitem__(
if self._cached_qkeModule is None:
emitFatalError(
f"kernel decorator {self.name} has not been compiled")
self._cached_qkeModule.operation.attributes.__setitem__(
'quake.cudaq_run', UnitAttr.get(context=self.qkeModule.context))

@ensure_compiled
def _repr_svg_(self):
"""
Return the SVG representation of `self` (:class:`PyKernelDecorator`).
Expand Down Expand Up @@ -445,6 +509,7 @@ def convertStringsToPauli(self, arg):
def formal_arity(self):
return len(self.arg_types())

@ensure_compiled
def handle_call_arguments(self, *args, allow_no_args=False):
"""
Resolve all the arguments at the call site for this decorator.
Expand Down Expand Up @@ -486,13 +551,18 @@ def handle_call_arguments(self, *args, allow_no_args=False):
return specialized_module, processedArgs

def get_none_type(self):
return NoneType.get(self.qkeModule.context)
if self._cached_qkeModule:
context = self._cached_qkeModule.context
else:
context = getMLIRContext()
return NoneType.get(context)

def handle_call_results(self):
if not self.return_type:
return self.get_none_type()
return self.return_type

@ensure_compiled
def launch_args_required(self):
"""
This is a deeper query on the quake module. The quake module may have
Expand Down Expand Up @@ -554,6 +624,8 @@ def delete_cache_execution_engine(self, key):
cudaq_runtime.delete_cache_execution_engine is not None):
cudaq_runtime.delete_cache_execution_engine(key)

@ensure_compiled
@ensure_not_recursive
def resolve_decorator_at_callsite(self, callingMod):
# Resolve all lifted arguments for `self`.
processedArgs = []
Expand Down
Loading
Loading