Skip to content

Commit ebd48b3

Browse files
committed
Start to add return to output logging pass.
Add test. Add struct test. Signed-off-by: Eric Schweitz <eschweitz@nvidia.com>
1 parent ce146cc commit ebd48b3

File tree

6 files changed

+533
-0
lines changed

6 files changed

+533
-0
lines changed

include/cudaq/Optimizer/CodeGen/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,5 +271,19 @@ def QuakeToQIRAPIPrep : Pass<"quake-to-qir-api-prep", "mlir::ModuleOp"> {
271271
];
272272
}
273273

274+
def ReturnToOutputLog : Pass<"return-to-output-log", "mlir::ModuleOp"> {
275+
let summary = "Convert a kernel to be compatible with cudaq::run().";
276+
let description = [{
277+
When the target supports the cudaq::run() launch function, the kernel's
278+
return value(s) are translated into QIR output logging functions. This
279+
conversion allows the kernel to be executed as a group of shots on the QPU
280+
and a log file produced for the data produced by each kernel execution.
281+
Effectively, this allows for the benefit of running the kernel as a batch
282+
of executions and eliminating the overhead of executing the kernel one at
283+
a time with all the interprocessor overhead.
284+
}];
285+
let dependentDialects = ["cudaq::cc::CCDialect", "mlir::func::FuncDialect"];
286+
}
287+
274288

275289
#endif // CUDAQ_OPT_OPTIMIZER_CODEGEN_PASSES

include/cudaq/Optimizer/CodeGen/QIRFunctionNames.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,16 @@ static constexpr const char QIRRecordOutput[] =
9494
static constexpr const char QIRClearResultMaps[] =
9595
"__quantum__rt__clear_result_maps";
9696

97+
// Output logging function names.
98+
static constexpr const char QIRBoolRecordOutput[] =
99+
"__quantum__rt__bool_record_output";
100+
static constexpr const char QIRIntegerRecordOutput[] =
101+
"__quantum__rt__integer_record_output";
102+
static constexpr const char QIRDoubleRecordOutput[] =
103+
"__quantum__rt__double_record_output";
104+
static constexpr const char QIRTupleRecordOutput[] =
105+
"__quantum__rt__tuple_record_output";
106+
static constexpr const char QIRArrayRecordOutput[] =
107+
"__quantum__rt__array_record_output";
108+
97109
} // namespace cudaq::opt

lib/Optimizer/Builder/Intrinsics.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,17 @@ static constexpr IntrinsicCode intrinsicTable[] = {
519519
!qir_llvmptr = !llvm.ptr<i8>
520520
)#"},
521521

