forked from NVIDIA/cuda-quantum
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkernel_decorator.py
More file actions
532 lines (462 loc) · 21.2 KB
/
kernel_decorator.py
File metadata and controls
532 lines (462 loc) · 21.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
# ============================================================================ #
# Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #
import ast, sys, traceback
import importlib
import inspect
import json
from typing import Callable
from ..mlir.ir import *
from ..mlir.passmanager import *
from ..mlir.dialects import quake, cc, func
from .ast_bridge import compile_to_mlir, PyASTBridge
from .utils import mlirTypeFromPyType, nvqppPrefix, mlirTypeToPyType, globalAstRegistry, emitFatalError, emitErrorIfInvalidPauli, globalRegisteredTypes
from .analysis import MidCircuitMeasurementAnalyzer, HasReturnNodeVisitor
from ..mlir._mlir_libs._quakeDialects import cudaq_runtime
from .captured_data import CapturedDataStorage
from ..handlers import PhotonicsHandler
import numpy as np
# This file implements the decorator mechanism needed to
# JIT compile CUDA-Q kernels. It exposes the cudaq.kernel()
# decorator which hooks us into the JIT compilation infrastructure
# which maps the AST representation to an MLIR representation and ultimately
# executable code.
class PyKernelDecorator(object):
"""
The `PyKernelDecorator` serves as a standard Python decorator that takes
the decorated function as input and optionally lowers its AST
representation to executable code via MLIR. This decorator enables full JIT
compilation mode, where the function is lowered to an MLIR representation.
This decorator exposes a call overload that executes the code via the
MLIR `ExecutionEngine` for the MLIR mode.
"""
def __init__(self,
function,
verbose=False,
module=None,
kernelName=None,
funcSrc=None,
signature=None,
location=None,
overrideGlobalScopedVars=None):
is_deserializing = isinstance(function, str)
# When initializing with a provided `funcSrc`, we cannot use inspect
# because we only have a string for the function source. That is - the
# "function" isn't actually a concrete Python Function object in memory
# that we can "inspect". Hence, use alternate approaches when
# initializing from `funcSrc`.
if is_deserializing:
self.kernelFunction = None
self.name = kernelName
self.location = location
self.signature = signature
else:
self.kernelFunction = function
self.name = kernelName if kernelName != None else self.kernelFunction.__name__
self.location = (inspect.getfile(self.kernelFunction),
inspect.getsourcelines(self.kernelFunction)[1]
) if self.kernelFunction is not None else ('', 0)
self.capturedDataStorage = None
self.module = module
self.verbose = verbose
self.argTypes = None
# Get any global variables from parent scope.
# We filter only types we accept: integers and floats.
# Note here we assume that the parent scope is 2 stack frames up
self.parentFrame = inspect.stack()[2].frame
if overrideGlobalScopedVars:
self.globalScopedVars = {
k: v for k, v in overrideGlobalScopedVars.items()
}
else:
self.globalScopedVars = {
k: v for k, v in dict(inspect.getmembers(self.parentFrame))
['f_locals'].items()
}
# Register any external class types that may be used
# in the kernel definition
for name, var in self.globalScopedVars.items():
if isinstance(var, type) and hasattr(var, '__annotations__'):
globalRegisteredTypes[name] = (var, var.__annotations__)
# Once the kernel is compiled to MLIR, we
# want to know what capture variables, if any, were
# used in the kernel. We need to track these.
self.dependentCaptures = None
if self.kernelFunction is None and not is_deserializing:
if self.module is not None:
# Could be that we don't have a function
# but someone has provided an external Module.
# If we want this new decorator to be callable
# we'll need to set the `argTypes`
symbols = SymbolTable(self.module.operation)
if nvqppPrefix + self.name in symbols:
function = symbols[nvqppPrefix + self.name]
entryBlock = function.entry_block
self.argTypes = [v.type for v in entryBlock.arguments]
self.signature = {
'arg{}'.format(i): mlirTypeToPyType(v)
for i, v in enumerate(self.argTypes)
}
self.returnType = self.signature[
'return'] if 'return' in self.signature else None
return
else:
emitFatalError(
"Invalid kernel decorator. Module and function are both None."
)
if is_deserializing:
self.funcSrc = funcSrc
else:
# Get the function source
src = inspect.getsource(self.kernelFunction)
# Strip off the extra tabs
leadingSpaces = len(src) - len(src.lstrip())
self.funcSrc = '\n'.join(
[line[leadingSpaces:] for line in src.split('\n')])
# Create the AST
self.astModule = ast.parse(self.funcSrc)
if verbose and importlib.util.find_spec('astpretty') is not None:
import astpretty
astpretty.pprint(self.astModule.body[0])
# Assign the signature for use later and
# keep a list of arguments (used for validation in the runtime)
if not is_deserializing:
self.signature = inspect.getfullargspec(
self.kernelFunction).annotations
self.arguments = [
(k, v) for k, v in self.signature.items() if k != 'return'
]
self.returnType = self.signature[
'return'] if 'return' in self.signature else None
# Validate that we have a return type annotation if necessary
hasRetNodeVis = HasReturnNodeVisitor()
hasRetNodeVis.visit(self.astModule)
if hasRetNodeVis.hasReturnNode and 'return' not in self.signature:
emitFatalError(
'CUDA-Q kernel has return statement but no return type annotation.'
)
# Run analyzers and attach metadata (only have 1 right now)
analyzer = MidCircuitMeasurementAnalyzer()
analyzer.visit(self.astModule)
self.metadata = {'conditionalOnMeasure': analyzer.hasMidCircuitMeasures}
# Store the AST for this kernel, it is needed for
# building up call graphs. We also must retain
# the source code location for error diagnostics
globalAstRegistry[self.name] = (self.astModule, self.location)
def compile(self):
"""
Compile the Python function AST to MLIR. This is a no-op
if the kernel is already compiled.
"""
# Before we can execute, we need to make sure
# variables from the parent frame that we captured
# have not changed. If they have changed, we need to
# recompile with the new values.
s = inspect.currentframe()
while s:
if s == self.parentFrame:
# We found the parent frame, now
# see if any of the variables we depend
# on have changed.
self.globalScopedVars = {
k: v
for k, v in dict(inspect.getmembers(s))['f_locals'].items()
}
if self.dependentCaptures != None:
for k, v in self.dependentCaptures.items():
if (isinstance(v, (list, np.ndarray))):
if not all(a == b for a, b in zip(
self.globalScopedVars[k], v)):
# Recompile if values in the list have changed.
self.module = None
break
elif self.globalScopedVars[k] != v:
# Need to recompile
self.module = None
break
break
s = s.f_back
if self.module != None:
return
self.module, self.argTypes, extraMetadata = compile_to_mlir(
self.astModule,
self.metadata,
self.capturedDataStorage,
verbose=self.verbose,
returnType=self.returnType,
location=self.location,
parentVariables=self.globalScopedVars)
# Grab the dependent capture variables, if any
self.dependentCaptures = extraMetadata[
'dependent_captures'] if 'dependent_captures' in extraMetadata else None
def merge_kernel(self, otherMod):
"""
Merge the kernel in this PyKernelDecorator (the ModuleOp) with
the provided ModuleOp.
"""
self.compile()
if not isinstance(otherMod, str):
otherMod = str(otherMod)
newMod = cudaq_runtime.mergeExternalMLIR(self.module, otherMod)
# Get the name of the kernel entry point
name = self.name
for op in newMod.body:
if isinstance(op, func.FuncOp):
for attr in op.attributes:
if 'cudaq-entrypoint' == attr.name:
name = op.name.value.replace(nvqppPrefix, '')
break
return PyKernelDecorator(None, kernelName=name, module=newMod)
def synthesize_callable_arguments(self, funcNames):
"""
Given this Kernel has callable block arguments, synthesize away these
callable arguments with the in-module FuncOps with given names. The
name at index 0 in the list corresponds to the first callable block
argument, index 1 to the second callable block argument, etc.
"""
self.compile()
cudaq_runtime.synthPyCallable(self.module, funcNames)
# Reset the argument types by removing the Callable
self.argTypes = [
a for a in self.argTypes if not cc.CallableType.isinstance(a)
]
def extract_c_function_pointer(self, name=None):
"""
Return the C function pointer for the function with given name, or
with the name of this kernel if not provided.
"""
self.compile()
return cudaq_runtime.jitAndGetFunctionPointer(
self.module, nvqppPrefix + self.name if name is None else name)
def __str__(self):
"""
Return the MLIR Module string representation for this kernel.
"""
self.compile()
return str(self.module)
def _repr_svg_(self):
"""
Return the SVG representation of `self` (:class:`PyKernelDecorator`).
This assumes no arguments are required to execute the kernel,
and `latex` (with `quantikz` package) and `dvisvgm` are installed,
and the temporary directory is writable.
If any of these assumptions fail, returns None.
"""
self.compile() # compile if not yet compiled
if self.argTypes is None or len(self.argTypes) != 0:
return None
from cudaq import getSVGstring
try:
from subprocess import CalledProcessError
try:
return getSVGstring(self)
except CalledProcessError:
return None
except ImportError:
return None
def isCastable(self, fromTy, toTy):
if F64Type.isinstance(toTy):
return F32Type.isinstance(fromTy) or IntegerType.isinstance(fromTy)
if F32Type.isinstance(toTy):
return F64Type.isinstance(fromTy) or IntegerType.isinstance(fromTy)
if ComplexType.isinstance(toTy):
floatToType = ComplexType(toTy).element_type
if ComplexType.isinstance(fromTy):
floatFromType = ComplexType(fromTy).element_type
return self.isCastable(floatFromType, floatToType)
return fromTy == floatToType or self.isCastable(fromTy, floatToType)
return False
def castPyList(self, fromEleTy, toEleTy, list):
if self.isCastable(fromEleTy, toEleTy):
if F64Type.isinstance(toEleTy):
return [float(i) for i in list]
if F32Type.isinstance(toEleTy):
return [np.float32(i) for i in list]
if ComplexType.isinstance(toEleTy):
floatToType = ComplexType(toEleTy).element_type
if F64Type.isinstance(floatToType):
return [complex(i) for i in list]
return [np.complex64(i) for i in list]
return list
def createStorage(self):
ctx = None if self.module == None else self.module.context
return CapturedDataStorage(ctx=ctx,
loc=self.location,
name=self.name,
module=self.module)
@staticmethod
def type_to_str(t):
"""
This converts types to strings in a clean JSON-compatible way.
int -> 'int'
list[float] -> 'list[float]'
List[float] -> 'list[float]'
"""
if hasattr(t, '__origin__') and t.__origin__ is not None:
# Handle generic types from typing
origin = t.__origin__
args = t.__args__
args_str = ', '.join(
PyKernelDecorator.type_to_str(arg) for arg in args)
return f'{origin.__name__}[{args_str}]'
elif hasattr(t, '__name__'):
return t.__name__
else:
return str(t)
def to_json(self):
"""
Convert `self` to a JSON-serialized version of the kernel such that
`from_json` can reconstruct it elsewhere.
"""
obj = dict()
obj['name'] = self.name
obj['location'] = self.location
obj['funcSrc'] = self.funcSrc
obj['signature'] = {
k: PyKernelDecorator.type_to_str(v)
for k, v in self.signature.items()
}
return json.dumps(obj)
@staticmethod
def from_json(jStr, overrideDict=None):
"""
Convert a JSON string into a new PyKernelDecorator object.
"""
j = json.loads(jStr)
return PyKernelDecorator(
'kernel', # just set to any string
verbose=False,
module=None,
kernelName=j['name'],
funcSrc=j['funcSrc'],
signature=j['signature'],
location=j['location'],
overrideGlobalScopedVars=overrideDict)
def __call__(self, *args):
"""
Invoke the CUDA-Q kernel. JIT compilation of the kernel AST to MLIR
will occur here if it has not already occurred, except when the target
requires custom handling.
"""
# Check if target is set
try:
target_name = cudaq_runtime.get_target().name
except RuntimeError:
target_name = None
if 'orca-photonics' == target_name:
if self.kernelFunction is None:
raise RuntimeError(
"The 'orca-photonics' target must be used with a valid function."
)
# NOTE: Since this handler does not support MLIR mode (yet), just
# invoke the kernel. If calling from a bound function, need to
# unpack the arguments, for example, see `pyGetStateLibraryMode`
try:
context_name = cudaq_runtime.getExecutionContextName()
except RuntimeError:
context_name = None
callable_args = args
if "extract-state" == context_name and len(args) == 1:
callable_args = args[0]
PhotonicsHandler(self.kernelFunction)(*callable_args)
return
# Prepare captured state storage for the run
self.capturedDataStorage = self.createStorage()
# Compile, no-op if the module is not None
self.compile()
if len(args) != len(self.argTypes):
emitFatalError(
f"Incorrect number of runtime arguments provided to kernel `{self.name}` ({len(self.argTypes)} required, {len(args)} provided)"
)
# validate the argument types
processedArgs = []
callableNames = []
for i, arg in enumerate(args):
if isinstance(arg, PyKernelDecorator):
arg.compile()
if isinstance(arg, str):
# Only allow `pauli_word` as string input
emitErrorIfInvalidPauli(arg)
arg = cudaq_runtime.pauli_word(arg)
if issubclass(type(arg), list):
if all(isinstance(a, str) for a in arg):
[emitErrorIfInvalidPauli(a) for a in arg]
arg = [cudaq_runtime.pauli_word(a) for a in arg]
mlirType = mlirTypeFromPyType(type(arg),
self.module.context,
argInstance=arg,
argTypeToCompareTo=self.argTypes[i])
# Support passing `list[int]` to a `list[float]` argument
# Support passing `list[int]` or `list[float]` to a `list[complex]` argument
if cc.StdvecType.isinstance(mlirType):
if cc.StdvecType.isinstance(self.argTypes[i]):
argEleTy = cc.StdvecType.getElementType(mlirType) # actual
eleTy = cc.StdvecType.getElementType(
self.argTypes[i]) # formal
if self.isCastable(argEleTy, eleTy):
processedArgs.append(
self.castPyList(argEleTy, eleTy, arg))
mlirType = self.argTypes[i]
continue
if not cc.CallableType.isinstance(
mlirType) and mlirType != self.argTypes[i]:
emitFatalError(
f"Invalid runtime argument type. Argument of type {mlirTypeToPyType(mlirType)} was provided, but {mlirTypeToPyType(self.argTypes[i])} was expected."
)
if cc.CallableType.isinstance(mlirType):
# Assume this is a PyKernelDecorator
callableNames.append(arg.name)
# It may be that the provided input callable kernel
# is not currently in the ModuleOp. Need to add it
# if that is the case, we have to use the AST
# so that it shares self.module's MLIR Context
symbols = SymbolTable(self.module.operation)
if nvqppPrefix + arg.name not in symbols:
tmpBridge = PyASTBridge(self.capturedDataStorage,
existingModule=self.module,
disableEntryPointTag=True)
tmpBridge.visit(globalAstRegistry[arg.name][0])
# Convert `numpy` arrays to lists
if cc.StdvecType.isinstance(mlirType) and hasattr(arg, "tolist"):
if arg.ndim != 1:
emitFatalError(
f"CUDA-Q kernels only support array arguments from NumPy that are one dimensional (input argument {i} has shape = {arg.shape})."
)
processedArgs.append(arg.tolist())
else:
processedArgs.append(arg)
if self.returnType == None:
cudaq_runtime.pyAltLaunchKernel(self.name,
self.module,
*processedArgs,
callable_names=callableNames)
self.capturedDataStorage.__del__()
self.capturedDataStorage = None
else:
result = cudaq_runtime.pyAltLaunchKernelR(
self.name,
self.module,
mlirTypeFromPyType(self.returnType, self.module.context),
*processedArgs,
callable_names=callableNames)
self.capturedDataStorage.__del__()
self.capturedDataStorage = None
return result
def kernel(function=None, **kwargs):
"""
The `cudaq.kernel` represents the CUDA-Q language function
attribute that programmers leverage to indicate the following function
is a CUDA-Q kernel and should be compile and executed on
an available quantum coprocessor.
Verbose logging can be enabled via `verbose=True`.
"""
if function:
return PyKernelDecorator(function)
else:
def wrapper(function):
return PyKernelDecorator(function, **kwargs)
return wrapper