diff --git a/include/cudaq/Optimizer/Transforms/Passes.td b/include/cudaq/Optimizer/Transforms/Passes.td index 1ccec4a4278..98dd645729b 100644 --- a/include/cudaq/Optimizer/Transforms/Passes.td +++ b/include/cudaq/Optimizer/Transforms/Passes.td @@ -623,6 +623,8 @@ def LoopUnroll : Pass<"cc-loop-unroll"> { Option<"signalFailure", "signal-failure-if-any-loop-cannot-be-completely-unrolled", "bool", /*default=*/"false", "Signal failure if pass can't unroll all loops.">, + Option<"allowClosedInterval", "allow-closed-interval", "bool", + /*default=*/"true", "Allow loop iterations on a closed interval.">, Option<"allowBreak", "allow-early-exit", "bool", /*default=*/"false", "Allow unrolling of loop with early exit (i.e. break statement)."> ]; diff --git a/lib/Optimizer/Transforms/LiftArrayAlloc.cpp b/lib/Optimizer/Transforms/LiftArrayAlloc.cpp index 0f7647b5796..70113cf3e9f 100644 --- a/lib/Optimizer/Transforms/LiftArrayAlloc.cpp +++ b/lib/Optimizer/Transforms/LiftArrayAlloc.cpp @@ -27,263 +27,9 @@ namespace cudaq::opt { using namespace mlir; -namespace { -class AllocaPattern : public OpRewritePattern { -public: - explicit AllocaPattern(MLIRContext *ctx, DominanceInfo &di, StringRef fn) - : OpRewritePattern(ctx), dom(di), funcName(fn) {} - - LogicalResult matchAndRewrite(cudaq::cc::AllocaOp alloc, - PatternRewriter &rewriter) const override { - SmallVector stores; - if (!isGoodCandidate(alloc, stores, dom)) - return failure(); - - LLVM_DEBUG(llvm::dbgs() << "Candidate was found\n"); - auto allocTy = alloc.getElementType(); - auto arrTy = cast(allocTy); - auto eleTy = arrTy.getElementType(); - - SmallVector values; - - // Every element of `stores` must be a cc::StoreOp with a ConstantOp as the - // value argument. Build the array attr to attach to a cc.const_array. - for (auto *op : stores) { - auto store = cast(op); - auto *valOp = store.getValue().getDefiningOp(); - if (auto con = dyn_cast(valOp)) - values.push_back(con.getValueAttr()); - else if (auto con = dyn_cast(valOp)) - values.push_back(con.getValueAttr()); - else - return alloc.emitOpError("could not fold"); - } - - // Create the cc.const_array. - auto valuesAttr = rewriter.getArrayAttr(values); - auto loc = alloc.getLoc(); - Value conArr = - rewriter.create(loc, arrTy, valuesAttr); - - assert(conArr && "must have created the constant array"); - LLVM_DEBUG(llvm::dbgs() << "constant array is:\n" << conArr << '\n'); - bool cannotEraseAlloc = false; - - // Collect all the stores, casts, and compute_ptr to be erased safely and in - // topological order. - SmallVector opsToErase; - auto insertOpToErase = [&](Operation *op) { - auto iter = std::find(opsToErase.begin(), opsToErase.end(), op); - if (iter == opsToErase.end()) - opsToErase.push_back(op); - }; - - // Rewalk all the uses of alloc, u, which must be cc.cast or cc.compute_ptr. - // For each u remove a store and replace a load with a cc.extract_value. - for (auto *user : alloc->getUsers()) { - if (!user) - continue; - std::int32_t offset = 0; - if (auto cptr = dyn_cast(user)) - offset = cptr.getRawConstantIndices()[0]; - bool isLive = false; - if (!isa(user)) { - cannotEraseAlloc = isLive = true; - } else { - for (auto *useuser : user->getUsers()) { - if (!useuser) - continue; - if (auto load = dyn_cast(useuser)) { - rewriter.setInsertionPointAfter(useuser); - LLVM_DEBUG(llvm::dbgs() << "replaced load\n"); - rewriter.replaceOpWithNewOp( - load, eleTy, conArr, - ArrayRef{offset}); - continue; - } - if (isa(useuser)) { - insertOpToErase(useuser); - continue; - } - LLVM_DEBUG(llvm::dbgs() << "alloc is live\n"); - cannotEraseAlloc = isLive = true; - } - } - if (!isLive) - insertOpToErase(user); - } - - for (auto *e : opsToErase) - rewriter.eraseOp(e); - - if (cannotEraseAlloc) { - rewriter.setInsertionPointAfter(alloc); - rewriter.create(loc, conArr, alloc); - return success(); - } - rewriter.eraseOp(alloc); - return success(); - } - - // Determine if \p alloc is a legit candidate for promotion to a constant - // array value. \p scoreboard is a vector of store operations. Each element of - // the allocated array must be written to exactly 1 time, and the scoreboard - // is used to track these stores. \p dom is the dominance info for this - // function (to ensure the stores happen before uses). - static bool isGoodCandidate(cudaq::cc::AllocaOp alloc, - SmallVectorImpl &scoreboard, - DominanceInfo &dom) { - LLVM_DEBUG(llvm::dbgs() << "checking candidate\n"); - if (alloc.getSeqSize()) - return false; - auto arrTy = dyn_cast(alloc.getElementType()); - if (!arrTy || arrTy.isUnknownSize()) - return false; - auto arrEleTy = arrTy.getElementType(); - if (!isa(arrEleTy)) - return false; - - // There must be at least `size` uses to initialize the entire array. - auto size = arrTy.getSize(); - if (std::distance(alloc->getUses().begin(), alloc->getUses().end()) < size) - return false; - - // Keep a scoreboard for every element in the array. Every element *must* be - // stored to with a constant exactly one time. - scoreboard.resize(size); - for (int i = 0; i < size; i++) - scoreboard[i] = nullptr; - - SmallVector toGlobalUses; - SmallVector> loadSets(size); - - auto getWriteOp = [&](auto op, std::int32_t index) -> Operation * { - Operation *theStore = nullptr; - for (auto &use : op->getUses()) { - Operation *u = use.getOwner(); - if (!u) - return nullptr; - if (auto store = dyn_cast(u)) { - if (op.getOperation() == store.getPtrvalue().getDefiningOp()) { - if (theStore) { - LLVM_DEBUG(llvm::dbgs() - << "more than 1 store to element of array\n"); - return nullptr; - } - LLVM_DEBUG(llvm::dbgs() << "found store: " << store << "\n"); - theStore = u; - } - continue; - } - if (isa(u)) { - toGlobalUses.push_back(u); - continue; - } - if (isa(u)) { - loadSets[index].insert(u); - continue; - } - return nullptr; - } - return theStore && - isa_and_present( - dyn_cast(theStore) - .getValue() - .getDefiningOp()) - ? theStore - : nullptr; - }; - - auto unsizedArrTy = cudaq::cc::ArrayType::get(arrEleTy); - auto ptrUnsizedArrTy = cudaq::cc::PointerType::get(unsizedArrTy); - auto ptrArrEleTy = cudaq::cc::PointerType::get(arrEleTy); - for (auto &use : alloc->getUses()) { - // All uses *must* be a degenerate cc.cast, cc.compute_ptr, or - // cc.init_state. - auto *op = use.getOwner(); - if (!op) { - LLVM_DEBUG(llvm::dbgs() << "use was not an op\n"); - return false; - } - if (auto cptr = dyn_cast(op)) { - if (auto index = cptr.getConstantIndex(0)) - if (auto w = getWriteOp(cptr, *index)) - if (!scoreboard[*index]) { - scoreboard[*index] = w; - continue; - } - return false; - } - if (auto cast = dyn_cast(op)) { - // Process casts that are used in store ops. - if (cast.getType() == ptrArrEleTy) { - if (auto w = getWriteOp(cast, 0)) - if (!scoreboard[0]) { - scoreboard[0] = w; - continue; - } - return false; - } - // Process casts that are used in quake.init_state. - if (cast.getType() == ptrUnsizedArrTy) { - if (cast->hasOneUse()) { - auto &use = *cast->getUses().begin(); - Operation *u = use.getOwner(); - if (isa_and_present(u)) { - toGlobalUses.push_back(op); - continue; - } - } - return false; - } - LLVM_DEBUG(llvm::dbgs() << "unexpected cast: " << *op << '\n'); - toGlobalUses.push_back(op); - continue; - } - if (isa(op)) { - toGlobalUses.push_back(op); - continue; - } - LLVM_DEBUG(llvm::dbgs() << "unexpected use: " << *op << '\n'); - toGlobalUses.push_back(op); - } - - bool ok = std::all_of(scoreboard.begin(), scoreboard.end(), - [](bool b) { return b; }); - LLVM_DEBUG(llvm::dbgs() << "all elements of array are set: " << ok << '\n'); - if (ok) { - // Verify dominance relations. - - // For all stores, the store of an element $e$ must dominate all loads of - // $e$. - for (int i = 0; i < size; ++i) { - for (auto *load : loadSets[i]) - if (!dom.dominates(scoreboard[i], load)) { - LLVM_DEBUG(llvm::dbgs() - << "store " << scoreboard[i] - << " doesn't dominate load: " << *load << '\n'); - return false; - } - } - - // For all global uses, all of the stores must dominate every use. - for (auto *glob : toGlobalUses) { - for (auto *store : scoreboard) - if (!dom.dominates(store, glob)) { - LLVM_DEBUG(llvm::dbgs() - << "store " << store << " doesn't dominate op: " << *glob - << '\n'); - return false; - } - } - } - return ok; - } - - DominanceInfo &dom; - StringRef funcName; -}; +#include "LiftArrayAllocPatterns.inc" +namespace { class LiftArrayAllocPass : public cudaq::opt::impl::LiftArrayAllocBase { public: diff --git a/lib/Optimizer/Transforms/LiftArrayAllocPatterns.inc b/lib/Optimizer/Transforms/LiftArrayAllocPatterns.inc new file mode 100644 index 00000000000..db020a3aa39 --- /dev/null +++ b/lib/Optimizer/Transforms/LiftArrayAllocPatterns.inc @@ -0,0 +1,276 @@ +/******************************************************************************* + * Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +// These patterns are used by the loft-array-alloc and cc-loop-unroll passes. + +// This file must be included after a `using namespace mlir;` as it uses bare +// identifiers from that namespace. + +namespace { +class AllocaPattern : public OpRewritePattern { +public: + explicit AllocaPattern(MLIRContext *ctx, DominanceInfo &di, StringRef fn) + : OpRewritePattern(ctx), dom(di), funcName(fn) {} + + LogicalResult matchAndRewrite(cudaq::cc::AllocaOp alloc, + PatternRewriter &rewriter) const override { + SmallVector stores; + if (!isGoodCandidate(alloc, stores, dom)) + return failure(); + + LLVM_DEBUG(llvm::dbgs() << "Candidate was found\n"); + auto allocTy = alloc.getElementType(); + auto arrTy = cast(allocTy); + auto eleTy = arrTy.getElementType(); + + SmallVector values; + + // Every element of `stores` must be a cc::StoreOp with a ConstantOp as the + // value argument. Build the array attr to attach to a cc.const_array. + for (auto *op : stores) { + auto store = cast(op); + auto *valOp = store.getValue().getDefiningOp(); + if (auto con = dyn_cast(valOp)) + values.push_back(con.getValueAttr()); + else if (auto con = dyn_cast(valOp)) + values.push_back(con.getValueAttr()); + else + return alloc.emitOpError("could not fold"); + } + + // Create the cc.const_array. + auto valuesAttr = rewriter.getArrayAttr(values); + auto loc = alloc.getLoc(); + Value conArr = + rewriter.create(loc, arrTy, valuesAttr); + + assert(conArr && "must have created the constant array"); + LLVM_DEBUG(llvm::dbgs() << "constant array is:\n" << conArr << '\n'); + bool cannotEraseAlloc = false; + + // Collect all the stores, casts, and compute_ptr to be erased safely and in + // topological order. + SmallVector opsToErase; + auto insertOpToErase = [&](Operation *op) { + auto iter = std::find(opsToErase.begin(), opsToErase.end(), op); + if (iter == opsToErase.end()) + opsToErase.push_back(op); + }; + + // Rewalk all the uses of alloc, u, which must be cc.cast or cc.compute_ptr. + // For each u remove a store and replace a load with a cc.extract_value. + for (auto *user : alloc->getUsers()) { + if (!user) + continue; + std::int32_t offset = 0; + if (auto cptr = dyn_cast(user)) + offset = cptr.getRawConstantIndices()[0]; + bool isLive = false; + if (!isa(user)) { + cannotEraseAlloc = isLive = true; + } else { + for (auto *useuser : user->getUsers()) { + if (!useuser) + continue; + if (auto load = dyn_cast(useuser)) { + rewriter.setInsertionPointAfter(useuser); + LLVM_DEBUG(llvm::dbgs() << "replaced load\n"); + // rewriter.replaceOpWithNewOp( + // load, eleTy, conArr, + // ArrayRef{offset}); + + auto extractValue = rewriter.create( + loc, eleTy, conArr, + ArrayRef{offset}); + rewriter.replaceAllUsesWith(load, extractValue); + insertOpToErase(load); + continue; + } + if (isa(useuser)) { + insertOpToErase(useuser); + continue; + } + LLVM_DEBUG(llvm::dbgs() << "alloc is live\n"); + cannotEraseAlloc = isLive = true; + } + } + if (!isLive) + insertOpToErase(user); + } + + for (auto *e : opsToErase) + rewriter.eraseOp(e); + + if (cannotEraseAlloc) { + rewriter.setInsertionPointAfter(alloc); + rewriter.create(loc, conArr, alloc); + return success(); + } + rewriter.eraseOp(alloc); + return success(); + } + + // Determine if \p alloc is a legit candidate for promotion to a constant + // array value. \p scoreboard is a vector of store operations. Each element of + // the allocated array must be written to exactly 1 time, and the scoreboard + // is used to track these stores. \p dom is the dominance info for this + // function (to ensure the stores happen before uses). + static bool isGoodCandidate(cudaq::cc::AllocaOp alloc, + SmallVectorImpl &scoreboard, + DominanceInfo &dom) { + if (alloc.getSeqSize()) + return false; + auto arrTy = dyn_cast(alloc.getElementType()); + if (!arrTy || arrTy.isUnknownSize()) + return false; + auto arrEleTy = arrTy.getElementType(); + if (!isa(arrEleTy)) + return false; + + // There must be at least `size` uses to initialize the entire array. + auto size = arrTy.getSize(); + if (std::distance(alloc->getUses().begin(), alloc->getUses().end()) < size) + return false; + + // Keep a scoreboard for every element in the array. Every element *must* be + // stored to with a constant exactly one time. + scoreboard.resize(size); + for (int i = 0; i < size; i++) + scoreboard[i] = nullptr; + + SmallVector toGlobalUses; + SmallVector> loadSets(size); + + auto getWriteOp = [&](auto op, std::int32_t index) -> Operation * { + Operation *theStore = nullptr; + for (auto &use : op->getUses()) { + Operation *u = use.getOwner(); + if (!u) + return nullptr; + + if (auto store = dyn_cast(u)) { + if (op.getOperation() == store.getPtrvalue().getDefiningOp()) { + if (theStore) { + LLVM_DEBUG(llvm::dbgs() + << "more than 1 store to element of array\n"); + return nullptr; + } + LLVM_DEBUG(llvm::dbgs() << "found store: " << store << "\n"); + theStore = u; + } + continue; + } + if (isa(u)) { + toGlobalUses.push_back(u); + continue; + } + if (isa(u)) { + loadSets[index].insert(u); + continue; + } + return nullptr; + } + return theStore && + isa_and_present( + dyn_cast(theStore) + .getValue() + .getDefiningOp()) + ? theStore + : nullptr; + }; + + auto unsizedArrTy = cudaq::cc::ArrayType::get(arrEleTy); + auto ptrUnsizedArrTy = cudaq::cc::PointerType::get(unsizedArrTy); + auto ptrArrEleTy = cudaq::cc::PointerType::get(arrEleTy); + for (auto &use : alloc->getUses()) { + // All uses *must* be a degenerate cc.cast, cc.compute_ptr, or + // cc.init_state. + auto *op = use.getOwner(); + if (!op) { + LLVM_DEBUG(llvm::dbgs() << "use was not an op\n"); + return false; + } + if (auto cptr = dyn_cast(op)) { + if (auto index = cptr.getConstantIndex(0)) + if (auto w = getWriteOp(cptr, *index)) + if (!scoreboard[*index]) { + scoreboard[*index] = w; + continue; + } + return false; + } + if (auto cast = dyn_cast(op)) { + // Process casts that are used in store ops. + if (cast.getType() == ptrArrEleTy) { + if (auto w = getWriteOp(cast, 0)) + if (!scoreboard[0]) { + scoreboard[0] = w; + continue; + } + return false; + } + // Process casts that are used in quake.init_state. + if (cast.getType() == ptrUnsizedArrTy) { + if (cast->hasOneUse()) { + auto &use = *cast->getUses().begin(); + Operation *u = use.getOwner(); + if (isa_and_present(u)) { + toGlobalUses.push_back(op); + continue; + } + } + return false; + } + LLVM_DEBUG(llvm::dbgs() << "unexpected cast: " << *op << '\n'); + toGlobalUses.push_back(op); + continue; + } + if (isa(op)) { + toGlobalUses.push_back(op); + continue; + } + LLVM_DEBUG(llvm::dbgs() << "unexpected use: " << *op << '\n'); + toGlobalUses.push_back(op); + } + + bool ok = std::all_of(scoreboard.begin(), scoreboard.end(), + [](bool b) { return b; }); + LLVM_DEBUG(llvm::dbgs() << "all elements of array are set: " << ok << '\n'); + if (ok) { + // Verify dominance relations. + + // For all stores, the store of an element $e$ must dominate all loads of + // $e$. + for (int i = 0; i < size; ++i) { + for (auto *load : loadSets[i]) + if (!dom.dominates(scoreboard[i], load)) { + LLVM_DEBUG(llvm::dbgs() + << "store " << scoreboard[i] + << " doesn't dominate load: " << *load << '\n'); + return false; + } + } + + // For all global uses, all of the stores must dominate every use. + for (auto *glob : toGlobalUses) { + for (auto *store : scoreboard) + if (!dom.dominates(store, glob)) { + LLVM_DEBUG(llvm::dbgs() + << "store " << store << " doesn't dominate op: " << *glob + << '\n'); + return false; + } + } + } + return ok; + } + + DominanceInfo &dom; + StringRef funcName; +}; +} // namespace diff --git a/lib/Optimizer/Transforms/LoopNormalize.cpp b/lib/Optimizer/Transforms/LoopNormalize.cpp index faf0b3ea648..2e15bbb069d 100644 --- a/lib/Optimizer/Transforms/LoopNormalize.cpp +++ b/lib/Optimizer/Transforms/LoopNormalize.cpp @@ -23,145 +23,9 @@ namespace cudaq::opt { using namespace mlir; -// Return true if \p loop is not monotonic or it is an invariant loop. -// Normalization is to be done on any loop that is monotonic and not -// invariant (which includes loops that are already in counted form). -static bool isNotMonotonicOrInvariant(cudaq::cc::LoopOp loop, - bool allowClosedInterval, - bool allowEarlyExit) { - cudaq::opt::LoopComponents c; - return !cudaq::opt::isaMonotonicLoop(loop, allowEarlyExit, &c) || - (cudaq::opt::isaInvariantLoop(c, allowClosedInterval) && - !c.isLinearExpr()); -} +#include "LoopNormalizePatterns.inc" namespace { -class LoopPat : public OpRewritePattern { -public: - explicit LoopPat(MLIRContext *ctx, bool aci, bool ab) - : OpRewritePattern(ctx), allowClosedInterval(aci), allowEarlyExit(ab) {} - - LogicalResult matchAndRewrite(cudaq::cc::LoopOp loop, - PatternRewriter &rewriter) const override { - if (loop->hasAttr(cudaq::opt::NormalizedLoopAttr)) - return failure(); - if (isNotMonotonicOrInvariant(loop, allowClosedInterval, allowEarlyExit)) - return failure(); - - // loop is monotonic but not invariant. - LLVM_DEBUG(llvm::dbgs() << "loop before normalization: " << loop << '\n'); - auto componentsOpt = cudaq::opt::getLoopComponents(loop); - assert(componentsOpt && "loop must have components"); - auto c = *componentsOpt; - auto loc = loop.getLoc(); - - // 1) Set initial value to 0. - auto ty = c.initialValue.getType(); - rewriter.startRootUpdate(loop); - auto createConstantOp = [&](std::int64_t val) -> Value { - return rewriter.create(loc, val, ty); - }; - auto zero = createConstantOp(0); - loop->setOperand(c.induction, zero); - - // 2) Compute the number of iterations as an invariant. `iterations = max(0, - // (upper - lower + step) / step)`. - Value upper = c.compareValue; - auto one = createConstantOp(1); - Value step = c.stepValue; - Value lower = c.initialValue; - if (!c.stepIsAnAddOp()) - step = rewriter.create(loc, zero, step); - if (c.isLinearExpr()) { - // Induction is part of a linear expression. Deal with the terms of the - // equation. `m` scales the step. `b` is an addend to the lower bound. - if (c.addendValue) { - if (c.negatedAddend) { - // `m * i - b`, u += `b`. - upper = rewriter.create(loc, upper, c.addendValue); - } else { - // `m * i + b`, u -= `b`. - upper = rewriter.create(loc, upper, c.addendValue); - } - } - if (c.minusOneMult) { - // `b - m * i` (b eliminated), multiply lower and step by `-1` (`m` - // follows). - auto negOne = createConstantOp(-1); - lower = rewriter.create(loc, lower, negOne); - step = rewriter.create(loc, step, negOne); - } - if (c.scaleValue) { - if (c.reciprocalScale) { - // `1/m * i + b` (b eliminated), multiply upper by `m`. - upper = rewriter.create(loc, upper, c.scaleValue); - } else { - // `m * i + b` (b eliminated), multiple lower and step by `m`. - lower = rewriter.create(loc, lower, c.scaleValue); - step = rewriter.create(loc, step, c.scaleValue); - } - } - } - if (!c.isClosedIntervalForm()) { - // Note: treating the step as a signed value to process countdown loops as - // well as countup loops. - Value negStepCond = rewriter.create( - loc, arith::CmpIPredicate::slt, step, zero); - auto negOne = createConstantOp(-1); - Value adj = - rewriter.create(loc, ty, negStepCond, negOne, one); - upper = rewriter.create(loc, upper, adj); - } - Value diff = rewriter.create(loc, upper, lower); - Value disp = rewriter.create(loc, diff, step); - auto cmpOp = cast(c.compareOp); - Value up1 = rewriter.create(loc, disp, step); - Value noLoopCond = rewriter.create( - loc, arith::CmpIPredicate::sgt, up1, zero); - Value newUpper = - rewriter.create(loc, ty, noLoopCond, up1, zero); - - // 3) Rewrite the comparison (!=) and step operations (+1). - Value v1 = c.getCompareInduction(); - rewriter.setInsertionPoint(cmpOp); - Value newCmp = rewriter.create( - cmpOp.getLoc(), arith::CmpIPredicate::ne, v1, newUpper); - cmpOp->replaceAllUsesWith(ValueRange{newCmp}); - auto v2 = c.stepOp->getOperand( - c.stepIsAnAddOp() && c.shouldCommuteStepOp() ? 1 : 0); - rewriter.setInsertionPoint(c.stepOp); - auto newStep = rewriter.create(c.stepOp->getLoc(), v2, one); - c.stepOp->replaceAllUsesWith(ValueRange{newStep.getResult()}); - - // 4) Compute original induction value as a loop variant and replace the - // uses. `lower + step * i`. Careful to not replace the new induction. - if (!loop.getBodyRegion().empty()) { - Block *entry = &loop.getBodyRegion().front(); - rewriter.setInsertionPointToStart(entry); - Value induct = entry->getArgument(c.induction); - auto mul = rewriter.create(loc, induct, c.stepValue); - Value newInd; - if (c.stepIsAnAddOp()) - newInd = rewriter.create(loc, c.initialValue, mul); - else - newInd = rewriter.create(loc, c.initialValue, mul); - induct.replaceUsesWithIf(newInd, [&](OpOperand &opnd) { - auto *op = opnd.getOwner(); - return op != newStep.getOperation() && op != mul && - !isa(op); - }); - } - loop->setAttr(cudaq::opt::NormalizedLoopAttr, rewriter.getUnitAttr()); - - rewriter.finalizeRootUpdate(loop); - LLVM_DEBUG(llvm::dbgs() << "loop after normalization: " << loop << '\n'); - return success(); - } - - bool allowClosedInterval; - bool allowEarlyExit; -}; - class LoopNormalizePass : public cudaq::opt::impl::LoopNormalizeBase { public: diff --git a/lib/Optimizer/Transforms/LoopNormalizePatterns.inc b/lib/Optimizer/Transforms/LoopNormalizePatterns.inc new file mode 100644 index 00000000000..3191d877a4b --- /dev/null +++ b/lib/Optimizer/Transforms/LoopNormalizePatterns.inc @@ -0,0 +1,153 @@ +/******************************************************************************* + * Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +// These loop normalization patterns are used by the cc-loop-normalize pass +// and cc-loop-unroll pass + +// This file must be included after a `using namespace mlir;` as it uses bare +// identifiers from that namespace. + +// Return true if \p loop is not monotonic or it is an invariant loop. +// Normalization is to be done on any loop that is monotonic and not +// invariant (which includes loops that are already in counted form). +static bool isNotMonotonicOrInvariant(cudaq::cc::LoopOp loop, + bool allowClosedInterval, + bool allowEarlyExit) { + cudaq::opt::LoopComponents c; + return !cudaq::opt::isaMonotonicLoop(loop, allowEarlyExit, &c) || + (cudaq::opt::isaInvariantLoop(c, allowClosedInterval) && + !c.isLinearExpr()); +} + +namespace { +class LoopPat : public OpRewritePattern { +public: + explicit LoopPat(MLIRContext *ctx, bool aci, bool ab) + : OpRewritePattern(ctx), allowClosedInterval(aci), allowEarlyExit(ab) {} + + LogicalResult matchAndRewrite(cudaq::cc::LoopOp loop, + PatternRewriter &rewriter) const override { + if (loop->hasAttr(cudaq::opt::NormalizedLoopAttr)) + return failure(); + if (isNotMonotonicOrInvariant(loop, allowClosedInterval, allowEarlyExit)) + return failure(); + + // loop is monotonic but not invariant. + LLVM_DEBUG(llvm::dbgs() << "loop before normalization: " << loop << '\n'); + auto componentsOpt = cudaq::opt::getLoopComponents(loop); + assert(componentsOpt && "loop must have components"); + auto c = *componentsOpt; + auto loc = loop.getLoc(); + + // 1) Set initial value to 0. + auto ty = c.initialValue.getType(); + rewriter.startRootUpdate(loop); + auto createConstantOp = [&](std::int64_t val) -> Value { + return rewriter.create(loc, val, ty); + }; + auto zero = createConstantOp(0); + loop->setOperand(c.induction, zero); + + // 2) Compute the number of iterations as an invariant. `iterations = max(0, + // (upper - lower + step) / step)`. + Value upper = c.compareValue; + auto one = createConstantOp(1); + Value step = c.stepValue; + Value lower = c.initialValue; + if (!c.stepIsAnAddOp()) + step = rewriter.create(loc, zero, step); + if (c.isLinearExpr()) { + // Induction is part of a linear expression. Deal with the terms of the + // equation. `m` scales the step. `b` is an addend to the lower bound. + if (c.addendValue) { + if (c.negatedAddend) { + // `m * i - b`, u += `b`. + upper = rewriter.create(loc, upper, c.addendValue); + } else { + // `m * i + b`, u -= `b`. + upper = rewriter.create(loc, upper, c.addendValue); + } + } + if (c.minusOneMult) { + // `b - m * i` (b eliminated), multiply lower and step by `-1` (`m` + // follows). + auto negOne = createConstantOp(-1); + lower = rewriter.create(loc, lower, negOne); + step = rewriter.create(loc, step, negOne); + } + if (c.scaleValue) { + if (c.reciprocalScale) { + // `1/m * i + b` (b eliminated), multiply upper by `m`. + upper = rewriter.create(loc, upper, c.scaleValue); + } else { + // `m * i + b` (b eliminated), multiple lower and step by `m`. + lower = rewriter.create(loc, lower, c.scaleValue); + step = rewriter.create(loc, step, c.scaleValue); + } + } + } + if (!c.isClosedIntervalForm()) { + // Note: treating the step as a signed value to process countdown loops as + // well as countup loops. + Value negStepCond = rewriter.create( + loc, arith::CmpIPredicate::slt, step, zero); + auto negOne = createConstantOp(-1); + Value adj = + rewriter.create(loc, ty, negStepCond, negOne, one); + upper = rewriter.create(loc, upper, adj); + } + Value diff = rewriter.create(loc, upper, lower); + Value disp = rewriter.create(loc, diff, step); + auto cmpOp = cast(c.compareOp); + Value up1 = rewriter.create(loc, disp, step); + Value noLoopCond = rewriter.create( + loc, arith::CmpIPredicate::sgt, up1, zero); + Value newUpper = + rewriter.create(loc, ty, noLoopCond, up1, zero); + + // 3) Rewrite the comparison (!=) and step operations (+1). + Value v1 = c.getCompareInduction(); + rewriter.setInsertionPoint(cmpOp); + Value newCmp = rewriter.create( + cmpOp.getLoc(), arith::CmpIPredicate::ne, v1, newUpper); + cmpOp->replaceAllUsesWith(ValueRange{newCmp}); + auto v2 = c.stepOp->getOperand( + c.stepIsAnAddOp() && c.shouldCommuteStepOp() ? 1 : 0); + rewriter.setInsertionPoint(c.stepOp); + auto newStep = rewriter.create(c.stepOp->getLoc(), v2, one); + c.stepOp->replaceAllUsesWith(ValueRange{newStep.getResult()}); + + // 4) Compute original induction value as a loop variant and replace the + // uses. `lower + step * i`. Careful to not replace the new induction. + if (!loop.getBodyRegion().empty()) { + Block *entry = &loop.getBodyRegion().front(); + rewriter.setInsertionPointToStart(entry); + Value induct = entry->getArgument(c.induction); + auto mul = rewriter.create(loc, induct, c.stepValue); + Value newInd; + if (c.stepIsAnAddOp()) + newInd = rewriter.create(loc, c.initialValue, mul); + else + newInd = rewriter.create(loc, c.initialValue, mul); + induct.replaceUsesWithIf(newInd, [&](OpOperand &opnd) { + auto *op = opnd.getOwner(); + return op != newStep.getOperation() && op != mul && + !isa(op); + }); + } + loop->setAttr(cudaq::opt::NormalizedLoopAttr, rewriter.getUnitAttr()); + + rewriter.finalizeRootUpdate(loop); + LLVM_DEBUG(llvm::dbgs() << "loop after normalization: " << loop << '\n'); + return success(); + } + + bool allowClosedInterval; + bool allowEarlyExit; +}; +} // namespace diff --git a/lib/Optimizer/Transforms/LoopUnroll.cpp b/lib/Optimizer/Transforms/LoopUnroll.cpp index b7f0053f263..bc6f06b6b7d 100644 --- a/lib/Optimizer/Transforms/LoopUnroll.cpp +++ b/lib/Optimizer/Transforms/LoopUnroll.cpp @@ -9,9 +9,11 @@ #include "LoopAnalysis.h" #include "PassDetails.h" #include "cudaq/Optimizer/Transforms/Passes.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" namespace cudaq::opt { #define GEN_PASS_DEF_LOOPUNROLL @@ -23,6 +25,11 @@ namespace cudaq::opt { using namespace mlir; +#include "LiftArrayAllocPatterns.inc" +#include "LoopNormalizePatterns.inc" +#include "LowerToCFGPatterns.inc" +#include "WriteAfterWriteEliminationPatterns.inc" + inline std::pair findCloneRange(Block *first, Block *last) { return {first->getNextNode(), last->getPrevNode()}; } @@ -144,7 +151,7 @@ struct UnrollCountedLoop : public OpRewritePattern { contBlock = rewriter.createBlock(endBlock, argTys, argLocs); } // Replace any continue and (possibly) break ops in the body region. They - // are repalced with branches to the continue block or exit block, resp. + // are replaced with branches to the continue block or exit block, resp. for (Block *b = cloneRange.first; b != contBlock; b = b->getNextNode()) { auto *term = b->getTerminator(); if (auto cont = dyn_cast(term)) { @@ -230,6 +237,8 @@ class LoopUnrollPass : public cudaq::opt::impl::LoopUnrollBase { void runOnOperation() override { auto *ctx = &getContext(); auto *op = getOperation(); + DominanceInfo domInfo(op); + auto func = dyn_cast(op); auto numLoops = countLoopOps(op); unsigned progress = 0; if (numLoops) { @@ -238,14 +247,32 @@ class LoopUnrollPass : public cudaq::opt::impl::LoopUnrollBase { dialect->getCanonicalizationPatterns(patterns); for (RegisteredOperationName op : ctx->getRegisteredOperations()) op.getCanonicalizationPatterns(patterns, ctx); + + // Add patterns that help const prop loop boundaries computed + // in conditional statements, other loops, or arrays. + patterns.insert(ctx, /*rewriteOnlyIfConst=*/true); + patterns.insert(ctx, allowClosedInterval, allowBreak); + patterns.insert( + ctx, domInfo, func == nullptr ? "unknown" : func.getName()); patterns.insert(ctx, threshold, /*signalFailure=*/false, allowBreak, progress); + FrozenRewritePatternSet frozen(std::move(patterns)); // Iterate over the loops until a fixed-point is reached. Some loops can // only be unrolled if other loops are unrolled first and the constants // iteratively propagated. do { + // Remove overridden writes. + auto analysis = SimplifyWritesAnalysis(domInfo, op); + analysis.removeOverriddenStores(); + // Clean up dead code. + { + auto builder = OpBuilder(op); + IRRewriter rewriter(builder); + [[maybe_unused]] auto unused = + simplifyRegions(rewriter, op->getRegions()); + } progress = 0; (void)applyPatternsAndFoldGreedily(op, frozen); } while (progress); @@ -343,12 +370,11 @@ static void createUnrollingPipeline(OpPassManager &pm, unsigned threshold, bool signalFailure, bool allowBreak, bool allowClosedInterval) { pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCSEPass()); pm.addNestedPass(cudaq::opt::createClassicalMemToReg()); - pm.addNestedPass(createCanonicalizerPass()); - cudaq::opt::LoopNormalizeOptions lno{allowClosedInterval, allowBreak}; - pm.addNestedPass(cudaq::opt::createLoopNormalize(lno)); - pm.addNestedPass(createCanonicalizerPass()); - cudaq::opt::LoopUnrollOptions luo{threshold, signalFailure, allowBreak}; + cudaq::opt::LoopUnrollOptions luo{threshold, signalFailure, + allowClosedInterval, allowBreak}; + // TODO: run cse as a part of cc-loop-unroll when we update the llvm version. pm.addNestedPass(cudaq::opt::createLoopUnroll(luo)); pm.addNestedPass(cudaq::opt::createUpdateRegisterNames()); } diff --git a/lib/Optimizer/Transforms/LowerToCFG.cpp b/lib/Optimizer/Transforms/LowerToCFG.cpp index fb050fabffb..6431153542a 100644 --- a/lib/Optimizer/Transforms/LowerToCFG.cpp +++ b/lib/Optimizer/Transforms/LowerToCFG.cpp @@ -20,6 +20,8 @@ using namespace mlir; +#include "LowerToCFGPatterns.inc" + namespace { class RewriteScope : public OpRewritePattern { public: @@ -278,80 +280,6 @@ class RewriteLoop : public OpRewritePattern { } }; -class RewriteIf : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - /// Rewrites an if construct like - /// ```mlir - /// (0) - /// quake.if %cond { - /// (1) - /// } else { - /// (2) - /// } - /// (3) - /// ``` - /// to a CFG like - /// ```mlir - /// (0) - /// cf.cond_br %cond, ^bb1, ^bb2 - /// ^bb1: - /// (1) - /// cf.br ^bb3 - /// ^bb2: - /// (2) - /// cf.br ^bb3 - /// ^bb3: - /// (3) - /// ``` - LogicalResult matchAndRewrite(cudaq::cc::IfOp ifOp, - PatternRewriter &rewriter) const override { - auto loc = ifOp.getLoc(); - auto *initBlock = rewriter.getInsertionBlock(); - auto initPos = rewriter.getInsertionPoint(); - auto *endBlock = rewriter.splitBlock(initBlock, initPos); - if (ifOp.getNumResults() != 0) { - Block *continueBlock = rewriter.createBlock( - endBlock, ifOp.getResultTypes(), - SmallVector(ifOp.getNumResults(), loc)); - rewriter.create(loc, endBlock); - endBlock = continueBlock; - } - auto *thenBlock = &ifOp.getThenRegion().front(); - bool hasElse = !ifOp.getElseRegion().empty(); - auto *elseBlock = hasElse ? &ifOp.getElseRegion().front() : endBlock; - updateBodyBranches(&ifOp.getThenRegion(), rewriter, endBlock); - updateBodyBranches(&ifOp.getElseRegion(), rewriter, endBlock); - rewriter.inlineRegionBefore(ifOp.getThenRegion(), endBlock); - if (hasElse) - rewriter.inlineRegionBefore(ifOp.getElseRegion(), endBlock); - rewriter.setInsertionPointToEnd(initBlock); - rewriter.create(loc, ifOp.getCondition(), thenBlock, - ifOp.getLinearArgs(), elseBlock, - ifOp.getLinearArgs()); - rewriter.replaceOp(ifOp, endBlock->getArguments()); - return success(); - } - - // Replace all the ContinueOp in the body region with branches to the correct - // basic blocks. - void updateBodyBranches(Region *bodyRegion, PatternRewriter &rewriter, - Block *continueBlock) const { - // Walk body region and replace all continue and break ops. - for (Block &block : *bodyRegion) { - auto *terminator = block.getTerminator(); - if (auto cont = dyn_cast(terminator)) { - rewriter.setInsertionPointToEnd(&block); - LLVM_DEBUG(llvm::dbgs() << "replacing " << *terminator << '\n'); - rewriter.replaceOpWithNewOp(cont, continueBlock, - cont.getOperands()); - } - // Other ad-hoc control flow in the region need not be rewritten. - } - } -}; - class RewriteReturn : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; diff --git a/lib/Optimizer/Transforms/LowerToCFGPatterns.inc b/lib/Optimizer/Transforms/LowerToCFGPatterns.inc new file mode 100644 index 00000000000..ae11f455346 --- /dev/null +++ b/lib/Optimizer/Transforms/LowerToCFGPatterns.inc @@ -0,0 +1,104 @@ +/******************************************************************************* + * Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ +// These patterns are used by the lower-to-cfg pass and cc-loop-unroll pass. + +// This file must be included after a `using namespace mlir;` as it uses bare +// identifiers from that namespace. + +namespace { +class RewriteIf : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + explicit RewriteIf(MLIRContext *ctx) + : OpRewritePattern(ctx), rewriteOnlyIfConst(false) {} + + RewriteIf(MLIRContext *ctx, bool rewriteOnlyIfConst) + : OpRewritePattern(ctx), rewriteOnlyIfConst(rewriteOnlyIfConst) {} + + /// Rewrites an if construct like + /// ```mlir + /// (0) + /// quake.if %cond { + /// (1) + /// } else { + /// (2) + /// } + /// (3) + /// ``` + /// to a CFG like + /// ```mlir + /// (0) + /// cf.cond_br %cond, ^bb1, ^bb2 + /// ^bb1: + /// (1) + /// cf.br ^bb3 + /// ^bb2: + /// (2) + /// cf.br ^bb3 + /// ^bb3: + /// (3) + /// ``` + LogicalResult matchAndRewrite(cudaq::cc::IfOp ifOp, + PatternRewriter &rewriter) const override { + // Bail out on non-constant conditions if we just need to + // const-prop if($const). + if (rewriteOnlyIfConst) { + auto cond = ifOp.getCondition(); + if (!isa_and_present(cond.getDefiningOp())) + return failure(); + } + + auto loc = ifOp.getLoc(); + auto *initBlock = rewriter.getInsertionBlock(); + auto initPos = rewriter.getInsertionPoint(); + auto *endBlock = rewriter.splitBlock(initBlock, initPos); + if (ifOp.getNumResults() != 0) { + Block *continueBlock = rewriter.createBlock( + endBlock, ifOp.getResultTypes(), + SmallVector(ifOp.getNumResults(), loc)); + rewriter.create(loc, endBlock); + endBlock = continueBlock; + } + auto *thenBlock = &ifOp.getThenRegion().front(); + bool hasElse = !ifOp.getElseRegion().empty(); + auto *elseBlock = hasElse ? &ifOp.getElseRegion().front() : endBlock; + updateBodyBranches(&ifOp.getThenRegion(), rewriter, endBlock); + updateBodyBranches(&ifOp.getElseRegion(), rewriter, endBlock); + rewriter.inlineRegionBefore(ifOp.getThenRegion(), endBlock); + if (hasElse) + rewriter.inlineRegionBefore(ifOp.getElseRegion(), endBlock); + rewriter.setInsertionPointToEnd(initBlock); + rewriter.create(loc, ifOp.getCondition(), thenBlock, + ifOp.getLinearArgs(), elseBlock, + ifOp.getLinearArgs()); + rewriter.replaceOp(ifOp, endBlock->getArguments()); + return success(); + } + + // Replace all the ContinueOp in the body region with branches to the correct + // basic blocks. + void updateBodyBranches(Region *bodyRegion, PatternRewriter &rewriter, + Block *continueBlock) const { + // Walk body region and replace all continue and break ops. + for (Block &block : *bodyRegion) { + auto *terminator = block.getTerminator(); + if (auto cont = dyn_cast(terminator)) { + rewriter.setInsertionPointToEnd(&block); + LLVM_DEBUG(llvm::dbgs() << "replacing " << *terminator << '\n'); + rewriter.replaceOpWithNewOp(cont, continueBlock, + cont.getOperands()); + } + // Other ad-hoc control flow in the region need not be rewritten. + } + } + +private: + bool rewriteOnlyIfConst; +}; +} // namespace diff --git a/lib/Optimizer/Transforms/WriteAfterWriteElimination.cpp b/lib/Optimizer/Transforms/WriteAfterWriteElimination.cpp index 6ab5f2272ab..9f9b1d571d9 100644 --- a/lib/Optimizer/Transforms/WriteAfterWriteElimination.cpp +++ b/lib/Optimizer/Transforms/WriteAfterWriteElimination.cpp @@ -27,121 +27,9 @@ namespace cudaq::opt { using namespace mlir; -namespace { -/// Remove stores followed by a store to the same pointer -/// if the pointer is not used in between. -/// ``` -/// cc.store %c0_i64, %1 : !cc.ptr -/// // no use of %1 until next line -/// cc.store %0, %1 : !cc.ptr -/// ─────────────────────────────────────────── -/// cc.store %0, %1 : !cc.ptr -/// ``` -class SimplifyWritesAnalysis { -public: - SimplifyWritesAnalysis(DominanceInfo &di, Operation *op) : dom(di) { - for (auto ®ion : op->getRegions()) - for (auto &b : region) - collectBlockInfo(&b); - } - - /// Remove stores followed by a store to the same pointer if the pointer is - /// not used in between, using collected block info. - void removeOverriddenStores() { - SmallVector toErase; - - for (const auto &[block, ptrToStores] : blockInfo) { - for (const auto &[ptr, stores] : ptrToStores) { - if (stores.size() > 1) { - auto replacement = stores.back(); - for (auto *store : stores) { - if (isReplacement(ptr, store, replacement)) { - LLVM_DEBUG(llvm::dbgs() << "replacing store " << *store - << " by: " << *replacement << '\n'); - toErase.push_back(store); - } - } - } - } - } - - for (auto *op : toErase) - op->erase(); - } - -private: - /// Detect if value is used in the op or its nested blocks. - bool isReplacement(Operation *ptr, Operation *store, - Operation *replacement) const { - if (store == replacement) - return false; - - // Check that there are no non-store uses dominated by the store and - // not dominated by the replacement, i.e. only uses between the two - // stores are other stores to the same pointer. - for (auto *user : ptr->getUsers()) { - if (user != store && user != replacement) { - if (!isStoreToPtr(user, ptr) && dom.dominates(store, user) && - !dom.dominates(replacement, user)) { - LLVM_DEBUG(llvm::dbgs() << "store " << replacement - << " is used before: " << store << '\n'); - return false; - } - } - } - return true; - } - - /// Detects a store to the pointer. - static bool isStoreToPtr(Operation *op, Operation *ptr) { - return isa_and_present(op) && - (dyn_cast(op).getPtrvalue().getDefiningOp() == - ptr); - } - - /// Collect all stores to a pointer for a block. - void collectBlockInfo(Block *block) { - for (auto &op : *block) { - for (auto ®ion : op.getRegions()) - for (auto &b : region) - collectBlockInfo(&b); - - if (auto store = dyn_cast(&op)) { - auto ptr = store.getPtrvalue().getDefiningOp(); - if (isStoreToStack(store)) { - auto &[b, ptrToStores] = blockInfo.FindAndConstruct(block); - auto &[p, stores] = ptrToStores.FindAndConstruct(ptr); - stores.push_back(&op); - } - } - } - } - - /// Detect stores to stack locations, for example: - /// ``` - /// %1 = cc.alloca !cc.array - /// - /// %2 = cc.cast %1 : (!cc.ptr>) -> !cc.ptr - /// cc.store %c0_i64, %2 : !cc.ptr - /// - /// %3 = cc.compute_ptr %1[1] : (!cc.ptr>) -> !cc.ptr - /// cc.store %c0_i64, %3 : !cc.ptr - /// ``` - static bool isStoreToStack(cudaq::cc::StoreOp store) { - auto ptrOp = store.getPtrvalue(); - if (auto cast = ptrOp.getDefiningOp()) - ptrOp = cast.getOperand(); - - if (auto computePtr = ptrOp.getDefiningOp()) - ptrOp = computePtr.getBase(); - - return isa_and_present(ptrOp.getDefiningOp()); - } - - DominanceInfo &dom; - DenseMap>> blockInfo; -}; +#include "WriteAfterWriteEliminationPatterns.inc" +namespace { class WriteAfterWriteEliminationPass : public cudaq::opt::impl::WriteAfterWriteEliminationBase< WriteAfterWriteEliminationPass> { diff --git a/lib/Optimizer/Transforms/WriteAfterWriteEliminationPatterns.inc b/lib/Optimizer/Transforms/WriteAfterWriteEliminationPatterns.inc new file mode 100644 index 00000000000..642a7a76636 --- /dev/null +++ b/lib/Optimizer/Transforms/WriteAfterWriteEliminationPatterns.inc @@ -0,0 +1,129 @@ +/******************************************************************************* + * Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +// These patterns are used by the write-after-write-elimination and +// cc-loop-unroll passes. + +// This file must be included after a `using namespace mlir;` as it uses bare +// identifiers from that namespace. + +namespace { +/// Remove stores followed by a store to the same pointer +/// if the pointer is not used in between. +/// ``` +/// cc.store %c0_i64, %1 : !cc.ptr +/// // no use of %1 until next line +/// cc.store %0, %1 : !cc.ptr +/// ─────────────────────────────────────────── +/// cc.store %0, %1 : !cc.ptr +/// ``` +class SimplifyWritesAnalysis { +public: + SimplifyWritesAnalysis(DominanceInfo &di, Operation *op) : dom(di) { + for (auto ®ion : op->getRegions()) + for (auto &b : region) + collectBlockInfo(&b); + } + + /// Remove stores followed by a store to the same pointer if the pointer is + /// not used in between, using collected block info. + void removeOverriddenStores() { + SmallVector toErase; + + for (const auto &[block, ptrToStores] : blockInfo) { + for (const auto &[ptr, stores] : ptrToStores) { + if (stores.size() > 1) { + auto replacement = stores.back(); + for (auto *store : stores) { + if (isReplacement(ptr, store, replacement)) { + LLVM_DEBUG(llvm::dbgs() << "replacing store " << *store + << " by: " << *replacement << '\n'); + toErase.push_back(store); + } + } + } + } + } + + for (auto *op : toErase) + op->erase(); + } + +private: + /// Detect if value is used in the op or its nested blocks. + bool isReplacement(Operation *ptr, Operation *store, + Operation *replacement) const { + if (store == replacement) + return false; + + // Check that there are no non-store uses dominated by the store and + // not dominated by the replacement, i.e. only uses between the two + // stores are other stores to the same pointer. + for (auto *user : ptr->getUsers()) { + if (user != store && user != replacement) { + if (!isStoreToPtr(user, ptr) && dom.dominates(store, user) && + !dom.dominates(replacement, user)) { + LLVM_DEBUG(llvm::dbgs() << "store " << replacement + << " is used before: " << store << '\n'); + return false; + } + } + } + return true; + } + + /// Detects a store to the pointer. + static bool isStoreToPtr(Operation *op, Operation *ptr) { + return isa_and_present(op) && + (dyn_cast(op).getPtrvalue().getDefiningOp() == + ptr); + } + + /// Collect all stores to a pointer for a block. + void collectBlockInfo(Block *block) { + for (auto &op : *block) { + for (auto ®ion : op.getRegions()) + for (auto &b : region) + collectBlockInfo(&b); + + if (auto store = dyn_cast(&op)) { + auto ptr = store.getPtrvalue().getDefiningOp(); + if (isStoreToStack(store)) { + auto &[b, ptrToStores] = blockInfo.FindAndConstruct(block); + auto &[p, stores] = ptrToStores.FindAndConstruct(ptr); + stores.push_back(&op); + } + } + } + } + + /// Detect stores to stack locations, for example: + /// ``` + /// %1 = cc.alloca !cc.array + /// + /// %2 = cc.cast %1 : (!cc.ptr>) -> !cc.ptr + /// cc.store %c0_i64, %2 : !cc.ptr + /// + /// %3 = cc.compute_ptr %1[1] : (!cc.ptr>) -> !cc.ptr + /// cc.store %c0_i64, %3 : !cc.ptr + /// ``` + static bool isStoreToStack(cudaq::cc::StoreOp store) { + auto ptrOp = store.getPtrvalue(); + if (auto cast = ptrOp.getDefiningOp()) + ptrOp = cast.getOperand(); + + if (auto computePtr = ptrOp.getDefiningOp()) + ptrOp = computePtr.getBase(); + + return isa_and_present(ptrOp.getDefiningOp()); + } + + DominanceInfo &dom; + DenseMap>> blockInfo; +}; +} // namespace diff --git a/test/AST-Quake/loop_unroll-2.cpp b/test/AST-Quake/loop_unroll-2.cpp index 3a64741df1a..9a6e8221d26 100644 --- a/test/AST-Quake/loop_unroll-2.cpp +++ b/test/AST-Quake/loop_unroll-2.cpp @@ -48,7 +48,6 @@ struct test2 { struct test3 { void operator()(cudaq::qvector<> &q) __qpu__ { - // Do not expect this to unroll. Loop must be normalized. for (unsigned i = 7; i < 14; i += 3) { h(q[i]); x(q[i]); @@ -58,10 +57,15 @@ struct test3 { }; // CHECK-LABEL: func.func @__nvqpp__mlirgen__test3( -// CHECK: cc.loop while -// CHECK: } do { +// CHECK: quake.h %{{.*}} : (!quake.ref) -> () // CHECK: quake.x %{{.*}} : (!quake.ref) -> () -// CHECK-NOT: quake.x %{{.*}} : (!quake.ref) -> () +// CHECK: quake.h %{{.*}} : (!quake.ref) -> () +// CHECK: quake.h %{{.*}} : (!quake.ref) -> () +// CHECK: quake.x %{{.*}} : (!quake.ref) -> () +// CHECK: quake.h %{{.*}} : (!quake.ref) -> () +// CHECK: quake.h %{{.*}} : (!quake.ref) -> () +// CHECK: quake.x %{{.*}} : (!quake.ref) -> () +// CHECK: quake.h %{{.*}} : (!quake.ref) -> () // CHECK: return struct test4 { @@ -85,7 +89,6 @@ struct test4 { // CHECK: return struct test5 { - // Loop that decrements. Loop is not unrolled. It needs to be normalized. void operator()() __qpu__ { cudaq::qvector reg(1); for (size_t i = 3; i > 0; --i) @@ -95,14 +98,13 @@ struct test5 { }; // CHECK-LABEL: func.func @__nvqpp__mlirgen__test5( -// CHECK: cc.loop while -// CHECK: } do { +// CHECK: quake.x %{{.*}} : (!quake.ref) -> () +// CHECK: quake.x %{{.*}} : (!quake.ref) -> () // CHECK: quake.x %{{.*}} : (!quake.ref) -> () // CHECK-NOT: quake.x %{{.*}} : (!quake.ref) -> () // CHECK: return struct test6 { - // Loop that decrements. Loop is not unrolled. It needs to be normalized. void operator()() __qpu__ { cudaq::qvector reg(1); for (size_t i = 3; i-- > 0;) @@ -112,8 +114,8 @@ struct test6 { }; // CHECK-LABEL: func.func @__nvqpp__mlirgen__test6( -// CHECK: cc.loop while -// CHECK: } do { +// CHECK: quake.x %{{.*}} : (!quake.ref) -> () +// CHECK: quake.x %{{.*}} : (!quake.ref) -> () // CHECK: quake.x %{{.*}} : (!quake.ref) -> () // CHECK-NOT: quake.x %{{.*}} : (!quake.ref) -> () // CHECK: return diff --git a/test/AST-Quake/qalloc_initialization.cpp b/test/AST-Quake/qalloc_initialization.cpp index 5cc30afcc8b..887c4664a8c 100644 --- a/test/AST-Quake/qalloc_initialization.cpp +++ b/test/AST-Quake/qalloc_initialization.cpp @@ -307,87 +307,57 @@ __qpu__ bool Peppermint() { // QIR: } // QIR-LABEL: define { i1*, i64 } @__nvqpp__mlirgen__Cherry() -// QIR: %[[VAL_0:.*]] = alloca [4 x { double, double }] -// QIR: %[[VAL_1:.*]] = getelementptr inbounds [4 x { double, double }], [4 x { double, double }]* %[[VAL_0]], i64 0, i64 0 -// QIR: %[[VAL_2:.*]] = getelementptr inbounds [4 x { double, double }], [4 x { double, double }]* %[[VAL_0]], i64 0, i64 0, i32 0 -// QIR: store double 0.000000e+00, double* %[[VAL_2]] -// QIR: %[[VAL_3:.*]] = getelementptr inbounds [4 x { double, double }], [4 x { double, double }]* %[[VAL_0]], i64 0, i64 0, i32 1 -// QIR: store double 1.000000e+00, double* %[[VAL_3]] -// QIR: %[[VAL_4:.*]] = getelementptr inbounds [4 x { double, double }], [4 x { double, double }]* %[[VAL_0]], i64 0, i64 1, i32 0 -// QIR: store double 6.000000e-01, double* %[[VAL_4]] -// QIR: %[[VAL_5:.*]] = getelementptr inbounds [4 x { double, double }], [4 x { double, double }]* %[[VAL_0]], i64 0, i64 1, i32 1 -// QIR: store double 4.000000e-01, double* %[[VAL_5]] -// QIR: %[[VAL_6:.*]] = getelementptr inbounds [4 x { double, double }], [4 x { double, double }]* %[[VAL_0]], i64 0, i64 2, i32 0 -// QIR: store double 1.000000e+00, double* %[[VAL_6]] -// QIR: %[[VAL_7:.*]] = getelementptr inbounds [4 x { double, double }], [4 x { double, double }]* %[[VAL_0]], i64 0, i64 2, i32 1 -// QIR: %[[VAL_8:.*]] = bitcast double* %[[VAL_7]] to i8* -// QIR: call void @llvm.memset.p0i8.i64(i8* {{.*}}%[[VAL_8]], i8 0, i64 24, i1 false) -// QIR: %[[VAL_9:.*]] = call %[[VAL_10:.*]]* @__quantum__rt__qubit_allocate_array_with_state_complex64(i64 2, { double, double }* nonnull %[[VAL_1]]) -// QIR: %[[VAL_11:.*]] = call %[[VAL_12:.*]]** @__quantum__rt__array_get_element_ptr_1d(%[[VAL_10]]* %[[VAL_9]], i64 0) -// QIR: %[[VAL_13:.*]] = load %[[VAL_12]]*, %[[VAL_12]]** %[[VAL_11]] -// QIR: call void @__quantum__qis__h(%[[VAL_12]]* %[[VAL_13]]) -// QIR: %[[VAL_14:.*]] = call %[[VAL_12]]** @__quantum__rt__array_get_element_ptr_1d(%[[VAL_10]]* %[[VAL_9]], i64 1) -// QIR: %[[VAL_15:.*]] = load %[[VAL_12]]*, %[[VAL_12]]** %[[VAL_14]] -// QIR: call void @__quantum__qis__h(%[[VAL_12]]* %[[VAL_15]]) -// QIR: %[[VAL_16:.*]] = call %[[VAL_17:.*]]* @__quantum__qis__mz(%[[VAL_12]]* %[[VAL_13]]) -// QIR: %[[VAL_18:.*]] = bitcast %[[VAL_17]]* %[[VAL_16]] to i1* -// QIR: %[[VAL_19:.*]] = load i1, i1* %[[VAL_18]] -// QIR: %[[VAL_20:.*]] = zext i1 %[[VAL_19]] to i8 -// QIR: %[[VAL_21:.*]] = call %[[VAL_17]]* @__quantum__qis__mz(%[[VAL_12]]* %[[VAL_15]]) -// QIR: %[[VAL_22:.*]] = bitcast %[[VAL_17]]* %[[VAL_21]] to i1* -// QIR: %[[VAL_23:.*]] = load i1, i1* %[[VAL_22]] -// QIR: %[[VAL_24:.*]] = zext i1 %[[VAL_23]] to i8 -// QIR: %[[VAL_25:.*]] = call dereferenceable_or_null(2) i8* @malloc(i64 2) -// QIR: store i8 %[[VAL_20]], i8* %[[VAL_25]] -// QIR: %[[VAL_26:.*]] = getelementptr inbounds i8, i8* %[[VAL_25]], i64 1 -// QIR: store i8 %[[VAL_24]], i8* %[[VAL_26]] -// QIR: %[[VAL_27:.*]] = bitcast i8* %[[VAL_25]] to i1* -// QIR: %[[VAL_28:.*]] = insertvalue { i1*, i64 } undef, i1* %[[VAL_27]], 0 -// QIR: %[[VAL_29:.*]] = insertvalue { i1*, i64 } %[[VAL_28]], i64 2, 1 -// QIR: call void @__quantum__rt__qubit_release_array(%[[VAL_10]]* %[[VAL_9]]) -// QIR: ret { i1*, i64 } %[[VAL_29]] +// QIR: %[[VAL_0:.*]] = tail call %[[VAL_1:.*]]* @__quantum__rt__qubit_allocate_array_with_state_complex64(i64 2, { double, double }* nonnull getelementptr inbounds ([4 x { double, double }], [4 x { double, double }]* @__nvqpp__mlirgen__Cherry.rodata_1, i64 0, i64 0)) +// QIR: %[[VAL_2:.*]] = tail call %[[VAL_3:.*]]** @__quantum__rt__array_get_element_ptr_1d(%Array* %1, i64 0) +// QIR: %[[VAL_4:.*]] = load %[[VAL_3]]*, %[[VAL_3]]** %[[VAL_2]], align 8 +// QIR: tail call void @__quantum__qis__h(%[[VAL_3]]* %[[VAL_4]]) +// QIR: %[[VAL_5:.*]] = tail call %[[VAL_3]]** @__quantum__rt__array_get_element_ptr_1d(%Array* %1, i64 1) +// QIR: %[[VAL_6:.*]] = load %[[VAL_3]]*, %[[VAL_3]]** %[[VAL_5]], align 8 +// QIR: tail call void @__quantum__qis__h(%[[VAL_3]]* %[[VAL_6]]) +// QIR: %[[VAL_7:.*]] = tail call %[[VAL_8:.*]]* @__quantum__qis__mz(%[[VAL_3]]* %[[VAL_4]]) +// QIR: %[[VAL_9:.*]] = bitcast %[[VAL_8]]* %[[VAL_7]] to i1* +// QIR: %[[VAL_10:.*]] = load i1, i1* %[[VAL_9]], align 1 +// QIR: %[[VAL_11:.*]] = zext i1 %[[VAL_10:.*]] to i8 +// QIR: %[[VAL_12:.*]] = tail call %[[VAL_8]]* @__quantum__qis__mz(%[[VAL_3]]* %[[VAL_6]]) +// QIR: %[[VAL_13:.*]] = bitcast %[[VAL_8]]* %[[VAL_12]] to i1* +// QIR: %[[VAL_14:.*]] = load i1, i1* %[[VAL_13]], align 1 +// QIR: %[[VAL_15:.*]] = zext i1 %[[VAL_14]] to i8 +// QIR: %[[VAL_16:.*]] = tail call dereferenceable_or_null(2) i8* @malloc(i64 2) +// QIR: store i8 %[[VAL_11]], i8* %[[VAL_16]], align 1 +// QIR: %[[VAL_17:.*]] = getelementptr inbounds i8, i8* %[[VAL_16]], i64 1 +// QIR: store i8 %[[VAL_15]], i8* %[[VAL_17]], align 1 +// QIR: %[[VAL_18:.*]] = bitcast i8* %[[VAL_16]] to i1* +// QIR: %[[VAL_19:.*]] = insertvalue { i1*, i64 } undef, i1* %[[VAL_18]], 0 +// QIR: %[[VAL_20:.*]] = insertvalue { i1*, i64 } %[[VAL_19]], i64 2, 1 +// QIR: tail call void @__quantum__rt__qubit_release_array(%[[VAL_1]]* %[[VAL_0]]) +// QIR: ret { i1*, i64 } %[[VAL_20]] // QIR: } // QIR-LABEL: define { i1*, i64 } @__nvqpp__mlirgen__MooseTracks() -// QIR: %[[VAL_0:.*]] = alloca [4 x { double, double }] -// QIR: %[[VAL_1:.*]] = getelementptr inbounds [4 x { double, double }], [4 x { double, double }]* %[[VAL_0]], i64 0, i64 0 -// QIR: %[[VAL_2:.*]] = getelementptr inbounds [4 x { double, double }], [4 x { double, double }]* %[[VAL_0]], i64 0, i64 0, i32 0 -// QIR: store double 0.000000e+00, double* %[[VAL_2]] -// QIR: %[[VAL_3:.*]] = getelementptr inbounds [4 x { double, double }], [4 x { double, double }]* %[[VAL_0]], i64 0, i64 0, i32 1 -// QIR: store double 1.000000e+00, double* %[[VAL_3]] -// QIR: %[[VAL_4:.*]] = getelementptr inbounds [4 x { double, double }], [4 x { double, double }]* %[[VAL_0]], i64 0, i64 1, i32 0 -// QIR: store double 7.500000e-01, double* %[[VAL_4]] -// QIR: %[[VAL_5:.*]] = getelementptr inbounds [4 x { double, double }], [4 x { double, double }]* %[[VAL_0]], i64 0, i64 1, i32 1 -// QIR: store double 2.500000e-01, double* %[[VAL_5]] -// QIR: %[[VAL_6:.*]] = getelementptr inbounds [4 x { double, double }], [4 x { double, double }]* %[[VAL_0]], i64 0, i64 2, i32 0 -// QIR: store double 1.000000e+00, double* %[[VAL_6]] -// QIR: %[[VAL_7:.*]] = getelementptr inbounds [4 x { double, double }], [4 x { double, double }]* %[[VAL_0]], i64 0, i64 2, i32 1 -// QIR: %[[VAL_8:.*]] = bitcast double* %[[VAL_7]] to i8* -// QIR: call void @llvm.memset.p0i8.i64(i8* {{.*}}%[[VAL_8]], i8 0, i64 24, i1 false) -// QIR: %[[VAL_9:.*]] = call %[[VAL_10:.*]]* @__quantum__rt__qubit_allocate_array_with_state_complex64(i64 2, { double, double }* nonnull %[[VAL_1]]) -// QIR: %[[VAL_11:.*]] = call %[[VAL_12:.*]]** @__quantum__rt__array_get_element_ptr_1d(%[[VAL_10]]* %[[VAL_9]], i64 0) -// QIR: %[[VAL_13:.*]] = load %[[VAL_12]]*, %[[VAL_12]]** %[[VAL_11]] -// QIR: call void @__quantum__qis__h(%[[VAL_12]]* %[[VAL_13]]) -// QIR: %[[VAL_14:.*]] = call %[[VAL_12]]** @__quantum__rt__array_get_element_ptr_1d(%[[VAL_10]]* %[[VAL_9]], i64 1) -// QIR: %[[VAL_15:.*]] = load %[[VAL_12]]*, %[[VAL_12]]** %[[VAL_14]] -// QIR: call void @__quantum__qis__h(%[[VAL_12]]* %[[VAL_15]]) -// QIR: %[[VAL_16:.*]] = call %[[VAL_17:.*]]* @__quantum__qis__mz(%[[VAL_12]]* %[[VAL_13]]) -// QIR: %[[VAL_18:.*]] = bitcast %[[VAL_17]]* %[[VAL_16]] to i1* -// QIR: %[[VAL_19:.*]] = load i1, i1* %[[VAL_18]] -// QIR: %[[VAL_20:.*]] = zext i1 %[[VAL_19]] to i8 -// QIR: %[[VAL_21:.*]] = call %[[VAL_17]]* @__quantum__qis__mz(%[[VAL_12]]* %[[VAL_15]]) -// QIR: %[[VAL_22:.*]] = bitcast %[[VAL_17]]* %[[VAL_21]] to i1* -// QIR: %[[VAL_23:.*]] = load i1, i1* %[[VAL_22]] -// QIR: %[[VAL_24:.*]] = zext i1 %[[VAL_23]] to i8 -// QIR: %[[VAL_25:.*]] = call dereferenceable_or_null(2) i8* @malloc(i64 2) -// QIR: store i8 %[[VAL_20]], i8* %[[VAL_25]] -// QIR: %[[VAL_26:.*]] = getelementptr inbounds i8, i8* %[[VAL_25]], i64 1 -// QIR: store i8 %[[VAL_24]], i8* %[[VAL_26]] -// QIR: %[[VAL_27:.*]] = bitcast i8* %[[VAL_25]] to i1* -// QIR: %[[VAL_28:.*]] = insertvalue { i1*, i64 } undef, i1* %[[VAL_27]], 0 -// QIR: %[[VAL_29:.*]] = insertvalue { i1*, i64 } %[[VAL_28]], i64 2, 1 -// QIR: call void @__quantum__rt__qubit_release_array(%[[VAL_10]]* %[[VAL_9]]) -// QIR: ret { i1*, i64 } %[[VAL_29]] +// QIR: %[[VAL_0:.*]] = tail call %[[VAL_1:.*]]* @__quantum__rt__qubit_allocate_array_with_state_complex64(i64 2, { double, double }* nonnull getelementptr inbounds ([4 x { double, double }], [4 x { double, double }]* @__nvqpp__mlirgen__MooseTracks.rodata_0, i64 0, i64 0)) +// QIR: %[[VAL_2:.*]] = tail call %[[VAL_3:.*]]** @__quantum__rt__array_get_element_ptr_1d(%[[VAL_1]]* %[[VAL_0]], i64 0) +// QIR: %[[VAL_4:.*]] = load %[[VAL_3]]*, %[[VAL_3]]** %[[VAL_2]], align 8 +// QIR: tail call void @__quantum__qis__h(%[[VAL_3]]* %[[VAL_4]]) +// QIR: %[[VAL_5:.*]] = tail call %[[VAL_3]]** @__quantum__rt__array_get_element_ptr_1d(%[[VAL_1]]* %[[VAL_0]], i64 1) +// QIR: %[[VAL_6:.*]] = load %[[VAL_3]]*, %[[VAL_3]]** %[[VAL_5]], align 8 +// QIR: tail call void @__quantum__qis__h(%[[VAL_3]]* %[[VAL_6]]) +// QIR: %[[VAL_8:.*]] = tail call %[[VAL_7:.*]]* @__quantum__qis__mz(%[[VAL_3]]* %[[VAL_4]]) +// QIR: %[[VAL_9:.*]] = bitcast %[[VAL_7]]* %[[VAL_8]] to i1* +// QIR: %[[VAL_10:.*]] = load i1, i1* %[[VAL_9]], align 1 +// QIR: %[[VAL_11:.*]] = zext i1 %[[VAL_10]] to i8 +// QIR: %[[VAL_12:.*]] = tail call %[[VAL_7]]* @__quantum__qis__mz(%[[VAL_3]]* %[[VAL_6]]) +// QIR: %[[VAL_13:.*]] = bitcast %[[VAL_7]]* %[[VAL_12]] to i1* +// QIR: %[[VAL_14:.*]] = load i1, i1* %[[VAL_13]], align 1 +// QIR: %[[VAL_15:.*]] = zext i1 %[[VAL_14]] to i8 +// QIR: %[[VAL_16:.*]] = tail call dereferenceable_or_null(2) i8* @malloc(i64 2) +// QIR: store i8 %[[VAL_11]], i8* %[[VAL_16]], align 1 +// QIR: %[[VAL_17:.*]] = getelementptr inbounds i8, i8* %[[VAL_16]], i64 1 +// QIR: store i8 %[[VAL_15]], i8* %[[VAL_17]], align 1 +// QIR: %[[VAL_18:.*]] = bitcast i8* %[[VAL_16]] to i1* +// QIR: %[[VAL_19:.*]] = insertvalue { i1*, i64 } undef, i1* %[[VAL_18]], 0 +// QIR: %[[VAL_20]] = insertvalue { i1*, i64 } %[[VAL_19]], i64 2, 1 +// QIR: tail call void @__quantum__rt__qubit_release_array(%[[VAL_1]]* %[[VAL_0]]) +// QIR: ret { i1*, i64 } %[[VAL_20]] // QIR: } // QIR-LABEL: define { i1*, i64 } @__nvqpp__mlirgen__RockyRoad() diff --git a/test/Quake/loop_unroll.qke b/test/Quake/loop_unroll.qke new file mode 100644 index 00000000000..65d8e5cb696 --- /dev/null +++ b/test/Quake/loop_unroll.qke @@ -0,0 +1,233 @@ +// ========================================================================== // +// Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. // +// All rights reserved. // +// // +// This source code and the accompanying materials are made available under // +// the terms of the Apache License 2.0 which accompanies this distribution. // +// ========================================================================== // + +// RUN: cudaq-opt -cc-loop-unroll -cse -cc-loop-unroll %s | FileCheck %s + +func.func @test_loop_unroll() { + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c0_i64 = arith.constant 0 : i64 + %0 = quake.alloca !quake.veq<6> + %1 = quake.extract_ref %0[0] : (!quake.veq<6>) -> !quake.ref + quake.x %1 : (!quake.ref) -> () + %2 = math.absi %c2_i64 : i64 + %3 = cc.alloca i64[%2 : i64] + %4:2 = cc.loop while ((%arg0 = %c0_i64, %arg1 = %c0_i64) -> (i64, i64)) { + %25 = arith.cmpi slt, %arg0, %c2_i64 : i64 + cc.condition %25(%arg0, %arg1 : i64, i64) + } do { + ^bb0(%arg0: i64, %arg1: i64): + %25 = cc.compute_ptr %3[%arg1] : (!cc.ptr>, i64) -> !cc.ptr + cc.store %arg1, %25 : !cc.ptr + %26 = arith.addi %arg1, %c1_i64 : i64 + cc.continue %arg0, %26 : i64, i64 + } step { + ^bb0(%arg0: i64, %arg1: i64): + %25 = arith.addi %arg0, %c1_i64 : i64 + cc.continue %25, %arg1 : i64, i64 + } {invariant} + %5 = cc.alloca i64[%c2_i64 : i64] + %6 = cc.loop while ((%arg0 = %c0_i64) -> (i64)) { + %25 = arith.cmpi slt, %arg0, %c2_i64 : i64 + cc.condition %25(%arg0 : i64) + } do { + ^bb0(%arg0: i64): + %25 = cc.compute_ptr %3[%arg0] : (!cc.ptr>, i64) -> !cc.ptr + %26 = cc.load %25 : !cc.ptr + %27 = arith.muli %26, %c2_i64 : i64 + %28 = cc.compute_ptr %5[%arg0] : (!cc.ptr>, i64) -> !cc.ptr + cc.store %27, %28 : !cc.ptr + cc.continue %arg0 : i64 + } step { + ^bb0(%arg0: i64): + %25 = arith.addi %arg0, %c1_i64 : i64 + cc.continue %25 : i64 + } {invariant} + %7 = cc.stdvec_init %5, %c2_i64 : (!cc.ptr>, i64) -> !cc.stdvec + %8 = cc.stdvec_size %7 : (!cc.stdvec) -> i64 + %9 = arith.subi %8, %c1_i64 : i64 + %10:2 = cc.loop while ((%arg0 = %c0_i64, %arg1 = %c0_i64) -> (i64, i64)) { + %25 = arith.cmpi slt, %arg0, %9 : i64 + cc.condition %25(%arg0, %arg1 : i64, i64) + } do { + ^bb0(%arg0: i64, %arg1: i64): + %25 = arith.addi %arg0, %c1_i64 : i64 + %26:2 = cc.loop while ((%arg2 = %25, %arg3 = %arg1) -> (i64, i64)) { + %27 = arith.cmpi slt, %arg2, %8 : i64 + cc.condition %27(%arg2, %arg3 : i64, i64) + } do { + ^bb0(%arg2: i64, %arg3: i64): + %27 = arith.addi %arg3, %c1_i64 : i64 + cc.continue %arg2, %27 : i64, i64 + } step { + ^bb0(%arg2: i64, %arg3: i64): + %27 = arith.addi %arg2, %c1_i64 : i64 + cc.continue %27, %arg3 : i64, i64 + } {invariant} + cc.continue %arg0, %26#1 : i64, i64 + } step { + ^bb0(%arg0: i64, %arg1: i64): + %25 = arith.addi %arg0, %c1_i64 : i64 + cc.continue %25, %arg1 : i64, i64 + } {invariant} + %11 = math.absi %10#1 : i64 + %12 = cc.alloca i64[%11 : i64] + %13:2 = cc.loop while ((%arg0 = %c0_i64, %arg1 = %c0_i64) -> (i64, i64)) { + %25 = arith.cmpi slt, %arg0, %10#1 : i64 + cc.condition %25(%arg0, %arg1 : i64, i64) + } do { + ^bb0(%arg0: i64, %arg1: i64): + %25 = cc.compute_ptr %12[%arg1] : (!cc.ptr>, i64) -> !cc.ptr + cc.store %arg1, %25 : !cc.ptr + %26 = arith.addi %arg1, %c1_i64 : i64 + cc.continue %arg0, %26 : i64, i64 + } step { + ^bb0(%arg0: i64, %arg1: i64): + %25 = arith.addi %arg0, %c1_i64 : i64 + cc.continue %25, %arg1 : i64, i64 + } {invariant} + %14 = cc.alloca i64[%10#1 : i64] + %15 = cc.loop while ((%arg0 = %c0_i64) -> (i64)) { + %25 = arith.cmpi slt, %arg0, %10#1 : i64 + cc.condition %25(%arg0 : i64) + } do { + ^bb0(%arg0: i64): + %25 = cc.compute_ptr %14[%arg0] : (!cc.ptr>, i64) -> !cc.ptr + cc.store %c0_i64, %25 : !cc.ptr + cc.continue %arg0 : i64 + } step { + ^bb0(%arg0: i64): + %25 = arith.addi %arg0, %c1_i64 : i64 + cc.continue %25 : i64 + } {invariant} + %16 = cc.stdvec_init %14, %10#1 : (!cc.ptr>, i64) -> !cc.stdvec + %17 = cc.alloca i64[%11 : i64] + %18:2 = cc.loop while ((%arg0 = %c0_i64, %arg1 = %c0_i64) -> (i64, i64)) { + %25 = arith.cmpi slt, %arg0, %10#1 : i64 + cc.condition %25(%arg0, %arg1 : i64, i64) + } do { + ^bb0(%arg0: i64, %arg1: i64): + %25 = cc.compute_ptr %17[%arg1] : (!cc.ptr>, i64) -> !cc.ptr + cc.store %arg1, %25 : !cc.ptr + %26 = arith.addi %arg1, %c1_i64 : i64 + cc.continue %arg0, %26 : i64, i64 + } step { + ^bb0(%arg0: i64, %arg1: i64): + %25 = arith.addi %arg0, %c1_i64 : i64 + cc.continue %25, %arg1 : i64, i64 + } {invariant} + %19 = cc.alloca i64[%10#1 : i64] + %20 = cc.loop while ((%arg0 = %c0_i64) -> (i64)) { + %25 = arith.cmpi slt, %arg0, %10#1 : i64 + cc.condition %25(%arg0 : i64) + } do { + ^bb0(%arg0: i64): + %25 = cc.compute_ptr %19[%arg0] : (!cc.ptr>, i64) -> !cc.ptr + cc.store %c0_i64, %25 : !cc.ptr + cc.continue %arg0 : i64 + } step { + ^bb0(%arg0: i64): + %25 = arith.addi %arg0, %c1_i64 : i64 + cc.continue %25 : i64 + } {invariant} + %21 = cc.stdvec_init %19, %10#1 : (!cc.ptr>, i64) -> !cc.stdvec + %22:2 = cc.loop while ((%arg0 = %c0_i64, %arg1 = %c0_i64) -> (i64, i64)) { + %25 = arith.cmpi slt, %arg0, %9 : i64 + cc.condition %25(%arg0, %arg1 : i64, i64) + } do { + ^bb0(%arg0: i64, %arg1: i64): + %25 = arith.addi %arg0, %c1_i64 : i64 + %26:2 = cc.loop while ((%arg2 = %25, %arg3 = %arg1) -> (i64, i64)) { + %27 = arith.cmpi slt, %arg2, %8 : i64 + cc.condition %27(%arg2, %arg3 : i64, i64) + } do { + ^bb0(%arg2: i64, %arg3: i64): + %27 = cc.stdvec_data %16 : (!cc.stdvec) -> !cc.ptr> + %28 = cc.compute_ptr %27[%arg3] : (!cc.ptr>, i64) -> !cc.ptr + %29 = cc.stdvec_data %7 : (!cc.stdvec) -> !cc.ptr> + %30 = cc.compute_ptr %29[%arg0] : (!cc.ptr>, i64) -> !cc.ptr + %31 = cc.load %30 : !cc.ptr + cc.store %31, %28 : !cc.ptr + %32 = cc.stdvec_data %21 : (!cc.stdvec) -> !cc.ptr> + %33 = cc.compute_ptr %32[%arg3] : (!cc.ptr>, i64) -> !cc.ptr + %34 = cc.compute_ptr %29[%arg2] : (!cc.ptr>, i64) -> !cc.ptr + %35 = cc.load %34 : !cc.ptr + cc.store %35, %33 : !cc.ptr + %36 = arith.addi %arg3, %c1_i64 : i64 + cc.continue %arg2, %36 : i64, i64 + } step { + ^bb0(%arg2: i64, %arg3: i64): + %27 = arith.addi %arg2, %c1_i64 : i64 + cc.continue %27, %arg3 : i64, i64 + } {invariant} + cc.continue %arg0, %26#1 : i64, i64 + } step { + ^bb0(%arg0: i64, %arg1: i64): + %25 = arith.addi %arg0, %c1_i64 : i64 + cc.continue %25, %arg1 : i64, i64 + } {invariant} + %23 = cc.stdvec_size %16 : (!cc.stdvec) -> i64 + %24 = cc.loop while ((%arg0 = %c0_i64) -> (i64)) { + %25 = arith.cmpi slt, %arg0, %23 : i64 + cc.condition %25(%arg0 : i64) + } do { + ^bb0(%arg0: i64): + %25 = cc.stdvec_data %16 : (!cc.stdvec) -> !cc.ptr> + %26 = cc.compute_ptr %25[%arg0] : (!cc.ptr>, i64) -> !cc.ptr + %27 = cc.load %26 : !cc.ptr + %28 = cc.stdvec_data %21 : (!cc.stdvec) -> !cc.ptr> + %29 = cc.compute_ptr %28[%arg0] : (!cc.ptr>, i64) -> !cc.ptr + %30 = cc.load %29 : !cc.ptr + %31 = arith.cmpi slt, %27, %30 : i64 + %32:2 = cc.if(%31) -> (i64, i64) { + cc.continue %27, %30 : i64, i64 + } else { + %34 = arith.cmpi sgt, %27, %30 : i64 + %35:2 = cc.if(%34) -> (i64, i64) { + cc.continue %30, %27 : i64, i64 + } else { + cc.continue %c0_i64, %c0_i64 : i64, i64 + } + cc.continue %35#0, %35#1 : i64, i64 + } + %33 = cc.loop while ((%arg1 = %32#0) -> (i64)) { + %34 = arith.cmpi slt, %arg1, %32#1 : i64 + cc.condition %34(%arg1 : i64) + } do { + ^bb0(%arg1: i64): + %34 = quake.extract_ref %0[%arg1] : (!quake.veq<6>, i64) -> !quake.ref + %35 = arith.addi %arg1, %c1_i64 : i64 + %36 = quake.extract_ref %0[%35] : (!quake.veq<6>, i64) -> !quake.ref + quake.x [%34] %36 : (!quake.ref, !quake.ref) -> () + cc.continue %arg1 : i64 + } step { + ^bb0(%arg1: i64): + %34 = arith.addi %arg1, %c1_i64 : i64 + cc.continue %34 : i64 + } {invariant} + cc.continue %arg0 : i64 + } step { + ^bb0(%arg0: i64): + %25 = arith.addi %arg0, %c1_i64 : i64 + cc.continue %25 : i64 + } {invariant} + return +} + +// CHECK-LABEL: func.func @test_loop_unroll() { +// CHECK: %[[VAL_0:.*]] = quake.alloca !quake.veq<6> +// CHECK: %[[VAL_1:.*]] = quake.extract_ref %[[VAL_0]][0] : (!quake.veq<6>) -> !quake.ref +// CHECK: quake.x %[[VAL_1]] : (!quake.ref) -> () +// CHECK: %[[VAL_2:.*]] = quake.extract_ref %[[VAL_0]][0] : (!quake.veq<6>) -> !quake.ref +// CHECK: %[[VAL_3:.*]] = quake.extract_ref %[[VAL_0]][1] : (!quake.veq<6>) -> !quake.ref +// CHECK: quake.x [%[[VAL_2]]] %[[VAL_3]] : (!quake.ref, !quake.ref) -> () +// CHECK: %[[VAL_4:.*]] = quake.extract_ref %[[VAL_0]][1] : (!quake.veq<6>) -> !quake.ref +// CHECK: %[[VAL_5:.*]] = quake.extract_ref %[[VAL_0]][2] : (!quake.veq<6>) -> !quake.ref +// CHECK: quake.x [%[[VAL_4]]] %[[VAL_5]] : (!quake.ref, !quake.ref) -> () +// CHECK: return +// CHECK: }