Skip to content

Commit 130ba15

Browse files
Merge branch 'main' into test-publishing
2 parents db285ea + a883ace commit 130ba15

File tree

19 files changed

+880
-67
lines changed

19 files changed

+880
-67
lines changed

include/cudaq/Optimizer/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,4 +1398,15 @@ def WriteAfterWriteElimination : Pass<"write-after-write-elimination"> {
13981398
}];
13991399
}
14001400

1401+
def QubitResetBeforeReuse : Pass<"qubit-reset-before-reuse", "mlir::func::FuncOp"> {
1402+
let summary = "Add qubit reset and conditional initialization after measurement if qubit is to be reused.";
1403+
let description = [{
1404+
This pass adds qubit reset and conditionally applies an X gate if the measurement result is 1
1405+
to initialize qubit into the correct state after measurement. This is only activated when the measured qubit
1406+
is to be reused.
1407+
Note: if the measurement is already accompanied by a reset, we won't add any extra reset.
1408+
}];
1409+
1410+
}
1411+
14011412
#endif // CUDAQ_OPT_OPTIMIZER_TRANSFORMS_PASSES

lib/Optimizer/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ add_cudaq_library(OptTransforms
6363
RefToVeqAlloc.cpp
6464
RegToMem.cpp
6565
ReplaceStateWithKernel.cpp
66+
ResetBeforeReuse.cpp
6667
ResourceCountPreprocess.cpp
6768
SROA.cpp
6869
StackFramePrealloc.cpp
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
/*******************************************************************************
2+
* Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. *
3+
* All rights reserved. *
4+
* *
5+
* This source code and the accompanying materials are made available under *
6+
* the terms of the Apache License 2.0 which accompanies this distribution. *
7+
******************************************************************************/
8+
9+
#include "PassDetails.h"
10+
#include "cudaq/Optimizer/CodeGen/Emitter.h"
11+
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
12+
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
13+
#include "cudaq/Optimizer/Transforms/Passes.h"
14+
#include "cudaq/Todo.h"
15+
#include "llvm/Support/Debug.h"
16+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
17+
#include "mlir/Dialect/Func/IR/FuncOps.h"
18+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19+
#include "mlir/IR/Dominance.h"
20+
#include "mlir/Transforms/DialectConversion.h"
21+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22+
#include "mlir/Transforms/Passes.h"
23+
24+
namespace cudaq::opt {
25+
#define GEN_PASS_DEF_QUBITRESETBEFOREREUSE
26+
#include "cudaq/Optimizer/Transforms/Passes.h.inc"
27+
} // namespace cudaq::opt
28+
29+
#define DEBUG_TYPE "reset-before-reuse"
30+
31+
using namespace mlir;
32+
33+
namespace {
34+
35+
static SmallVector<Operation *, 8> sortUsers(const Value::user_range &users,
36+
const DominanceInfo &dom) {
37+
SmallVector<Operation *, 8> orderedUsers;
38+
39+
for (auto *user : users) {
40+
if ([&]() {
41+
for (auto iter = orderedUsers.begin(), iterEnd = orderedUsers.end();
42+
iter != iterEnd; ++iter) {
43+
assert(*iter);
44+
if (dom.dominates(user, *iter)) {
45+
orderedUsers.insert(iter, user);
46+
return false;
47+
}
48+
}
49+
return true;
50+
}())
51+
orderedUsers.push_back(user);
52+
}
53+
return orderedUsers;
54+
}
55+
56+
// Track qubit register use chains.
57+
// This is used to track if a qubit is reused after it has been measured across
58+
// different extract ops. We want to collect this beforehand so that we don't
59+
// need to repeat for each extract op.
60+
class RegUseTracker {
61+
mlir::DenseMap<mlir::Value, SmallVector<Operation *, 8>> regToOrderedUsers;
62+
DominanceInfo domInfo;
63+
64+
public:
65+
RegUseTracker(func::FuncOp func) : domInfo(func) {
66+
func->walk([&](quake::AllocaOp qalloc) {
67+
regToOrderedUsers[qalloc.getResult()] =
68+
sortUsers(qalloc.getResult().getUsers(), domInfo);
69+
});
70+
}
71+
72+
SmallVector<Operation *, 8> getUsers(mlir::Value qreg) const {
73+
if (!isa<quake::VeqType>(qreg.getType()))
74+
mlir::emitError(qreg.getLoc(),
75+
"Unexpected type used: expected a quake::VeqType.");
76+
77+
auto iter = regToOrderedUsers.find(qreg);
78+
if (iter != regToOrderedUsers.end())
79+
return iter->second;
80+
mlir::emitWarning(qreg.getLoc(), "Qubit vector is not tracked.");
81+
return {};
82+
}
83+
DominanceInfo &getDominanceInfo() { return domInfo; }
84+
RegUseTracker(const RegUseTracker &) = delete;
85+
RegUseTracker(RegUseTracker &&) = delete;
86+
RegUseTracker &operator=(const RegUseTracker &) = delete;
87+
};
88+
89+
class ResetAfterMeasurePattern : public OpRewritePattern<quake::MzOp> {
90+
public:
91+
using OpRewritePattern::OpRewritePattern;
92+
93+
explicit ResetAfterMeasurePattern(MLIRContext *ctx, RegUseTracker &tracker)
94+
: OpRewritePattern(ctx), tracker(tracker) {}
95+
96+
LogicalResult matchAndRewrite(quake::MzOp mz,
97+
PatternRewriter &rewriter) const override {
98+
SmallVector<Operation *> useOps;
99+
for (Value measuredQubit : mz.getTargets()) {
100+
auto *nextOp = getNextUse(measuredQubit, mz);
101+
if (nextOp) {
102+
// If the user is a reset/measure op, nothing to do.
103+
if (isa<quake::ResetOp>(nextOp) || isa<quake::MzOp>(nextOp)) {
104+
continue;
105+
}
106+
107+
// If this is a dealloc op, nothing to do.
108+
if (isa<quake::DeallocOp>(nextOp)) {
109+
continue;
110+
}
111+
112+
// Insert reset
113+
Location loc = mz->getLoc();
114+
rewriter.setInsertionPointAfter(mz);
115+
rewriter.create<quake::ResetOp>(loc, TypeRange{}, measuredQubit);
116+
// Insert a conditional X to initialize qubit after reset.
117+
auto measOut = mz.getMeasOut();
118+
mlir::Value measBit = [&]() {
119+
for (auto *out : measOut.getUsers()) {
120+
// A mz may be accompanied by a store op, find that op.
121+
if (auto disc = dyn_cast_if_present<quake::DiscriminateOp>(out)) {
122+
rewriter.setInsertionPointAfter(disc);
123+
return disc.getResult();
124+
}
125+
}
126+
// No discriminate exists - create the discriminate Op
127+
auto discOp = rewriter.create<quake::DiscriminateOp>(
128+
loc, rewriter.getI1Type(), measOut);
129+
return discOp.getResult();
130+
}();
131+
rewriter.create<cudaq::cc::IfOp>(
132+
loc, TypeRange{}, measBit,
133+
[&](OpBuilder &opBuilder, Location location, Region &region) {
134+
region.push_back(new Block{});
135+
auto &bodyBlock = region.front();
136+
OpBuilder::InsertionGuard guad(opBuilder);
137+
opBuilder.setInsertionPointToStart(&bodyBlock);
138+
opBuilder.create<quake::XOp>(location, measuredQubit);
139+
opBuilder.create<cudaq::cc::ContinueOp>(location);
140+
});
141+
} else {
142+
LLVM_DEBUG(llvm::dbgs() << "No next use\n");
143+
}
144+
}
145+
146+
return failure();
147+
}
148+
149+
private:
150+
Operation *getNextUse(Value qubit, Operation *op) const {
151+
auto &dom = tracker.getDominanceInfo();
152+
{
153+
// Check direct use
154+
const auto orderedUsers = sortUsers(qubit.getUsers(), dom);
155+
for (auto v : llvm::enumerate(orderedUsers))
156+
if (v.value() == op && v.index() < (orderedUsers.size() - 1) &&
157+
dom.dominates(op, orderedUsers[v.index() + 1]))
158+
return orderedUsers[v.index() + 1];
159+
}
160+
161+
// No next use is found, check if this is an extracted qubit.
162+
if (isa<quake::RefType>(qubit.getType())) {
163+
if (auto extractOp =
164+
dyn_cast_if_present<quake::ExtractRefOp>(qubit.getDefiningOp())) {
165+
LLVM_DEBUG(llvm::dbgs() << "Defining op: " << *extractOp << "\n");
166+
auto reg = extractOp.getVeq();
167+
std::optional<int64_t> index =
168+
extractOp.hasConstantIndex()
169+
? std::optional<int64_t>(extractOp.getConstantIndex())
170+
: cudaq::getIndexValueAsInt(extractOp.getIndex());
171+
LLVM_DEBUG(llvm::dbgs() << "Reg: " << reg
172+
<< "; index = " << index.value_or(-1) << "\n");
173+
if (isa<quake::AllocaOp>(reg.getDefiningOp())) {
174+
const auto orderedUsers = tracker.getUsers(reg);
175+
for (auto v : llvm::enumerate(orderedUsers)) {
176+
if (v.value() != extractOp) {
177+
// This is another extract.
178+
auto nextExtractOp =
179+
dyn_cast_or_null<quake::ExtractRefOp>(v.value());
180+
if (nextExtractOp) {
181+
std::optional<int64_t> nextIndex =
182+
nextExtractOp.hasConstantIndex()
183+
? nextExtractOp.getConstantIndex()
184+
: cudaq::getIndexValueAsInt(nextExtractOp.getIndex());
185+
if ((!index.has_value() || !nextIndex.has_value()) ||
186+
(index == nextIndex)) {
187+
// Either the previous index or this index is unknown, we
188+
// assume that they may be the same.
189+
const auto extractedQubit = nextExtractOp.getRef();
190+
const auto extractedQubitOrderedUsers =
191+
sortUsers(extractedQubit.getUsers(), dom);
192+
for (auto *user : extractedQubitOrderedUsers) {
193+
// If the use is dominated by the original mz op,
194+
// then this is the next use.
195+
if (dom.dominates(op, user)) {
196+
LLVM_DEBUG(llvm::dbgs() << "Next use: " << *user << "\n");
197+
return user;
198+
}
199+
}
200+
}
201+
}
202+
}
203+
}
204+
}
205+
}
206+
}
207+
return nullptr;
208+
}
209+
210+
RegUseTracker &tracker;
211+
};
212+
213+
class QubitResetBeforeReusePass
214+
: public cudaq::opt::impl::QubitResetBeforeReuseBase<
215+
QubitResetBeforeReusePass> {
216+
public:
217+
using QubitResetBeforeReuseBase::QubitResetBeforeReuseBase;
218+
QubitResetBeforeReusePass() = default;
219+
220+
void runOnOperation() override {
221+
func::FuncOp funcOp = getOperation();
222+
if (funcOp.empty())
223+
return;
224+
auto *ctx = &getContext();
225+
RegUseTracker tracker(funcOp);
226+
RewritePatternSet patterns(ctx);
227+
patterns.insert<ResetAfterMeasurePattern>(ctx, tracker);
228+
if (failed(applyPatternsAndFoldGreedily(funcOp.getOperation(),
229+
std::move(patterns)))) {
230+
funcOp.emitOpError("Adding qubit reset before reuse pass failed");
231+
signalPassFailure();
232+
}
233+
}
234+
};
235+
} // namespace

lib/Optimizer/Transforms/ResourceCountPreprocess.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,14 @@ struct ResourceCountPreprocessPass
9494
void runOnOperation() override {
9595
auto func = getOperation();
9696

97-
for (auto &b : func.getBody())
97+
for (auto &b : func.getBody()) {
98+
// We only pre-process the main block as the other blocks may be
99+
// conditional when the IR is lowered to CFG.
100+
if (&b != &func.getBody().front())
101+
continue;
98102
for (auto &op : b.getOperations())
99103
preprocessOp(&op);
100-
104+
}
101105
for (auto op : to_erase)
102106
op->erase();
103107

0 commit comments

Comments
 (0)