Skip to content

Commit 983f7b2

Browse files
authored
Merge branch 'main' into main
2 parents db6311c + 3602887 commit 983f7b2

File tree

26 files changed

+549
-501
lines changed

26 files changed

+549
-501
lines changed

include/cudaq/Optimizer/CodeGen/Pipelines.h

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,35 +30,18 @@ void commonPipelineConvertToQIR(mlir::PassManager &pm,
3030
mlir::StringRef codeGenFor = "qir",
3131
mlir::StringRef passConfigAs = "qir");
3232

33-
/// \deprecated{Only for Python, since it can't use the new QIR codegen.}
34-
void commonPipelineConvertToQIR_PythonWorkaround(
35-
mlir::PassManager &pm, const std::optional<mlir::StringRef> &convertTo);
36-
3733
/// \brief Pipeline builder to convert Quake to QIR.
3834
/// Does not specify a particular QIR profile.
3935
inline void addPipelineConvertToQIR(mlir::PassManager &pm) {
4036
commonPipelineConvertToQIR(pm);
4137
}
4238

43-
/// \deprecated{Only for Python, since it can't use the new QIR codegen.}
44-
inline void addPipelineConvertToQIR_PythonWorkaround(mlir::PassManager &pm) {
45-
commonPipelineConvertToQIR_PythonWorkaround(pm, std::nullopt);
46-
}
47-
4839
/// \brief Pipeline builder to convert Quake to QIR.
4940
/// Specifies a particular QIR profile in \p convertTo.
5041
/// \p pm Pass manager to append passes to
5142
/// \p convertTo name of QIR profile (e.g., `qir-base`, `qir-adaptive`, ...)
5243
void addPipelineConvertToQIR(mlir::PassManager &pm, mlir::StringRef convertTo);
5344

54-
/// \deprecated{Only for Python, since it can't use the new QIR codegen.}
55-
inline void
56-
addPipelineConvertToQIR_PythonWorkaround(mlir::PassManager &pm,
57-
mlir::StringRef convertTo) {
58-
commonPipelineConvertToQIR_PythonWorkaround(pm, convertTo);
59-
addQIRProfilePipeline(pm, convertTo);
60-
}
61-
6245
void addLowerToCCPipeline(mlir::OpPassManager &pm);
6346

6447
void addPipelineTranslateToOpenQASM(mlir::PassManager &pm);

