Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 77 additions & 59 deletions runtime/cudaq/platform/default/python/QPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,16 +180,65 @@ static void updateExecutionContext(ModuleOp module) {
}
}

static std::optional<cudaq::JitEngine> alreadyBuiltJITCode() {
static std::optional<cudaq::JitEngine>
alreadyBuiltJITCode(const std::string &name,
const std::vector<void *> &rawArgs) {
auto *currentExecCtx = cudaq::getExecutionContext();
if (!currentExecCtx || !currentExecCtx->allowJitEngineCaching)
return std::nullopt;
if (currentExecCtx->jitEng)

auto jit = currentExecCtx->jitEng;
if (jit && cudaq::compiler_artifact::isPersistingJITEngine()) {
CUDAQ_INFO("Loading previously compiled JIT engine for {}. This will "
"re-run the previous job, discarding any changes to the kernel, "
"arguments or launch configuration.",
currentExecCtx->kernelName);
return currentExecCtx->jitEng;

// Ensure the arguments are the same as the previous launch.
auto argsCreatorThunk = [&jit, &name]() {
return (void *)jit->lookupRawNameOrFail(name + ".argsCreator");
};
cudaq::compiler_artifact::checkArtifactReuse(name, rawArgs, jit.value(),
argsCreatorThunk);
}

return jit;
}

static cudaq::KernelThunkResultType
executeKernel(cudaq::JitEngine jit, const std::string &name,
const std::vector<void *> &rawArgs, bool hasResult,
bool hasVariationalArgs) {
cudaq::KernelThunkResultType result{nullptr, 0};
void *buff = nullptr;
if (hasResult) {
buff = const_cast<void *>(rawArgs.back());
} else if (hasVariationalArgs) {
auto argsCreatorFn = reinterpret_cast<int64_t (*)(const void *, void **)>(
jit.lookupRawNameOrFail(name + ".argsCreator"));
argsCreatorFn(static_cast<const void *>(rawArgs.data()), &buff);
}

if (buff) {
// Proceed to call the .thunk function so that the result value will be
// properly marshaled into the buffer we allocated in
// appendTheResultBuffer().
// FIXME: Python ought to set up the call stack so that a legit C++ entry
// point can be called instead of winging it and duplicating what the core
// compiler already does.
auto funcPtr = jit.lookupRawNameOrFail(name + ".thunk");
result = reinterpret_cast<cudaq::KernelThunkResultType (*)(void *, bool)>(
funcPtr)(buff, /*client_server=*/false);
} else {
jit.run(name);
}

if (hasVariationalArgs) {
std::free(buff);
return {nullptr, 0};
}

return result;
}

/// In a sample launch context, the (`JIT` compiled) execution engine may be
Expand Down Expand Up @@ -218,7 +267,6 @@ struct PythonLauncher : public cudaq::ModuleLauncher {
cudaq::getEnvBool("CUDAQ_PYTHON_CODEGEN_DUMP", false);

std::string fullName = cudaq::runtime::cudaqGenPrefixName + name;
cudaq::KernelThunkResultType result{nullptr, 0};

auto funcOp = module.lookupSymbol<func::FuncOp>(fullName);
if (!funcOp)
Expand Down Expand Up @@ -249,70 +297,40 @@ struct PythonLauncher : public cudaq::ModuleLauncher {
varArgIndices.clear();
}
const bool hasVariationalArgs = !varArgIndices.empty();
const bool hasResult = !!resultTy;

auto jit = alreadyBuiltJITCode();
if (!jit) {
// 1. Check that this call is sane.
if (enablePythonCodegenDump)
module.dump();
if (auto jit = alreadyBuiltJITCode(name, rawArgs)) {
return executeKernel(*jit, name, rawArgs, hasResult, hasVariationalArgs);
}

// 2. Merge other modules (e.g., if there are device kernel calls).
cudaq::detail::mergeAllCallableClosures(module, name, rawArgs);
// 1. Check that this call is sane.
if (enablePythonCodegenDump)
module.dump();

// Mark all newly merged kernels private.
for (auto &op : module)
if (auto f = dyn_cast<func::FuncOp>(op))
if (f != funcOp)
f.setPrivate();
// 2. Merge other modules (e.g., if there are device kernel calls).
cudaq::detail::mergeAllCallableClosures(module, name, rawArgs);

updateExecutionContext(module);
// Mark all newly merged kernels private.
for (auto &op : module)
if (auto f = dyn_cast<func::FuncOp>(op))
if (f != funcOp)
f.setPrivate();

// 3. LLVM JIT the code so we can execute it.
CUDAQ_INFO("Run Argument Synth.\n");
if (enablePythonCodegenDump)
module.dump();
specializeKernel(name, module, rawArgs, resultTy, enablePythonCodegenDump,
/*isEntryPoint=*/true, varArgIndices);
updateExecutionContext(module);

// 4. Execute the code right here, right now.
jit = cudaq::createQIRJITEngine(module, "qir:");
}
// 3. LLVM JIT the code so we can execute it.
CUDAQ_INFO("Run Argument Synth.\n");
if (enablePythonCodegenDump)
module.dump();
specializeKernel(name, module, rawArgs, resultTy, enablePythonCodegenDump,
/*isEntryPoint=*/true, varArgIndices);

if (cudaq::compiler_artifact::isPersistingJITEngine()) {
auto argsCreatorThunk = [&jit, &name]() {
return (void *)jit->lookupRawNameOrFail(name + ".argsCreator");
};
cudaq::compiler_artifact::checkArtifactReuse(name, rawArgs, jit.value(),
argsCreatorThunk);
}
auto jit = cudaq::createQIRJITEngine(module, "qir:");
cacheJITForPerformance(jit);

if (resultTy) {
// Proceed to call the .thunk function so that the result value will be
// properly marshaled into the buffer we allocated in
// appendTheResultBuffer().
// FIXME: Python ought to set up the call stack so that a legit C++ entry
// point can be called instead of winging it and duplicating what the core
// compiler already does.
auto funcPtr = jit->lookupRawNameOrFail(name + ".thunk");
void *buff = const_cast<void *>(rawArgs.back());
result = reinterpret_cast<cudaq::KernelThunkResultType (*)(void *, bool)>(
*funcPtr)(buff, /*client_server=*/false);
} else if (hasVariationalArgs) {
auto argsCreatorFn = reinterpret_cast<int64_t (*)(const void *, void **)>(
*jit->lookupRawNameOrFail(name + ".argsCreator"));
void *argsBuffer = nullptr;
argsCreatorFn(static_cast<const void *>(rawArgs.data()), &argsBuffer);
auto thunkFn =
reinterpret_cast<cudaq::KernelThunkResultType (*)(void *, bool)>(
*jit->lookupRawNameOrFail(name + ".thunk"));
thunkFn(argsBuffer, /*client_server=*/false);
std::free(argsBuffer);
} else {
jit->run(name);
}
cacheJITForPerformance(jit.value());
// FIXME: actually handle results
return result;
// 4. Execute the code right here, right now.
return executeKernel(jit, name, rawArgs, hasResult, hasVariationalArgs);
}

void *specializeModule(const std::string &name, ModuleOp module,
Expand Down
Loading