522+
// The QIR defined output logging functions.
523+
{"qir_output_logging",
524+
{},
525+
R"#(
526+
func.func private @__quantum__rt__bool_record_output(i1, !cc.ptr<i8>)
527+
func.func private @__quantum__rt__integer_record_output(i64, !cc.ptr<i8>)
528+
func.func private @__quantum__rt__double_record_output(f64, !cc.ptr<i8>)
529+
func.func private @__quantum__rt__tuple_record_output(i64, !cc.ptr<i8>)
530+
func.func private @__quantum__rt__array_record_output(i64, !cc.ptr<i8>)
531+
)#"},
532+
522533
// streamlinedLaunchKernel(kernelName, vectorArgPtrs)
523534
{cudaq::runtime::launchKernelStreamlinedFuncName,
524535
{},

lib/Optimizer/CodeGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ add_cudaq_library(OptCodeGen
2626
QuakeToExecMgr.cpp
2727
QuakeToLLVM.cpp
2828
RemoveMeasurements.cpp
29+
ReturnToOutputLog.cpp
2930
TranslateToIQMJson.cpp
3031
TranslateToOpenQASM.cpp
3132
VerifyNVQIRCalls.cpp
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
/*******************************************************************************
2+
* Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. *
3+
* All rights reserved. *
4+
* *
5+
* This source code and the accompanying materials are made available under *
6+
* the terms of the Apache License 2.0 which accompanies this distribution. *
7+
******************************************************************************/
8+
9+
#include "PassDetails.h"
10+
#include "cudaq/Optimizer/Builder/Factory.h"
11+
#include "cudaq/Optimizer/Builder/Intrinsics.h"
12+
#include "cudaq/Optimizer/CodeGen/Passes.h"
13+
#include "cudaq/Optimizer/CodeGen/QIRAttributeNames.h"
14+
#include "cudaq/Optimizer/CodeGen/QIRFunctionNames.h"
15+
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
16+
#include "cudaq/Optimizer/Dialect/CC/CCTypes.h"
17+
#include "llvm/ADT/TypeSwitch.h"
18+
#include "mlir/Transforms/DialectConversion.h"
19+
#include "mlir/Transforms/Passes.h"
20+
21+
#define DEBUG_TYPE "return-to-output-log"
22+
23+
namespace cudaq::opt {
24+
#define GEN_PASS_DEF_RETURNTOOUTPUTLOG
25+
#include "cudaq/Optimizer/CodeGen/Passes.h.inc"
26+
} // namespace cudaq::opt
27+
28+
using namespace mlir;
29+
30+
namespace {
31+
class FuncSignature : public OpRewritePattern<func::FuncOp> {
32+
public:
33+
using OpRewritePattern::OpRewritePattern;
34+
35+
// Simple type conversion: drop the result type on the floor.
36+
LogicalResult matchAndRewrite(func::FuncOp fn,
37+
PatternRewriter &rewriter) const override {
38+
auto *ctx = rewriter.getContext();
39+
auto inputTys = fn.getFunctionType().getInputs();
40+
auto funcTy = FunctionType::get(ctx, inputTys, {});
41+
rewriter.updateRootInPlace(fn, [&]() { fn.setFunctionType(funcTy); });
42+
return success();
43+
}
44+
};
45+
46+
class CallRewrite : public OpRewritePattern<func::CallOp> {
47+
public:
48+
using OpRewritePattern::OpRewritePattern;
49+
50+
// It should be a violation of the CUDA-Q spec to call an entry-point function
51+
// that returns a value from another entry-point function and use the result
52+
// value(s). Under a run context, no entry-point kernel will actually return a
53+
// value.
54+
LogicalResult matchAndRewrite(func::CallOp call,
55+
PatternRewriter &rewriter) const override {
56+
auto loc = call.getLoc();
57+
rewriter.create<func::CallOp>(loc, TypeRange{}, call.getCallee(),
58+
call.getOperands());
59+
SmallVector<Value> poisons;
60+
for (auto ty : call.getResultTypes())
61+
poisons.push_back(rewriter.create<cudaq::cc::PoisonOp>(loc, ty));
62+
rewriter.replaceOp(call, poisons);
63+
return success();
64+
}
65+
};
66+
67+
class ReturnRewrite : public OpRewritePattern<func::ReturnOp> {
68+
public:
69+
using OpRewritePattern::OpRewritePattern;
70+
71+
// This is where the heavy lifting is done. We take the return op's operand(s)
72+
// and convert them to calls to the QIR output logging functions with the
73+
// appropriate label information.
74+
LogicalResult matchAndRewrite(func::ReturnOp ret,
75+
PatternRewriter &rewriter) const override {
76+
auto loc = ret.getLoc();
77+
// For each operand:
78+
for (auto operand : ret.getOperands())
79+
genOutputLog(loc, rewriter, operand, std::nullopt);
80+
rewriter.replaceOpWithNewOp<func::ReturnOp>(ret);
81+
return success();
82+
}
83+
84+
static void genOutputLog(Location loc, PatternRewriter &rewriter, Value val,
85+
std::optional<StringRef> prefix) {
86+
Type valTy = val.getType();
87+
TypeSwitch<Type>(valTy)
88+
.Case([&](IntegerType intTy) {
89+
std::string labelStr{"i" + std::to_string(intTy.getWidth())};
90+
if (prefix)
91+
labelStr = prefix->str();
92+
Value label = makeLabel(loc, rewriter, labelStr);
93+
if (intTy.getWidth() == 1) {
94+
rewriter.create<func::CallOp>(loc, TypeRange{},
95+
cudaq::opt::QIRBoolRecordOutput,
96+
ArrayRef<Value>{val, label});
97+
return;
98+
}
99+
// Integer: convert to (signed) i64. The decoder *must* lop off any
100+
// higher-order bits added by the sign-extension to get this to 64
101+
// bits by examining the real integer type.
102+
Value castVal = val;
103+
if (intTy.getWidth() < 64)
104+
castVal = rewriter.create<cudaq::cc::CastOp>(
105+
loc, rewriter.getI64Type(), val, cudaq::cc::CastOpMode::Signed);
106+
else if (intTy.getWidth() > 64)
107+
castVal = rewriter.create<cudaq::cc::CastOp>(
108+
loc, rewriter.getI64Type(), val);
109+
rewriter.create<func::CallOp>(loc, TypeRange{},
110+
cudaq::opt::QIRIntegerRecordOutput,
111+
ArrayRef<Value>{castVal, label});
112+
})
113+
.Case([&](FloatType fltTy) {
114+
std::string labelStr{"f" + std::to_string(fltTy.getWidth())};
115+
if (prefix)
116+
labelStr = prefix->str();
117+
Value label = makeLabel(loc, rewriter, labelStr);
118+
// Floating point: convert it to double, whatever it actually is.
119+
Value castVal = val;
120+
if (fltTy != rewriter.getF64Type())
121+
castVal = rewriter.create<cudaq::cc::CastOp>(
122+
loc, rewriter.getF64Type(), val);
123+
rewriter.create<func::CallOp>(loc, TypeRange{},
124+
cudaq::opt::QIRDoubleRecordOutput,
125+
ArrayRef<Value>{castVal, label});
126+
})
127+
.Case([&](cudaq::cc::StructType strTy) {
128+
auto labelStr = translateType(strTy);
129+
if (prefix)
130+
labelStr = prefix->str();
131+
Value label = makeLabel(loc, rewriter, labelStr);
132+
std::int32_t sz = strTy.getNumMembers();
133+
Value size = rewriter.create<arith::ConstantIntOp>(loc, sz, 64);
134+
rewriter.create<func::CallOp>(loc, TypeRange{},
135+
cudaq::opt::QIRTupleRecordOutput,
136+
ArrayRef<Value>{size, label});
137+
std::string preStr = prefix ? prefix->str() : std::string{};
138+
for (std::int32_t i = 0; i < sz; ++i) {
139+
std::string offset = preStr + '.' + std::to_string(i);
140+
Value w = rewriter.create<cudaq::cc::ExtractValueOp>(
141+
loc, strTy.getMember(i), val,
142+
ArrayRef<cudaq::cc::ExtractValueArg>{i});
143+
genOutputLog(loc, rewriter, w, offset);
144+
}
145+
})
146+
.Case([&](cudaq::cc::ArrayType arrTy) {
147+
auto labelStr = translateType(arrTy);
148+
Value label = makeLabel(loc, rewriter, labelStr);
149+
std::int32_t sz = arrTy.getSize();
150+
Value size = rewriter.create<arith::ConstantIntOp>(loc, sz, 64);
151+
rewriter.create<func::CallOp>(loc, TypeRange{},
152+
cudaq::opt::QIRArrayRecordOutput,
153+
ArrayRef<Value>{size, label});
154+
std::string preStr = prefix ? prefix->str() : std::string{};
155+
for (std::int32_t i = 0; i < sz; ++i) {
156+
std::string offset = preStr + '[' + std::to_string(i) + ']';
157+
Value w = rewriter.create<cudaq::cc::ExtractValueOp>(
158+
loc, arrTy.getElementType(), val,
159+
ArrayRef<cudaq::cc::ExtractValueArg>{i});
160+
genOutputLog(loc, rewriter, w, offset);
161+
}
162+
})
163+
.Case([&](cudaq::cc::StdvecType vecTy) {
164+
// For this type, we expect a cc.stdvec_init operation as the input.
165+
// The data will be in a variable.
166+
// If we reach here and we cannot determine the constant size of the
167+
// buffer, then we will not generate any output logging.
168+
if (auto vecInit = val.getDefiningOp<cudaq::cc::StdvecInitOp>())
169+
if (auto maybeLen = cudaq::opt::factory::maybeValueOfIntConstant(
170+
vecInit.getLength())) {
171+
std::int32_t sz = *maybeLen;
172+
auto labelStr = translateType(vecTy, sz);
173+
Value label = makeLabel(loc, rewriter, labelStr);
174+
Value size = rewriter.create<arith::ConstantIntOp>(loc, sz, 64);
175+
rewriter.create<func::CallOp>(loc, TypeRange{},
176+
cudaq::opt::QIRArrayRecordOutput,
177+
ArrayRef<Value>{size, label});
178+
std::string preStr = prefix ? prefix->str() : std::string{};
179+
auto rawBuffer = vecInit.getBuffer();
180+
auto buffTy = cast<cudaq::cc::PointerType>(rawBuffer.getType());
181+
Type ptrArrTy = buffTy;
182+
if (!isa<cudaq::cc::ArrayType>(buffTy.getElementType()))
183+
ptrArrTy = cudaq::cc::PointerType::get(
184+
cudaq::cc::ArrayType::get(buffTy.getElementType()));
185+
Value buffer =
186+
rewriter.create<cudaq::cc::CastOp>(loc, ptrArrTy, rawBuffer);
187+
for (std::int32_t i = 0; i < sz; ++i) {
188+
std::string offset = preStr + '[' + std::to_string(i) + ']';
189+
auto v = rewriter.create<cudaq::cc::ComputePtrOp>(
190+
loc, buffTy, buffer, ArrayRef<cudaq::cc::ComputePtrArg>{i});
191+
Value w = rewriter.create<cudaq::cc::LoadOp>(loc, v);
192+
genOutputLog(loc, rewriter, w, offset);
193+
}
194+
}
195+
});
196+
}
197+
198+
static std::string
199+
translateType(Type ty, std::optional<std::int32_t> vecSz = std::nullopt) {
200+
if (auto intTy = dyn_cast<IntegerType>(ty))
201+
return "i" + std::to_string(intTy.getWidth());
202+
if (auto fltTy = dyn_cast<FloatType>(ty))
203+
return "f" + std::to_string(fltTy.getWidth());
204+
if (auto strTy = dyn_cast<cudaq::cc::StructType>(ty)) {
205+
if (strTy.getMembers().empty())
206+
return "{}";
207+
std::string result = "{" + translateType(strTy.getMembers().front());
208+
for (auto memTy : strTy.getMembers().drop_front())
209+
result += ", " + translateType(memTy);
210+
return result + "}";
211+
}
212+
if (auto arrTy = dyn_cast<cudaq::cc::ArrayType>(ty)) {
213+
return "[" + std::to_string(arrTy.getSize()) + " x " +
214+
translateType(arrTy.getElementType()) + "]";
215+
}
216+
if (auto arrTy = dyn_cast<cudaq::cc::StdvecType>(ty)) {
217+
return "[" + std::to_string(*vecSz) + " x " +
218+
translateType(arrTy.getElementType()) + "]";
219+
}
220+
return "error";
221+
}
222+
223+
static Value makeLabel(Location loc, PatternRewriter &rewriter,
224+
StringRef label) {
225+
auto strLitTy = cudaq::cc::PointerType::get(cudaq::cc::ArrayType::get(
226+
rewriter.getContext(), rewriter.getI8Type(), label.size() + 1));
227+
Value lit = rewriter.create<cudaq::cc::CreateStringLiteralOp>(
228+
loc, strLitTy, rewriter.getStringAttr(label));
229+
auto i8PtrTy = cudaq::cc::PointerType::get(rewriter.getI8Type());
230+
return rewriter.create<cudaq::cc::CastOp>(loc, i8PtrTy, lit);
231+
}
232+
};
233+
234+
struct ReturnToOutputLogPass
235+
: public cudaq::opt::impl::ReturnToOutputLogBase<ReturnToOutputLogPass> {
236+
using ReturnToOutputLogBase::ReturnToOutputLogBase;
237+
238+
void runOnOperation() override {
239+
auto module = getOperation();
240+
auto *ctx = &getContext();
241+
auto irBuilder = cudaq::IRBuilder::atBlockEnd(module.getBody());
242+
if (failed(irBuilder.loadIntrinsic(module, "qir_output_logging"))) {
243+
module.emitError("could not load QIR output logging declarations.");
244+
signalPassFailure();
245+
return;
246+
}
247+
248+
RewritePatternSet patterns(ctx);
249+
patterns.insert<CallRewrite, FuncSignature, ReturnRewrite>(ctx);
250+
LLVM_DEBUG(llvm::dbgs() << "Before return to output logging:\n" << module);
251+
ConversionTarget target(*ctx);
252+
target.addLegalDialect<arith::ArithDialect, cudaq::cc::CCDialect,
253+
func::FuncDialect>();
254+
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp fn) {
255+
// Legal unless an entry-point function, with a body, that returns a
256+
// value.
257+
return fn.getBody().empty() || !fn->hasAttr(cudaq::entryPointAttrName) ||
258+
fn.getFunctionType().getResults().empty();
259+
});
260+
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp call) {
261+
// Legal unless calling an entry-point function with a result.
262+
if (auto module = call->getParentOfType<ModuleOp>()) {
263+
auto callee = call.getCallee();
264+
if (auto fn = module.lookupSymbol<func::FuncOp>(callee)) {
265+
return fn.getBody().empty() ||
266+
!fn->hasAttr(cudaq::entryPointAttrName) ||
267+
fn.getFunctionType().getResults().empty();
268+
}
269+
}
270+
return true;
271+
});
272+
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp ret) {
273+
// Legal if return is not in an entry-point or does not return a value.
274+
if (auto fn = ret->getParentOfType<func::FuncOp>())
275+
return !fn->hasAttr(cudaq::entryPointAttrName) ||
276+
ret.getOperands().empty();
277+
return true;
278+
});
279+
if (failed(applyPartialConversion(module, target, std::move(patterns))))
280+
signalPassFailure();
281+
LLVM_DEBUG(llvm::dbgs() << "After return to output logging:\n" << module);
282+
}
283+
};
284+
285+
} // namespace

0 commit comments

Comments
 (0)