lib/Optimizer/CodeGen/ConvertToQIRAPI.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,7 @@ struct QuantumGatePattern : public OpConversionPattern<OP> {
10921092

10931093
// Process the controls, sorting them by type.
10941094
for (auto pr : llvm::zip(op.getControls(), adaptor.getControls())) {
1095-
if (isa<quake::VeqType>(std::get<0>(pr).getType())) {
1095+
if (isaVeqArgument(std::get<0>(pr).getType())) {
10961096
numArrayCtrls++;
10971097
auto sizeCall = rewriter.create<func::CallOp>(
10981098
loc, i64Ty, cudaq::opt::QIRArrayGetSize,
@@ -1155,6 +1155,18 @@ struct QuantumGatePattern : public OpConversionPattern<OP> {
11551155
return forwardOrEraseOp();
11561156
}
11571157

1158+
static bool isaVeqArgument(Type ty) {
1159+
// TODO: Need a way to identify arrays when using the opaque pointer
1160+
// variant. (In Python, the arguments may already be converted.)
1161+
auto alreadyConverted = [](Type ty) {
1162+
if (auto ptrTy = dyn_cast<cudaq::cc::PointerType>(ty))
1163+
if (auto strTy = dyn_cast<LLVM::LLVMStructType>(ptrTy.getElementType()))
1164+
return strTy.isIdentified() && strTy.getName() == "Array";
1165+
return false;
1166+
};
1167+
return isa<quake::VeqType>(ty) || alreadyConverted(ty);
1168+
}
1169+
11581170
static bool conformsToIntendedCall(std::size_t numControls, Value ctrl, OP op,
11591171
StringRef qirFunctionName) {
11601172
if (numControls != 1)
@@ -1819,9 +1831,7 @@ struct QuakeToQIRAPIPrepPass
18191831
}
18201832

18211833
void guaranteeMzIsLabeled(quake::MzOp mz, int &counter, OpBuilder &builder) {
1822-
if (mz.getRegisterNameAttr() &&
1823-
/* FIXME: issue 2538: the name should never be empty. */
1824-
!mz.getRegisterNameAttr().getValue().empty()) {
1834+
if (mz.getRegisterNameAttr()) {
18251835
mz->setAttr(cudaq::opt::MzAssignedNameAttrName, builder.getUnitAttr());
18261836
return;
18271837
}

lib/Optimizer/CodeGen/Pipelines.cpp

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -51,37 +51,6 @@ void cudaq::opt::commonPipelineConvertToQIR(PassManager &pm,
5151
pm.addPass(createCCToLLVM());
5252
}
5353

54-
void cudaq::opt::commonPipelineConvertToQIR_PythonWorkaround(
55-
PassManager &pm, const std::optional<StringRef> &convertTo) {
56-
pm.addNestedPass<func::FuncOp>(createApplyControlNegations());
57-
addAggressiveEarlyInlining(pm);
58-
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
59-
pm.addNestedPass<func::FuncOp>(createUnwindLoweringPass());
60-
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
61-
pm.addPass(createApplyOpSpecializationPass());
62-
pm.addNestedPass<func::FuncOp>(createExpandMeasurementsPass());
63-
pm.addNestedPass<func::FuncOp>(createClassicalMemToReg());
64-
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
65-
pm.addNestedPass<func::FuncOp>(createCSEPass());
66-
pm.addNestedPass<func::FuncOp>(createQuakeAddDeallocs());
67-
pm.addNestedPass<func::FuncOp>(createQuakeAddMetadata());
68-
pm.addNestedPass<func::FuncOp>(createLoopNormalize());
69-
LoopUnrollOptions luo;
70-
luo.allowBreak = convertTo && (*convertTo == "qir-adaptive");
71-
pm.addNestedPass<func::FuncOp>(createLoopUnroll(luo));
72-
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
73-
pm.addNestedPass<func::FuncOp>(createCSEPass());
74-
pm.addNestedPass<func::FuncOp>(createLowerToCFGPass());
75-
pm.addNestedPass<func::FuncOp>(createCombineQuantumAllocations());
76-
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
77-
pm.addNestedPass<func::FuncOp>(createCSEPass());
78-
if (convertTo && (*convertTo == "qir-base"))
79-
pm.addNestedPass<func::FuncOp>(createDelayMeasurementsPass());
80-
pm.addPass(createConvertMathToFuncs());
81-
pm.addPass(createSymbolDCEPass());
82-
pm.addPass(createConvertToQIR());
83-
}
84-
8554
void cudaq::opt::addPipelineTranslateToOpenQASM(PassManager &pm) {
8655
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
8756
pm.addNestedPass<func::FuncOp>(createCSEPass());

lib/Optimizer/Dialect/Quake/QuakeOps.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -517,38 +517,41 @@ void quake::WrapOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
517517
//===----------------------------------------------------------------------===//
518518

519519
// Common verification for measurement operations.
520-
static LogicalResult verifyMeasurements(Operation *const op,
521-
TypeRange targetsType,
522-
const Type bitsType) {
520+
template <typename MEAS>
521+
LogicalResult verifyMeasurements(MEAS op, TypeRange targetsType,
522+
const Type bitsType) {
523523
if (failed(verifyWireResultsAreLinear(op)))
524524
return failure();
525525
bool mustBeStdvec =
526526
targetsType.size() > 1 ||
527527
(targetsType.size() == 1 && isa<quake::VeqType>(targetsType[0]));
528528
if (mustBeStdvec) {
529-
if (!isa<cudaq::cc::StdvecType>(op->getResult(0).getType()))
530-
return op->emitOpError("must return `!cc.stdvec<!quake.measure>`, when "
531-
"measuring a qreg, a series of qubits, or both");
529+
if (!isa<cudaq::cc::StdvecType>(op.getMeasOut().getType()))
530+
return op.emitOpError("must return `!cc.stdvec<!quake.measure>`, when "
531+
"measuring a qreg, a series of qubits, or both");
532532
} else {
533-
if (!isa<quake::MeasureType>(op->getResult(0).getType()))
533+
if (!isa<quake::MeasureType>(op.getMeasOut().getType()))
534534
return op->emitOpError(
535535
"must return `!quake.measure` when measuring exactly one qubit");
536536
}
537+
if (op.getRegisterName())
538+
if (op.getRegisterName()->empty())
539+
return op->emitError("quake measurement name cannot be empty.");
537540
return success();
538541
}
539542

540543
LogicalResult quake::MxOp::verify() {
541-
return verifyMeasurements(getOperation(), getTargets().getType(),
544+
return verifyMeasurements(*this, getTargets().getType(),
542545
getMeasOut().getType());
543546
}
544547

545548
LogicalResult quake::MyOp::verify() {
546-
return verifyMeasurements(getOperation(), getTargets().getType(),
549+
return verifyMeasurements(*this, getTargets().getType(),
547550
getMeasOut().getType());
548551
}
549552

550553
LogicalResult quake::MzOp::verify() {
551-
return verifyMeasurements(getOperation(), getTargets().getType(),
554+
return verifyMeasurements(*this, getTargets().getType(),
552555
getMeasOut().getType());
553556
}
554557

lib/Optimizer/Transforms/GlobalizeArrayValues.cpp

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,23 @@ convertArrayAttrToGlobalConstant(MLIRContext *ctx, Location loc,
8787
}
8888

8989
namespace {
90+
91+
// This pattern replaces a cc.const_array with a global constant. It can
92+
// recognize a couple of usage patterns and will generate efficient IR in those
93+
// cases.
94+
//
95+
// Pattern 1: The entire constant array is stored to a stack variable(s). Here
96+
// we can eliminate the stack allocation and use the global constant.
97+
//
98+
// Pattern 2: Individual elements at dynamic offsets are extracted from the
99+
// constant array and used. This can be replaced with a compute pointer
100+
// operation using the global constant and a load of the element at the computed
101+
// offset.
102+
//
103+
// Default: If the usage is not recognized, the constant array value is replaced
104+
// with a load of the entire global variable. In this case, LLVM's optimizations
105+
// are counted on to help demote the (large?) sequence value to primitive memory
106+
// address arithmetic.
90107
struct ConstantArrayPattern
91108
: public OpRewritePattern<cudaq::cc::ConstantArrayOp> {
92109
explicit ConstantArrayPattern(MLIRContext *ctx, ModuleOp module,
@@ -95,21 +112,30 @@ struct ConstantArrayPattern
95112

96113
LogicalResult matchAndRewrite(cudaq::cc::ConstantArrayOp conarr,
97114
PatternRewriter &rewriter) const override {
115+
auto func = conarr->getParentOfType<func::FuncOp>();
116+
if (!func)
117+
return failure();
118+
98119
SmallVector<cudaq::cc::AllocaOp> allocas;
99120
SmallVector<cudaq::cc::StoreOp> stores;
121+
SmallVector<cudaq::cc::ExtractValueOp> extracts;
122+
bool loadAsValue = false;
100123
for (auto *usr : conarr->getUsers()) {
101124
auto store = dyn_cast<cudaq::cc::StoreOp>(usr);
102-
if (!store)
103-
return failure();
104-
auto alloca = store.getPtrvalue().getDefiningOp<cudaq::cc::AllocaOp>();
105-
if (!alloca)
106-
return failure();
107-
stores.push_back(store);
108-
allocas.push_back(alloca);
125+
auto extract = dyn_cast<cudaq::cc::ExtractValueOp>(usr);
126+
if (store) {
127+
auto alloca = store.getPtrvalue().getDefiningOp<cudaq::cc::AllocaOp>();
128+
if (alloca) {
129+
stores.push_back(store);
130+
allocas.push_back(alloca);
131+
continue;
132+
}
133+
} else if (extract) {
134+
extracts.push_back(extract);
135+
continue;
136+
}
137+
loadAsValue = true;
109138
}
110-
auto func = conarr->getParentOfType<func::FuncOp>();
111-
if (!func)
112-
return failure();
113139
std::string globalName =
114140
func.getName().str() + ".rodata_" + std::to_string(counter++);
115141
auto *ctx = rewriter.getContext();
@@ -118,12 +144,39 @@ struct ConstantArrayPattern
118144
if (failed(convertArrayAttrToGlobalConstant(ctx, conarr.getLoc(), valueAttr,
119145
module, globalName, eleTy)))
120146
return failure();
121-
for (auto alloca : allocas)
122-
rewriter.replaceOpWithNewOp<cudaq::cc::AddressOfOp>(
123-
alloca, alloca.getType(), globalName);
124-
for (auto store : stores)
125-
rewriter.eraseOp(store);
126-
rewriter.eraseOp(conarr);
147+
auto loc = conarr.getLoc();
148+
if (!extracts.empty()) {
149+
auto base = rewriter.create<cudaq::cc::AddressOfOp>(
150+
loc, cudaq::cc::PointerType::get(conarr.getType()), globalName);
151+
auto elePtrTy = cudaq::cc::PointerType::get(eleTy);
152+
for (auto extract : extracts) {
153+
SmallVector<cudaq::cc::ComputePtrArg> args;
154+
unsigned i = 0;
155+
for (auto arg : extract.getRawConstantIndices()) {
156+
if (arg == cudaq::cc::ExtractValueOp::getDynamicIndexValue())
157+
args.push_back(extract.getDynamicIndices()[i++]);
158+
else
159+
args.push_back(arg);
160+
}
161+
OpBuilder::InsertionGuard guard(rewriter);
162+
rewriter.setInsertionPoint(extract);
163+
auto addrVal =
164+
rewriter.create<cudaq::cc::ComputePtrOp>(loc, elePtrTy, base, args);
165+
rewriter.replaceOpWithNewOp<cudaq::cc::LoadOp>(extract, addrVal);
166+
}
167+
}
168+
if (!stores.empty()) {
169+
for (auto alloca : allocas)
170+
rewriter.replaceOpWithNewOp<cudaq::cc::AddressOfOp>(
171+
alloca, alloca.getType(), globalName);
172+
for (auto store : stores)
173+
rewriter.eraseOp(store);
174+
}
175+
if (loadAsValue) {
176+
auto base = rewriter.create<cudaq::cc::AddressOfOp>(
177+
loc, cudaq::cc::PointerType::get(conarr.getType()), globalName);
178+
rewriter.replaceOpWithNewOp<cudaq::cc::LoadOp>(conarr, base);
179+
}
127180
return success();
128181
}
129182

python/cudaq/kernel/ast_bridge.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,9 +1749,11 @@ def bodyBuilder(iterVal):
17491749
self.ctx) if len(qubits) == 1 and quake.RefType.isinstance(
17501750
qubits[0].type) else cc.StdvecType.get(
17511751
self.ctx, quake.MeasureType.get(self.ctx))
1752-
measureResult = opCtor(measTy, [],
1753-
qubits,
1754-
registerName=registerName).result
1752+
label = registerName
1753+
if not label:
1754+
label = None
1755+
measureResult = opCtor(measTy, [], qubits,
1756+
registerName=label).result
17551757
if pushResultToStack:
17561758
self.pushValue(
17571759
quake.DiscriminateOp(resTy, measureResult).result)
@@ -3152,6 +3154,73 @@ def bodyBuilder(iterVar):
31523154
isDecrementing=isDecrementing)
31533155
return
31543156

3157+
# We can simplify `for i,j in enumerate(L)` MLIR code immensely
3158+
# by just building a for loop over the iterable object L and using
3159+
# the index into that iterable and the element.
3160+
if isinstance(node.iter, ast.Call):
3161+
if node.iter.func.id == 'enumerate':
3162+
[self.visit(arg) for arg in node.iter.args]
3163+
if len(self.valueStack) == 2:
3164+
iterable = self.popValue()
3165+
self.popValue()
3166+
else:
3167+
assert len(self.valueStack) == 1
3168+
iterable = self.popValue()
3169+
iterable = self.ifPointerThenLoad(iterable)
3170+
totalSize = None
3171+
extractFunctor = None
3172+
varNames = []
3173+
for elt in node.target.elts:
3174+
varNames.append(elt.id)
3175+
3176+
beEfficient = False
3177+
if quake.VeqType.isinstance(iterable.type):
3178+
totalSize = quake.VeqSizeOp(self.getIntegerType(),
3179+
iterable).result
3180+
3181+
def functor(seq, idx):
3182+
q = quake.ExtractRefOp(self.getRefType(),
3183+
seq,
3184+
-1,
3185+
index=idx).result
3186+
return [idx, q]
3187+
3188+
extractFunctor = functor
3189+
beEfficient = True
3190+
elif cc.StdvecType.isinstance(iterable.type):
3191+
totalSize = cc.StdvecSizeOp(self.getIntegerType(),
3192+
iterable).result
3193+
3194+
def functor(seq, idx):
3195+
vecTy = cc.StdvecType.getElementType(seq.type)
3196+
dataTy = cc.PointerType.get(self.ctx, vecTy)
3197+
arrTy = vecTy
3198+
if not cc.ArrayType.isinstance(arrTy):
3199+
arrTy = cc.ArrayType.get(self.ctx, vecTy)
3200+
dataArrTy = cc.PointerType.get(self.ctx, arrTy)
3201+
data = cc.StdvecDataOp(dataArrTy, seq).result
3202+
v = cc.ComputePtrOp(
3203+
dataTy, data, [idx],
3204+
DenseI32ArrayAttr.get([kDynamicPtrIndex],
3205+
context=self.ctx)).result
3206+
return [idx, v]
3207+
3208+
extractFunctor = functor
3209+
beEfficient = True
3210+
3211+
if beEfficient:
3212+
3213+
def bodyBuilder(iterVar):
3214+
self.symbolTable.pushScope()
3215+
values = extractFunctor(iterable, iterVar)
3216+
for i, v in enumerate(values):
3217+
self.symbolTable[varNames[i]] = v
3218+
[self.visit(b) for b in node.body]
3219+
self.symbolTable.popScope()
3220+
3221+
self.createInvariantForLoop(totalSize, bodyBuilder)
3222+
return
3223+
31553224
self.visit(node.iter)
31563225
assert len(self.valueStack) > 0 and len(self.valueStack) < 3
31573226

0 commit comments

Comments
 (0)