|
| 1 | +/******************************************************************************* |
| 2 | + * Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. * |
| 3 | + * All rights reserved. * |
| 4 | + * * |
| 5 | + * This source code and the accompanying materials are made available under * |
| 6 | + * the terms of the Apache License 2.0 which accompanies this distribution. * |
| 7 | + ******************************************************************************/ |
| 8 | + |
| 9 | +#include "PassDetails.h" |
| 10 | +#include "cudaq/Optimizer/Builder/Factory.h" |
| 11 | +#include "cudaq/Optimizer/Builder/Intrinsics.h" |
| 12 | +#include "cudaq/Optimizer/CodeGen/Passes.h" |
| 13 | +#include "cudaq/Optimizer/CodeGen/QIRAttributeNames.h" |
| 14 | +#include "cudaq/Optimizer/CodeGen/QIRFunctionNames.h" |
| 15 | +#include "cudaq/Optimizer/Dialect/CC/CCOps.h" |
| 16 | +#include "cudaq/Optimizer/Dialect/CC/CCTypes.h" |
| 17 | +#include "llvm/ADT/TypeSwitch.h" |
| 18 | +#include "mlir/Transforms/DialectConversion.h" |
| 19 | +#include "mlir/Transforms/Passes.h" |
| 20 | + |
| 21 | +#define DEBUG_TYPE "return-to-output-log" |
| 22 | + |
| 23 | +namespace cudaq::opt { |
| 24 | +#define GEN_PASS_DEF_RETURNTOOUTPUTLOG |
| 25 | +#include "cudaq/Optimizer/CodeGen/Passes.h.inc" |
| 26 | +} // namespace cudaq::opt |
| 27 | + |
| 28 | +using namespace mlir; |
| 29 | + |
| 30 | +namespace { |
| 31 | +class FuncSignature : public OpRewritePattern<func::FuncOp> { |
| 32 | +public: |
| 33 | + using OpRewritePattern::OpRewritePattern; |
| 34 | + |
| 35 | + // Simple type conversion: drop the result type on the floor. |
| 36 | + LogicalResult matchAndRewrite(func::FuncOp fn, |
| 37 | + PatternRewriter &rewriter) const override { |
| 38 | + auto *ctx = rewriter.getContext(); |
| 39 | + auto inputTys = fn.getFunctionType().getInputs(); |
| 40 | + auto funcTy = FunctionType::get(ctx, inputTys, {}); |
| 41 | + rewriter.updateRootInPlace(fn, [&]() { fn.setFunctionType(funcTy); }); |
| 42 | + return success(); |
| 43 | + } |
| 44 | +}; |
| 45 | + |
| 46 | +class CallRewrite : public OpRewritePattern<func::CallOp> { |
| 47 | +public: |
| 48 | + using OpRewritePattern::OpRewritePattern; |
| 49 | + |
| 50 | + // It should be a violation of the CUDA-Q spec to call an entry-point function |
| 51 | + // that returns a value from another entry-point function and use the result |
| 52 | + // value(s). Under a run context, no entry-point kernel will actually return a |
| 53 | + // value. |
| 54 | + LogicalResult matchAndRewrite(func::CallOp call, |
| 55 | + PatternRewriter &rewriter) const override { |
| 56 | + auto loc = call.getLoc(); |
| 57 | + rewriter.create<func::CallOp>(loc, TypeRange{}, call.getCallee(), |
| 58 | + call.getOperands()); |
| 59 | + SmallVector<Value> poisons; |
| 60 | + for (auto ty : call.getResultTypes()) |
| 61 | + poisons.push_back(rewriter.create<cudaq::cc::PoisonOp>(loc, ty)); |
| 62 | + rewriter.replaceOp(call, poisons); |
| 63 | + return success(); |
| 64 | + } |
| 65 | +}; |
| 66 | + |
| 67 | +class ReturnRewrite : public OpRewritePattern<func::ReturnOp> { |
| 68 | +public: |
| 69 | + using OpRewritePattern::OpRewritePattern; |
| 70 | + |
| 71 | + // This is where the heavy lifting is done. We take the return op's operand(s) |
| 72 | + // and convert them to calls to the QIR output logging functions with the |
| 73 | + // appropriate label information. |
| 74 | + LogicalResult matchAndRewrite(func::ReturnOp ret, |
| 75 | + PatternRewriter &rewriter) const override { |
| 76 | + auto loc = ret.getLoc(); |
| 77 | + // For each operand: |
| 78 | + for (auto operand : ret.getOperands()) |
| 79 | + genOutputLog(loc, rewriter, operand, std::nullopt); |
| 80 | + rewriter.replaceOpWithNewOp<func::ReturnOp>(ret); |
| 81 | + return success(); |
| 82 | + } |
| 83 | + |
| 84 | + static void genOutputLog(Location loc, PatternRewriter &rewriter, Value val, |
| 85 | + std::optional<StringRef> prefix) { |
| 86 | + Type valTy = val.getType(); |
| 87 | + TypeSwitch<Type>(valTy) |
| 88 | + .Case([&](IntegerType intTy) { |
| 89 | + std::string labelStr{"i" + std::to_string(intTy.getWidth())}; |
| 90 | + if (prefix) |
| 91 | + labelStr = prefix->str(); |
| 92 | + Value label = makeLabel(loc, rewriter, labelStr); |
| 93 | + if (intTy.getWidth() == 1) { |
| 94 | + rewriter.create<func::CallOp>(loc, TypeRange{}, |
| 95 | + cudaq::opt::QIRBoolRecordOutput, |
| 96 | + ArrayRef<Value>{val, label}); |
| 97 | + return; |
| 98 | + } |
| 99 | + // Integer: convert to (signed) i64. The decoder *must* lop off any |
| 100 | + // higher-order bits added by the sign-extension to get this to 64 |
| 101 | + // bits by examining the real integer type. |
| 102 | + Value castVal = val; |
| 103 | + if (intTy.getWidth() < 64) |
| 104 | + castVal = rewriter.create<cudaq::cc::CastOp>( |
| 105 | + loc, rewriter.getI64Type(), val, cudaq::cc::CastOpMode::Signed); |
| 106 | + else if (intTy.getWidth() > 64) |
| 107 | + castVal = rewriter.create<cudaq::cc::CastOp>( |
| 108 | + loc, rewriter.getI64Type(), val); |
| 109 | + rewriter.create<func::CallOp>(loc, TypeRange{}, |
| 110 | + cudaq::opt::QIRIntegerRecordOutput, |
| 111 | + ArrayRef<Value>{castVal, label}); |
| 112 | + }) |
| 113 | + .Case([&](FloatType fltTy) { |
| 114 | + std::string labelStr{"f" + std::to_string(fltTy.getWidth())}; |
| 115 | + if (prefix) |
| 116 | + labelStr = prefix->str(); |
| 117 | + Value label = makeLabel(loc, rewriter, labelStr); |
| 118 | + // Floating point: convert it to double, whatever it actually is. |
| 119 | + Value castVal = val; |
| 120 | + if (fltTy != rewriter.getF64Type()) |
| 121 | + castVal = rewriter.create<cudaq::cc::CastOp>( |
| 122 | + loc, rewriter.getF64Type(), val); |
| 123 | + rewriter.create<func::CallOp>(loc, TypeRange{}, |
| 124 | + cudaq::opt::QIRDoubleRecordOutput, |
| 125 | + ArrayRef<Value>{castVal, label}); |
| 126 | + }) |
| 127 | + .Case([&](cudaq::cc::StructType strTy) { |
| 128 | + auto labelStr = translateType(strTy); |
| 129 | + if (prefix) |
| 130 | + labelStr = prefix->str(); |
| 131 | + Value label = makeLabel(loc, rewriter, labelStr); |
| 132 | + std::int32_t sz = strTy.getNumMembers(); |
| 133 | + Value size = rewriter.create<arith::ConstantIntOp>(loc, sz, 64); |
| 134 | + rewriter.create<func::CallOp>(loc, TypeRange{}, |
| 135 | + cudaq::opt::QIRTupleRecordOutput, |
| 136 | + ArrayRef<Value>{size, label}); |
| 137 | + std::string preStr = prefix ? prefix->str() : std::string{}; |
| 138 | + for (std::int32_t i = 0; i < sz; ++i) { |
| 139 | + std::string offset = preStr + '.' + std::to_string(i); |
| 140 | + Value w = rewriter.create<cudaq::cc::ExtractValueOp>( |
| 141 | + loc, strTy.getMember(i), val, |
| 142 | + ArrayRef<cudaq::cc::ExtractValueArg>{i}); |
| 143 | + genOutputLog(loc, rewriter, w, offset); |
| 144 | + } |
| 145 | + }) |
| 146 | + .Case([&](cudaq::cc::ArrayType arrTy) { |
| 147 | + auto labelStr = translateType(arrTy); |
| 148 | + Value label = makeLabel(loc, rewriter, labelStr); |
| 149 | + std::int32_t sz = arrTy.getSize(); |
| 150 | + Value size = rewriter.create<arith::ConstantIntOp>(loc, sz, 64); |
| 151 | + rewriter.create<func::CallOp>(loc, TypeRange{}, |
| 152 | + cudaq::opt::QIRArrayRecordOutput, |
| 153 | + ArrayRef<Value>{size, label}); |
| 154 | + std::string preStr = prefix ? prefix->str() : std::string{}; |
| 155 | + for (std::int32_t i = 0; i < sz; ++i) { |
| 156 | + std::string offset = preStr + '[' + std::to_string(i) + ']'; |
| 157 | + Value w = rewriter.create<cudaq::cc::ExtractValueOp>( |
| 158 | + loc, arrTy.getElementType(), val, |
| 159 | + ArrayRef<cudaq::cc::ExtractValueArg>{i}); |
| 160 | + genOutputLog(loc, rewriter, w, offset); |
| 161 | + } |
| 162 | + }) |
| 163 | + .Case([&](cudaq::cc::StdvecType vecTy) { |
| 164 | + // For this type, we expect a cc.stdvec_init operation as the input. |
| 165 | + // The data will be in a variable. |
| 166 | + // If we reach here and we cannot determine the constant size of the |
| 167 | + // buffer, then we will not generate any output logging. |
| 168 | + if (auto vecInit = val.getDefiningOp<cudaq::cc::StdvecInitOp>()) |
| 169 | + if (auto maybeLen = cudaq::opt::factory::maybeValueOfIntConstant( |
| 170 | + vecInit.getLength())) { |
| 171 | + std::int32_t sz = *maybeLen; |
| 172 | + auto labelStr = translateType(vecTy, sz); |
| 173 | + Value label = makeLabel(loc, rewriter, labelStr); |
| 174 | + Value size = rewriter.create<arith::ConstantIntOp>(loc, sz, 64); |
| 175 | + rewriter.create<func::CallOp>(loc, TypeRange{}, |
| 176 | + cudaq::opt::QIRArrayRecordOutput, |
| 177 | + ArrayRef<Value>{size, label}); |
| 178 | + std::string preStr = prefix ? prefix->str() : std::string{}; |
| 179 | + auto rawBuffer = vecInit.getBuffer(); |
| 180 | + auto buffTy = cast<cudaq::cc::PointerType>(rawBuffer.getType()); |
| 181 | + Type ptrArrTy = buffTy; |
| 182 | + if (!isa<cudaq::cc::ArrayType>(buffTy.getElementType())) |
| 183 | + ptrArrTy = cudaq::cc::PointerType::get( |
| 184 | + cudaq::cc::ArrayType::get(buffTy.getElementType())); |
| 185 | + Value buffer = |
| 186 | + rewriter.create<cudaq::cc::CastOp>(loc, ptrArrTy, rawBuffer); |
| 187 | + for (std::int32_t i = 0; i < sz; ++i) { |
| 188 | + std::string offset = preStr + '[' + std::to_string(i) + ']'; |
| 189 | + auto v = rewriter.create<cudaq::cc::ComputePtrOp>( |
| 190 | + loc, buffTy, buffer, ArrayRef<cudaq::cc::ComputePtrArg>{i}); |
| 191 | + Value w = rewriter.create<cudaq::cc::LoadOp>(loc, v); |
| 192 | + genOutputLog(loc, rewriter, w, offset); |
| 193 | + } |
| 194 | + } |
| 195 | + }); |
| 196 | + } |
| 197 | + |
| 198 | + static std::string |
| 199 | + translateType(Type ty, std::optional<std::int32_t> vecSz = std::nullopt) { |
| 200 | + if (auto intTy = dyn_cast<IntegerType>(ty)) |
| 201 | + return "i" + std::to_string(intTy.getWidth()); |
| 202 | + if (auto fltTy = dyn_cast<FloatType>(ty)) |
| 203 | + return "f" + std::to_string(fltTy.getWidth()); |
| 204 | + if (auto strTy = dyn_cast<cudaq::cc::StructType>(ty)) { |
| 205 | + if (strTy.getMembers().empty()) |
| 206 | + return "{}"; |
| 207 | + std::string result = "{" + translateType(strTy.getMembers().front()); |
| 208 | + for (auto memTy : strTy.getMembers().drop_front()) |
| 209 | + result += ", " + translateType(memTy); |
| 210 | + return result + "}"; |
| 211 | + } |
| 212 | + if (auto arrTy = dyn_cast<cudaq::cc::ArrayType>(ty)) { |
| 213 | + return "[" + std::to_string(arrTy.getSize()) + " x " + |
| 214 | + translateType(arrTy.getElementType()) + "]"; |
| 215 | + } |
| 216 | + if (auto arrTy = dyn_cast<cudaq::cc::StdvecType>(ty)) { |
| 217 | + return "[" + std::to_string(*vecSz) + " x " + |
| 218 | + translateType(arrTy.getElementType()) + "]"; |
| 219 | + } |
| 220 | + return "error"; |
| 221 | + } |
| 222 | + |
| 223 | + static Value makeLabel(Location loc, PatternRewriter &rewriter, |
| 224 | + StringRef label) { |
| 225 | + auto strLitTy = cudaq::cc::PointerType::get(cudaq::cc::ArrayType::get( |
| 226 | + rewriter.getContext(), rewriter.getI8Type(), label.size() + 1)); |
| 227 | + Value lit = rewriter.create<cudaq::cc::CreateStringLiteralOp>( |
| 228 | + loc, strLitTy, rewriter.getStringAttr(label)); |
| 229 | + auto i8PtrTy = cudaq::cc::PointerType::get(rewriter.getI8Type()); |
| 230 | + return rewriter.create<cudaq::cc::CastOp>(loc, i8PtrTy, lit); |
| 231 | + } |
| 232 | +}; |
| 233 | + |
| 234 | +struct ReturnToOutputLogPass |
| 235 | + : public cudaq::opt::impl::ReturnToOutputLogBase<ReturnToOutputLogPass> { |
| 236 | + using ReturnToOutputLogBase::ReturnToOutputLogBase; |
| 237 | + |
| 238 | + void runOnOperation() override { |
| 239 | + auto module = getOperation(); |
| 240 | + auto *ctx = &getContext(); |
| 241 | + auto irBuilder = cudaq::IRBuilder::atBlockEnd(module.getBody()); |
| 242 | + if (failed(irBuilder.loadIntrinsic(module, "qir_output_logging"))) { |
| 243 | + module.emitError("could not load QIR output logging declarations."); |
| 244 | + signalPassFailure(); |
| 245 | + return; |
| 246 | + } |
| 247 | + |
| 248 | + RewritePatternSet patterns(ctx); |
| 249 | + patterns.insert<CallRewrite, FuncSignature, ReturnRewrite>(ctx); |
| 250 | + LLVM_DEBUG(llvm::dbgs() << "Before return to output logging:\n" << module); |
| 251 | + ConversionTarget target(*ctx); |
| 252 | + target.addLegalDialect<arith::ArithDialect, cudaq::cc::CCDialect, |
| 253 | + func::FuncDialect>(); |
| 254 | + target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp fn) { |
| 255 | + // Legal unless an entry-point function, with a body, that returns a |
| 256 | + // value. |
| 257 | + return fn.getBody().empty() || !fn->hasAttr(cudaq::entryPointAttrName) || |
| 258 | + fn.getFunctionType().getResults().empty(); |
| 259 | + }); |
| 260 | + target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp call) { |
| 261 | + // Legal unless calling an entry-point function with a result. |
| 262 | + if (auto module = call->getParentOfType<ModuleOp>()) { |
| 263 | + auto callee = call.getCallee(); |
| 264 | + if (auto fn = module.lookupSymbol<func::FuncOp>(callee)) { |
| 265 | + return fn.getBody().empty() || |
| 266 | + !fn->hasAttr(cudaq::entryPointAttrName) || |
| 267 | + fn.getFunctionType().getResults().empty(); |
| 268 | + } |
| 269 | + } |
| 270 | + return true; |
| 271 | + }); |
| 272 | + target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp ret) { |
| 273 | + // Legal if return is not in an entry-point or does not return a value. |
| 274 | + if (auto fn = ret->getParentOfType<func::FuncOp>()) |
| 275 | + return !fn->hasAttr(cudaq::entryPointAttrName) || |
| 276 | + ret.getOperands().empty(); |
| 277 | + return true; |
| 278 | + }); |
| 279 | + if (failed(applyPartialConversion(module, target, std::move(patterns)))) |
| 280 | + signalPassFailure(); |
| 281 | + LLVM_DEBUG(llvm::dbgs() << "After return to output logging:\n" << module); |
| 282 | + } |
| 283 | +}; |
| 284 | + |
| 285 | +} // namespace |
0 commit comments