Skip to content

Commit 8e9729d

Browse files
yugrlanza
authored andcommitted
[CIR][Transforms][Bugfix] Do not use-after-free in MergeCleanups and IdiomRecognizer. (#389)
Some tests started failing under `-DLLVM_USE_SANITIZER=Address` due to trivial use-after-free errors.
1 parent ee93764 commit 8e9729d

File tree

2 files changed

+23
-17
lines changed

2 files changed

+23
-17
lines changed

clang/lib/CIR/Dialect/Transforms/IdiomRecognizer.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ struct IdiomRecognizerPass : public IdiomRecognizerBase<IdiomRecognizerPass> {
3636
IdiomRecognizerPass() = default;
3737
void runOnOperation() override;
3838
void recognizeCall(CallOp call);
39-
void raiseStdFind(CallOp call);
40-
void raiseIteratorBeginEnd(CallOp call);
39+
bool raiseStdFind(CallOp call);
40+
bool raiseIteratorBeginEnd(CallOp call);
4141

4242
// Handle pass options
4343
struct Options {
@@ -88,14 +88,14 @@ struct IdiomRecognizerPass : public IdiomRecognizerBase<IdiomRecognizerPass> {
8888
};
8989
} // namespace
9090

91-
void IdiomRecognizerPass::raiseStdFind(CallOp call) {
91+
bool IdiomRecognizerPass::raiseStdFind(CallOp call) {
9292
// FIXME: tablegen all of this function.
9393
if (call.getNumOperands() != 3)
94-
return;
94+
return false;
9595

9696
auto callExprAttr = call.getAstAttr();
9797
if (!callExprAttr || !callExprAttr.isStdFunctionCall("find")) {
98-
return;
98+
return false;
9999
}
100100

101101
if (opts.emitRemarkFoundCalls())
@@ -109,6 +109,7 @@ void IdiomRecognizerPass::raiseStdFind(CallOp call) {
109109

110110
call.replaceAllUsesWith(findOp);
111111
call.erase();
112+
return true;
112113
}
113114

114115
static bool isIteratorLikeType(mlir::Type t) {
@@ -128,24 +129,24 @@ static bool isIteratorInStdContainter(mlir::Type t) {
128129
return isStdArrayType(t);
129130
}
130131

131-
void IdiomRecognizerPass::raiseIteratorBeginEnd(CallOp call) {
132+
bool IdiomRecognizerPass::raiseIteratorBeginEnd(CallOp call) {
132133
// FIXME: tablegen all of this function.
133134
CIRBaseBuilderTy builder(getContext());
134135

135136
if (call.getNumOperands() != 1 || call.getNumResults() != 1)
136-
return;
137+
return false;
137138

138139
auto callExprAttr = call.getAstAttr();
139140
if (!callExprAttr)
140-
return;
141+
return false;
141142

142143
if (!isIteratorLikeType(call.getResult(0).getType()))
143-
return;
144+
return false;
144145

145146
// First argument is the container "this" pointer.
146147
auto thisPtr = call.getOperand(0).getType().dyn_cast<PointerType>();
147148
if (!thisPtr || !isIteratorInStdContainter(thisPtr.getPointee()))
148-
return;
149+
return false;
149150

150151
builder.setInsertionPointAfter(call.getOperation());
151152
mlir::Operation *iterOp;
@@ -162,16 +163,20 @@ void IdiomRecognizerPass::raiseIteratorBeginEnd(CallOp call) {
162163
call.getLoc(), call.getResult(0).getType(), call.getCalleeAttr(),
163164
call.getOperand(0));
164165
} else {
165-
return;
166+
return false;
166167
}
167168

168169
call.replaceAllUsesWith(iterOp);
169170
call.erase();
171+
return true;
170172
}
171173

172174
void IdiomRecognizerPass::recognizeCall(CallOp call) {
173-
raiseIteratorBeginEnd(call);
174-
raiseStdFind(call);
175+
if (raiseIteratorBeginEnd(call))
176+
return;
177+
178+
if (raiseStdFind(call))
179+
return;
175180
}
176181

177182
void IdiomRecognizerPass::runOnOperation() {

clang/lib/CIR/Dialect/Transforms/MergeCleanups.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,16 @@ struct MergeTrivialConditionalBranches : public OpRewritePattern<BrCondOp> {
8484
void rewrite(BrCondOp op, PatternRewriter &rewriter) const final {
8585
auto constOp = llvm::cast<ConstantOp>(op.getCond().getDefiningOp());
8686
bool cond = constOp.getValue().cast<cir::BoolAttr>().getValue();
87+
auto *destTrue = op.getDestTrue(), *destFalse = op.getDestFalse();
8788
Block *block = op.getOperation()->getBlock();
8889

8990
rewriter.eraseOp(op);
9091
if (cond) {
91-
rewriter.mergeBlocks(op.getDestTrue(), block);
92-
rewriter.eraseBlock(op.getDestFalse());
92+
rewriter.mergeBlocks(destTrue, block);
93+
rewriter.eraseBlock(destFalse);
9394
} else {
94-
rewriter.mergeBlocks(op.getDestFalse(), block);
95-
rewriter.eraseBlock(op.getDestTrue());
95+
rewriter.mergeBlocks(destFalse, block);
96+
rewriter.eraseBlock(destTrue);
9697
}
9798
}
9899
};

0 commit comments

Comments
 (0)