Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/cudaq/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def ApplySpecialization : Pass<"apply-op-specialization", "mlir::ModuleOp"> {
];
}

def ArgumentSynthesis : Pass<"argument-synthesis", "mlir::func::FuncOp"> {
def ArgumentSynthesis : Pass<"argument-synthesis", "mlir::ModuleOp"> {
let summary = "Specialize a function by replacing arguments with constants";
let description = [{
This pass takes a list of functions and argument substitutions. For each
Expand Down
207 changes: 103 additions & 104 deletions lib/Optimizer/Transforms/ArgumentSynthesis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,121 +31,120 @@ class ArgumentSynthesisPass
using ArgumentSynthesisBase::ArgumentSynthesisBase;

void runOnOperation() override {
func::FuncOp func = getOperation();
StringRef funcName = func.getName();
std::string text;
if (std::find_if(funcList.begin(), funcList.end(),
[&](const std::string &item) {
auto pos = item.find(':');
if (pos == std::string::npos)
return false;
std::string itemName = item.substr(0, pos);
bool result = itemName == funcName;
if (result)
text = item.substr(pos + 1);
return result;
}) == funcList.end()) {
// If the function isn't on the list, do nothing.
LLVM_DEBUG(llvm::dbgs() << funcName << " not in list.\n");
return;
}
ModuleOp moduleOp = getOperation();
for (auto item : funcList) {
auto pos = item.find(':');
if (pos == std::string::npos)
continue;

// If there are no substitutions, we're done.
if (text.empty()) {
LLVM_DEBUG(llvm::dbgs() << funcName << " has no substitutions.");
return;
}
std::string funcName = item.substr(0, pos);
std::string text = item.substr(pos + 1);

// If we're here, we have a FuncOp and we have substitutions that can be
// applied.
//
// 1. Create a Module with the substitutions that we'll be making.
auto *ctx = func.getContext();
LLVM_DEBUG(llvm::dbgs() << "substitution pattern: '" << text << "'\n");
auto substMod = [&]() -> OwningOpRef<ModuleOp> {
if (text.front() == '*') {
// Substitutions are a raw string after the '*' character.
return parseSourceString<ModuleOp>(text.substr(1), ctx);
}
// Substitutions are in a text file (command-line usage).
return parseSourceFile<ModuleOp>(text, ctx);
}();
assert(*substMod && "module must have been created");

// 2. Go through the Module and process each substitution.
SmallVector<bool> processedArgs(func.getFunctionType().getNumInputs());
SmallVector<std::tuple<unsigned, Value, Value>> replacements;
BitVector replacedArgs(processedArgs.size());
for (auto &op : *substMod) {
auto subst = dyn_cast<cudaq::cc::ArgumentSubstitutionOp>(op);
if (!subst) {
if (auto symInterface = dyn_cast<SymbolOpInterface>(op)) {
auto name = symInterface.getName();
auto srcMod = func->getParentOfType<ModuleOp>();
auto obj = srcMod.lookupSymbol(name);
if (!obj)
srcMod.getBody()->push_back(op.clone());
}
auto *op = moduleOp.lookupSymbol(funcName);
func::FuncOp func = dyn_cast_if_present<func::FuncOp>(op);

if (!func) {
LLVM_DEBUG(llvm::dbgs() << funcName << " is not in the module.");
continue;
}
auto pos = subst.getPosition();
if (pos >= processedArgs.size()) {
func.emitError("Argument " + std::to_string(pos) + " is invalid.");
signalPassFailure();
return;
}
if (processedArgs[pos]) {
func.emitError("Argument " + std::to_string(pos) +
" was already substituted.");
signalPassFailure();
return;
}

// OK, substitute the code for the argument.
Block &entry = func.getRegion().front();
processedArgs[pos] = true;
if (subst.getBody().front().empty()) {
// No code is present. Erase the argument if it is not used.
const auto numUses =
std::distance(entry.getArgument(pos).getUses().begin(),
entry.getArgument(pos).getUses().end());
LLVM_DEBUG(llvm::dbgs() << "maybe erasing an unused argument ("
<< std::to_string(numUses) << ")\n");
if (numUses == 0)
replacedArgs.set(pos);
// If there are no substitutions, we're done.
if (text.empty()) {
LLVM_DEBUG(llvm::dbgs() << funcName << " has no substitutions.");
continue;
}
OpBuilder builder{ctx};
Block *splitBlock = entry.splitBlock(entry.begin());
builder.setInsertionPointToEnd(&entry);
builder.create<cf::BranchOp>(func.getLoc(), &subst.getBody().front());
Operation *lastOp = &subst.getBody().front().back();
builder.setInsertionPointToEnd(&subst.getBody().front());
builder.create<cf::BranchOp>(func.getLoc(), splitBlock);
func.getBlocks().splice(Region::iterator{splitBlock},
subst.getBody().getBlocks());
if (lastOp &&
lastOp->getResult(0).getType() == entry.getArgument(pos).getType()) {
LLVM_DEBUG(llvm::dbgs()
<< funcName << " argument " << std::to_string(pos)
<< " was substituted.\n");
replacements.emplace_back(pos, entry.getArgument(pos),
lastOp->getResult(0));

// If we're here, we have a FuncOp and we have substitutions that can be
// applied.
//
// 1. Create a Module with the substitutions that we'll be making.
auto *ctx = func.getContext();
LLVM_DEBUG(llvm::dbgs() << "substitution pattern: '" << text << "'\n");
auto substMod = [&]() -> OwningOpRef<ModuleOp> {
if (text.front() == '*') {
// Substitutions are a raw string after the '*' character.
return parseSourceString<ModuleOp>(text.substr(1), ctx);
}
// Substitutions are in a text file (command-line usage).
return parseSourceFile<ModuleOp>(text, ctx);
}();
assert(*substMod && "module must have been created");

// 2. Go through the Module and process each substitution.
SmallVector<bool> processedArgs(func.getFunctionType().getNumInputs());
SmallVector<std::tuple<unsigned, Value, Value>> replacements;
BitVector replacedArgs(processedArgs.size());
for (auto &op : *substMod) {
auto subst = dyn_cast<cudaq::cc::ArgumentSubstitutionOp>(op);
if (!subst) {
if (auto symInterface = dyn_cast<SymbolOpInterface>(op)) {
auto name = symInterface.getName();
auto obj = moduleOp.lookupSymbol(name);
if (!obj)
moduleOp.getBody()->push_back(op.clone());
}
continue;
}
auto pos = subst.getPosition();
if (pos >= processedArgs.size()) {
func.emitError("Argument " + std::to_string(pos) + " is invalid.");
signalPassFailure();
return;
}
if (processedArgs[pos]) {
func.emitError("Argument " + std::to_string(pos) +
" was already substituted.");
signalPassFailure();
return;
}

// OK, substitute the code for the argument.
Block &entry = func.getRegion().front();
processedArgs[pos] = true;
if (subst.getBody().front().empty()) {
// No code is present. Erase the argument if it is not used.
const auto numUses =
std::distance(entry.getArgument(pos).getUses().begin(),
entry.getArgument(pos).getUses().end());
LLVM_DEBUG(llvm::dbgs() << "maybe erasing an unused argument ("
<< std::to_string(numUses) << ")\n");
if (numUses == 0)
replacedArgs.set(pos);
continue;
}
OpBuilder builder{ctx};
Block *splitBlock = entry.splitBlock(entry.begin());
builder.setInsertionPointToEnd(&entry);
builder.create<cf::BranchOp>(func.getLoc(), &subst.getBody().front());
Operation *lastOp = &subst.getBody().front().back();
builder.setInsertionPointToEnd(&subst.getBody().front());
builder.create<cf::BranchOp>(func.getLoc(), splitBlock);
func.getBlocks().splice(Region::iterator{splitBlock},
subst.getBody().getBlocks());
if (lastOp && lastOp->getResult(0).getType() ==
entry.getArgument(pos).getType()) {
LLVM_DEBUG(llvm::dbgs()
<< funcName << " argument " << std::to_string(pos)
<< " was substituted.\n");
replacements.emplace_back(pos, entry.getArgument(pos),
lastOp->getResult(0));
}
}
}

// Note: if we exited before here, any code that was cloned into the
// function is still dead and can be removed by a DCE.
// Note: if we exited before here, any code that was cloned into the
// function is still dead and can be removed by a DCE.

// 3. Replace the block argument values with the freshly inserted new code.
for (auto [pos, fromVal, toVal] : replacements) {
replacedArgs.set(pos);
fromVal.replaceAllUsesWith(toVal);
}
// 3. Replace the block argument values with the freshly inserted new
// code.
for (auto [pos, fromVal, toVal] : replacements) {
replacedArgs.set(pos);
fromVal.replaceAllUsesWith(toVal);
}

// 4. Finish specializing func and erase any of func's arguments that were
// substituted.
func.eraseArguments(replacedArgs);
// 4. Finish specializing func and erase any of func's arguments that were
// substituted.
func.eraseArguments(replacedArgs);
}
}
};
} // namespace
Expand Down
3 changes: 1 addition & 2 deletions python/runtime/cudaq/platform/py_alt_launch_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,8 +552,7 @@ MlirModule synthesizeKernel(const std::string &name, MlirModule module,
ss << argCon.getSubstitutionModule();
SmallVector<StringRef> substs = {substBuff};
PassManager pm(context);
pm.addNestedPass<func::FuncOp>(
cudaq::opt::createArgumentSynthesisPass(kernels, substs));
pm.addPass(opt::createArgumentSynthesisPass(kernels, substs));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addPass(opt::createDeleteStates());

Expand Down
3 changes: 1 addition & 2 deletions runtime/common/BaseRemoteRESTQPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,7 @@ class BaseRemoteRESTQPU : public cudaq::QPU {
llvm::raw_string_ostream ss(substBuff);
ss << argCon.getSubstitutionModule();
mlir::SmallVector<mlir::StringRef> substs = {substBuff};
pm.addNestedPass<mlir::func::FuncOp>(
opt::createArgumentSynthesisPass(kernels, substs));
pm.addPass(opt::createArgumentSynthesisPass(kernels, substs));
pm.addPass(opt::createDeleteStates());
} else if (updatedArgs) {
cudaq::info("Run Quake Synth.\n");
Expand Down
3 changes: 1 addition & 2 deletions runtime/common/BaseRestRemoteClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,7 @@ class BaseRemoteRestRuntimeClient : public cudaq::RemoteRuntimeClient {
llvm::raw_string_ostream ss(substBuff);
ss << argCon.getSubstitutionModule();
mlir::SmallVector<mlir::StringRef> substs = {substBuff};
pm.addNestedPass<mlir::func::FuncOp>(
opt::createArgumentSynthesisPass(kernels, substs));
pm.addPass(opt::createArgumentSynthesisPass(kernels, substs));
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(opt::createDeleteStates());
} else if (args) {
Expand Down
16 changes: 16 additions & 0 deletions test/Quake/arg_subst-5.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// ========================================================================== //
// Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. //
// All rights reserved. //
// //
// This source code and the accompanying materials are made available under //
// the terms of the Apache License 2.0 which accompanies this distribution. //
// ========================================================================== //

module {
cc.arg_subst[0] {
%0 = arith.constant 2 : i32
}
func.func private @callee5(%arg0: i32) -> (i32) {
return %arg0: i32
}
}
11 changes: 11 additions & 0 deletions test/Quake/arg_subst-6.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// ========================================================================== //
// Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. //
// All rights reserved. //
// //
// This source code and the accompanying materials are made available under //
// the terms of the Apache License 2.0 which accompanies this distribution. //
// ========================================================================== //

cc.arg_subst[0] {
%c4_i64 = arith.constant 4 : i32
}
16 changes: 14 additions & 2 deletions test/Quake/arg_subst_func.qke
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
// the terms of the Apache License 2.0 which accompanies this distribution. //
// ========================================================================== //

// RUN: cudaq-opt --argument-synthesis=functions=foo:%S/arg_subst.txt,blink:%S/arg_subst.txt,testy1:%S/arg_subst-1.txt,testy2:%S/arg_subst-2.txt,testy3:%S/arg_subst-3.txt,testy4:%S/arg_subst-4.txt --canonicalize %s | FileCheck %s

// RUN: cudaq-opt --argument-synthesis=functions=foo:%S/arg_subst.txt,blink:%S/arg_subst.txt,testy1:%S/arg_subst-1.txt,testy2:%S/arg_subst-2.txt,testy3:%S/arg_subst-3.txt,testy4:%S/arg_subst-4.txt,testy5:%S/arg_subst-5.txt,callee5:%S/arg_subst-6.txt --canonicalize %s | FileCheck %s
func.func private @bar(i32)
func.func private @baz(f32)

Expand Down Expand Up @@ -146,3 +145,16 @@ func.func @testy4(%arg0: !cc.stdvec<!cc.struct<{i32, f64, i8, i16}>>) {
// CHECK: call @callee4(%[[VAL_32]]) : (!cc.stdvec<!cc.struct<{i32, f64, i8, i16}>>) -> ()
// CHECK: return
// CHECK: }

func.func @testy5(%arg0: i32) -> i32 {
return %arg0: i32
}

// CHECK-LABEL: func.func @testy5() -> i32 {
// CHECK: %[[VAL_0:.*]] = arith.constant 2 : i32
// CHECK: return %[[VAL_0]] : i32
// CHECK: }
// CHECK-LABEL: func.func private @callee5() -> i32 {
// CHECK: %[[VAL_0:.*]] = arith.constant 4 : i32
// CHECK: return %[[VAL_0]] : i32
// CHECK: }
Loading