Skip to content

Commit 15d33f1

Browse files
philnik777lanza
authored andcommitted
[CIR][LibOpt] Extend std::find optimization to all calls with raw pointers (#400)
This also adds a missing check whether the pointer returned from `memchr` is null and changes the result to `last` in that case.
1 parent e387207 commit 15d33f1

File tree

4 files changed

+112
-46
lines changed

4 files changed

+112
-46
lines changed

clang/asf

Whitespace-only changes.

clang/lib/CIR/Dialect/Transforms/LibOpt.cpp

Lines changed: 66 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -120,47 +120,36 @@ static bool containerHasStaticSize(StructType t, unsigned &size) {
120120
}
121121

122122
void LibOptPass::xformStdFindIntoMemchr(StdFindOp findOp) {
123-
// First and second operands need to be iterators begin() and end().
124-
// TODO: look over cir.loads until we have a mem2reg + other passes
125-
// to help out here.
126-
auto iterBegin = dyn_cast<IterBeginOp>(findOp.getOperand(0).getDefiningOp());
127-
if (!iterBegin)
128-
return;
129-
if (!isa<IterEndOp>(findOp.getOperand(1).getDefiningOp()))
130-
return;
131-
132-
// Both operands have the same type, use iterBegin.
133-
134-
// Look at this pointer to retrieve container information.
135-
auto thisPtr =
136-
iterBegin.getOperand().getType().cast<PointerType>().getPointee();
137-
auto containerTy = dyn_cast<StructType>(thisPtr);
138-
if (!containerTy)
139-
return;
140-
141-
if (!isSequentialContainer(containerTy))
142-
return;
143-
144-
unsigned staticSize = 0;
145-
if (!containerHasStaticSize(containerTy, staticSize))
123+
// template <class T>
124+
// requires (sizeof(T) == 1 && is_integral_v<T>)
125+
// T* find(T* first, T* last, T value) {
126+
// if (auto result = __builtin_memchr(first, value, last - first))
127+
// return result;
128+
// return last;
129+
// }
130+
131+
auto first = findOp.getOperand(0);
132+
auto last = findOp.getOperand(1);
133+
auto value = findOp->getOperand(2);
134+
if (!first.getType().isa<PointerType>() || !last.getType().isa<PointerType>())
146135
return;
147136

148137
// Transformation:
149138
// - 1st arg: the data pointer
150139
// - Assert the Iterator is a pointer to primitive type.
151140
// - Check IterBeginOp is char sized. TODO: add other types that map to
152141
// char size.
153-
auto iterResTy = iterBegin.getResult().getType().dyn_cast<PointerType>();
142+
auto iterResTy = findOp.getType().dyn_cast<PointerType>();
154143
assert(iterResTy && "expected pointer type for iterator");
155-
auto underlyingDataTy = iterResTy.getPointee().dyn_cast<mlir::cir::IntType>();
144+
auto underlyingDataTy = iterResTy.getPointee().dyn_cast<IntType>();
156145
if (!underlyingDataTy || underlyingDataTy.getWidth() != 8)
157146
return;
158147

159148
// - 2nd arg: the pattern
160149
// - Check it's a pointer type.
161150
// - Load the pattern from memory
162151
// - cast it to `int`.
163-
auto patternAddrTy = findOp.getOperand(2).getType().dyn_cast<PointerType>();
152+
auto patternAddrTy = value.getType().dyn_cast<PointerType>();
164153
if (!patternAddrTy || patternAddrTy.getPointee() != underlyingDataTy)
165154
return;
166155

@@ -169,27 +158,65 @@ void LibOptPass::xformStdFindIntoMemchr(StdFindOp findOp) {
169158

170159
CIRBaseBuilderTy builder(getContext());
171160
builder.setInsertionPointAfter(findOp.getOperation());
172-
auto memchrOp0 = builder.createBitcast(
173-
iterBegin.getLoc(), iterBegin.getResult(), builder.getVoidPtrTy());
161+
auto memchrOp0 =
162+
builder.createBitcast(first.getLoc(), first, builder.getVoidPtrTy());
174163

175164
// FIXME: get datalayout based "int" instead of fixed size 4.
176-
auto loadPattern = builder.create<LoadOp>(
177-
findOp.getOperand(2).getLoc(), underlyingDataTy, findOp.getOperand(2));
165+
auto loadPattern =
166+
builder.create<LoadOp>(value.getLoc(), underlyingDataTy, value);
178167
auto memchrOp1 = builder.createIntCast(
179168
loadPattern, IntType::get(builder.getContext(), 32, true));
180169

181-
// FIXME: get datalayout based "size_t" instead of fixed size 64.
182-
auto uInt64Ty = IntType::get(builder.getContext(), 64, false);
183-
auto memchrOp2 = builder.create<ConstantOp>(
184-
findOp.getLoc(), uInt64Ty, mlir::cir::IntAttr::get(uInt64Ty, staticSize));
170+
const auto uInt64Ty = IntType::get(builder.getContext(), 64, false);
185171

186172
// Build memchr op:
187173
// void *memchr(const void *s, int c, size_t n);
188-
auto memChr = builder.create<MemChrOp>(findOp.getLoc(), memchrOp0, memchrOp1,
189-
memchrOp2);
190-
mlir::Operation *result =
191-
builder.createBitcast(findOp.getLoc(), memChr.getResult(), iterResTy)
192-
.getDefiningOp();
174+
auto memChr = [&] {
175+
if (auto iterBegin = dyn_cast<IterBeginOp>(first.getDefiningOp());
176+
iterBegin && isa<IterEndOp>(last.getDefiningOp())) {
177+
// Both operands have the same type, use iterBegin.
178+
179+
// Look at this pointer to retrieve container information.
180+
auto thisPtr =
181+
iterBegin.getOperand().getType().cast<PointerType>().getPointee();
182+
auto containerTy = dyn_cast<StructType>(thisPtr);
183+
184+
unsigned staticSize = 0;
185+
if (containerTy && isSequentialContainer(containerTy) &&
186+
containerHasStaticSize(containerTy, staticSize)) {
187+
return builder.create<MemChrOp>(
188+
findOp.getLoc(), memchrOp0, memchrOp1,
189+
builder.create<ConstantOp>(
190+
findOp.getLoc(), uInt64Ty,
191+
mlir::cir::IntAttr::get(uInt64Ty, staticSize)));
192+
}
193+
}
194+
return builder.create<MemChrOp>(
195+
findOp.getLoc(), memchrOp0, memchrOp1,
196+
builder.create<PtrDiffOp>(findOp.getLoc(), uInt64Ty, last, first));
197+
}();
198+
199+
auto MemChrResult =
200+
builder.createBitcast(findOp.getLoc(), memChr.getResult(), iterResTy);
201+
202+
// if (result)
203+
// return result;
204+
// else
205+
// return last;
206+
auto NullPtr = builder.create<ConstantOp>(
207+
findOp.getLoc(), first.getType(), ConstPtrAttr::get(first.getType(), 0));
208+
auto CmpResult = builder.create<CmpOp>(
209+
findOp.getLoc(), BoolType::get(builder.getContext()), CmpOpKind::eq,
210+
NullPtr.getRes(), MemChrResult);
211+
212+
auto result = builder.create<TernaryOp>(
213+
findOp.getLoc(), CmpResult.getResult(),
214+
[&](mlir::OpBuilder &ob, mlir::Location Loc) {
215+
ob.create<YieldOp>(Loc, last);
216+
},
217+
[&](mlir::OpBuilder &ob, mlir::Location Loc) {
218+
ob.create<YieldOp>(Loc, MemChrResult);
219+
});
193220

194221
findOp.replaceAllUsesWith(result);
195222
findOp.erase();

clang/test/CIR/Transforms/lib-opt-find.cpp

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,64 @@
33

44
#include "std-cxx.h"
55

6-
int test_find(unsigned char n = 3)
6+
int test1(unsigned char n = 3)
77
{
8+
// CHECK: test1
89
unsigned num_found = 0;
910
// CHECK: %[[pattern_addr:.*]] = cir.alloca !u8i, cir.ptr <!u8i>, ["n"
1011
std::array<unsigned char, 9> v = {1, 2, 3, 4, 5, 6, 7, 8, 9};
1112

1213
auto f = std::find(v.begin(), v.end(), n);
13-
// CHECK: %[[begin:.*]] = cir.call @_ZNSt5arrayIhLj9EE5beginEv
14-
// CHECK: cir.call @_ZNSt5arrayIhLj9EE3endEv
15-
// CHECK: %[[cast_to_void:.*]] = cir.cast(bitcast, %[[begin]] : !cir.ptr<!u8i>), !cir.ptr<!void>
14+
15+
// CHECK: %[[first:.*]] = cir.call @_ZNSt5arrayIhLj9EE5beginEv
16+
// CHECK: %[[last:.*]] = cir.call @_ZNSt5arrayIhLj9EE3endEv
17+
// CHECK: %[[cast_to_void:.*]] = cir.cast(bitcast, %[[first]] : !cir.ptr<!u8i>), !cir.ptr<!void>
1618
// CHECK: %[[load_pattern:.*]] = cir.load %[[pattern_addr]] : cir.ptr <!u8i>, !u8i
1719
// CHECK: %[[pattern:.*]] = cir.cast(integral, %[[load_pattern:.*]] : !u8i), !s32i
1820

1921
// CHECK-NOT: {{.*}} cir.call @_ZSt4findIPhhET_S1_S1_RKT0_(
2022
// CHECK: %[[array_size:.*]] = cir.const(#cir.int<9> : !u64i) : !u64i
2123

2224
// CHECK: %[[result_cast:.*]] = cir.libc.memchr(%[[cast_to_void]], %[[pattern]], %[[array_size]])
23-
// CHECK: cir.cast(bitcast, %[[result_cast]] : !cir.ptr<!void>), !cir.ptr<!u8i>
25+
// CHECK: %[[memchr_res:.*]] = cir.cast(bitcast, %[[result_cast]] : !cir.ptr<!void>), !cir.ptr<!u8i>
26+
// CHECK: %[[nullptr:.*]] = cir.const(#cir.ptr<null> : !cir.ptr<!u8i>) : !cir.ptr<!u8i>
27+
// CHECK: %[[cmp_res:.*]] = cir.cmp(eq, %[[nullptr]], %[[memchr_res]]) : !cir.ptr<!u8i>, !cir.bool
28+
// CHECK: cir.ternary(%[[cmp_res]], true {
29+
// CHECK: cir.yield %[[last]] : !cir.ptr<!u8i>
30+
// CHECK: }, false {
31+
// CHECK: cir.yield %[[memchr_res]] : !cir.ptr<!u8i>
32+
// CHECK: }) : (!cir.bool) -> !cir.ptr<!u8i>
33+
2434
if (f != v.end())
2535
num_found++;
2636

2737
return num_found;
28-
}
38+
}
39+
40+
unsigned char* test2(unsigned char* first, unsigned char* last, unsigned char v)
41+
{
42+
return std::find(first, last, v);
43+
// CHECK: test2
44+
45+
// CHECK: %[[first_storage:.*]] = cir.alloca !cir.ptr<!u8i>, cir.ptr <!cir.ptr<!u8i>>, ["first", init]
46+
// CHECK: %[[last_storage:.*]] = cir.alloca !cir.ptr<!u8i>, cir.ptr <!cir.ptr<!u8i>>, ["last", init]
47+
// CHECK: %[[pattern_storage:.*]] = cir.alloca !u8i, cir.ptr <!u8i>, ["v", init]
48+
// CHECK: %[[first:.*]] = cir.load %[[first_storage]]
49+
// CHECK: %[[last:.*]] = cir.load %[[last_storage]]
50+
// CHECK: %[[cast_to_void:.*]] = cir.cast(bitcast, %[[first]] : !cir.ptr<!u8i>), !cir.ptr<!void>
51+
// CHECK: %[[load_pattern:.*]] = cir.load %[[pattern_storage]] : cir.ptr <!u8i>, !u8i
52+
// CHECK: %[[pattern:.*]] = cir.cast(integral, %[[load_pattern:.*]] : !u8i), !s32i
53+
54+
// CHECK-NOT: {{.*}} cir.call @_ZSt4findIPhhET_S1_S1_RKT0_(
55+
// CHECK: %[[array_size:.*]] = cir.ptr_diff(%[[last]], %[[first]]) : !cir.ptr<!u8i> -> !u64i
56+
57+
// CHECK: %[[result_cast:.*]] = cir.libc.memchr(%[[cast_to_void]], %[[pattern]], %[[array_size]])
58+
// CHECK: %[[memchr_res:.*]] = cir.cast(bitcast, %[[result_cast]] : !cir.ptr<!void>), !cir.ptr<!u8i>
59+
// CHECK: %[[nullptr:.*]] = cir.const(#cir.ptr<null> : !cir.ptr<!u8i>) : !cir.ptr<!u8i>
60+
// CHECK: %[[cmp_res:.*]] = cir.cmp(eq, %[[nullptr]], %[[memchr_res]]) : !cir.ptr<!u8i>, !cir.bool
61+
// CHECK: cir.ternary(%[[cmp_res]], true {
62+
// CHECK: cir.yield %[[last]] : !cir.ptr<!u8i>
63+
// CHECK: }, false {
64+
// CHECK: cir.yield %[[memchr_res]] : !cir.ptr<!u8i>
65+
// CHECK: }) : (!cir.bool) -> !cir.ptr<!u8i>
66+
}

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -908,8 +908,9 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
908908

909909
void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
910910
FlatSymbolRefAttr callee, ValueRange args) {
911+
auto fargs = callee ? args : args.drop_front();
911912
build(builder, state, results,
912-
TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
913+
TypeAttr::get(getLLVMFuncType(builder.getContext(), results, fargs)),
913914
callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
914915
/*CConv=*/nullptr,
915916
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,

0 commit comments

Comments
 (0)