@@ -120,47 +120,36 @@ static bool containerHasStaticSize(StructType t, unsigned &size) {
120
120
}
121
121
122
122
void LibOptPass::xformStdFindIntoMemchr (StdFindOp findOp) {
123
- // First and second operands need to be iterators begin() and end().
124
- // TODO: look over cir.loads until we have a mem2reg + other passes
125
- // to help out here.
126
- auto iterBegin = dyn_cast<IterBeginOp>(findOp.getOperand (0 ).getDefiningOp ());
127
- if (!iterBegin)
128
- return ;
129
- if (!isa<IterEndOp>(findOp.getOperand (1 ).getDefiningOp ()))
130
- return ;
131
-
132
- // Both operands have the same type, use iterBegin.
133
-
134
- // Look at this pointer to retrieve container information.
135
- auto thisPtr =
136
- iterBegin.getOperand ().getType ().cast <PointerType>().getPointee ();
137
- auto containerTy = dyn_cast<StructType>(thisPtr);
138
- if (!containerTy)
139
- return ;
140
-
141
- if (!isSequentialContainer (containerTy))
142
- return ;
143
-
144
- unsigned staticSize = 0 ;
145
- if (!containerHasStaticSize (containerTy, staticSize))
123
+ // template <class T>
124
+ // requires (sizeof(T) == 1 && is_integral_v<T>)
125
+ // T* find(T* first, T* last, T value) {
126
+ // if (auto result = __builtin_memchr(first, value, last - first))
127
+ // return result;
128
+ // return last;
129
+ // }
130
+
131
+ auto first = findOp.getOperand (0 );
132
+ auto last = findOp.getOperand (1 );
133
+ auto value = findOp->getOperand (2 );
134
+ if (!first.getType ().isa <PointerType>() || !last.getType ().isa <PointerType>())
146
135
return ;
147
136
148
137
// Transformation:
149
138
// - 1st arg: the data pointer
150
139
// - Assert the Iterator is a pointer to primitive type.
151
140
// - Check IterBeginOp is char sized. TODO: add other types that map to
152
141
// char size.
153
- auto iterResTy = iterBegin. getResult () .getType ().dyn_cast <PointerType>();
142
+ auto iterResTy = findOp .getType ().dyn_cast <PointerType>();
154
143
assert (iterResTy && " expected pointer type for iterator" );
155
- auto underlyingDataTy = iterResTy.getPointee ().dyn_cast <mlir::cir:: IntType>();
144
+ auto underlyingDataTy = iterResTy.getPointee ().dyn_cast <IntType>();
156
145
if (!underlyingDataTy || underlyingDataTy.getWidth () != 8 )
157
146
return ;
158
147
159
148
// - 2nd arg: the pattern
160
149
// - Check it's a pointer type.
161
150
// - Load the pattern from memory
162
151
// - cast it to `int`.
163
- auto patternAddrTy = findOp. getOperand ( 2 ) .getType ().dyn_cast <PointerType>();
152
+ auto patternAddrTy = value .getType ().dyn_cast <PointerType>();
164
153
if (!patternAddrTy || patternAddrTy.getPointee () != underlyingDataTy)
165
154
return ;
166
155
@@ -169,27 +158,65 @@ void LibOptPass::xformStdFindIntoMemchr(StdFindOp findOp) {
169
158
170
159
CIRBaseBuilderTy builder (getContext ());
171
160
builder.setInsertionPointAfter (findOp.getOperation ());
172
- auto memchrOp0 = builder. createBitcast (
173
- iterBegin. getLoc (), iterBegin. getResult () , builder.getVoidPtrTy ());
161
+ auto memchrOp0 =
162
+ builder. createBitcast (first. getLoc (), first , builder.getVoidPtrTy ());
174
163
175
164
// FIXME: get datalayout based "int" instead of fixed size 4.
176
- auto loadPattern = builder. create <LoadOp>(
177
- findOp. getOperand ( 2 ) .getLoc (), underlyingDataTy, findOp. getOperand ( 2 ) );
165
+ auto loadPattern =
166
+ builder. create <LoadOp>(value .getLoc (), underlyingDataTy, value );
178
167
auto memchrOp1 = builder.createIntCast (
179
168
loadPattern, IntType::get (builder.getContext (), 32 , true ));
180
169
181
- // FIXME: get datalayout based "size_t" instead of fixed size 64.
182
- auto uInt64Ty = IntType::get (builder.getContext (), 64 , false );
183
- auto memchrOp2 = builder.create <ConstantOp>(
184
- findOp.getLoc (), uInt64Ty, mlir::cir::IntAttr::get (uInt64Ty, staticSize));
170
+ const auto uInt64Ty = IntType::get (builder.getContext (), 64 , false );
185
171
186
172
// Build memchr op:
187
173
// void *memchr(const void *s, int c, size_t n);
188
- auto memChr = builder.create <MemChrOp>(findOp.getLoc (), memchrOp0, memchrOp1,
189
- memchrOp2);
190
- mlir::Operation *result =
191
- builder.createBitcast (findOp.getLoc (), memChr.getResult (), iterResTy)
192
- .getDefiningOp ();
174
+ auto memChr = [&] {
175
+ if (auto iterBegin = dyn_cast<IterBeginOp>(first.getDefiningOp ());
176
+ iterBegin && isa<IterEndOp>(last.getDefiningOp ())) {
177
+ // Both operands have the same type, use iterBegin.
178
+
179
+ // Look at this pointer to retrieve container information.
180
+ auto thisPtr =
181
+ iterBegin.getOperand ().getType ().cast <PointerType>().getPointee ();
182
+ auto containerTy = dyn_cast<StructType>(thisPtr);
183
+
184
+ unsigned staticSize = 0 ;
185
+ if (containerTy && isSequentialContainer (containerTy) &&
186
+ containerHasStaticSize (containerTy, staticSize)) {
187
+ return builder.create <MemChrOp>(
188
+ findOp.getLoc (), memchrOp0, memchrOp1,
189
+ builder.create <ConstantOp>(
190
+ findOp.getLoc (), uInt64Ty,
191
+ mlir::cir::IntAttr::get (uInt64Ty, staticSize)));
192
+ }
193
+ }
194
+ return builder.create <MemChrOp>(
195
+ findOp.getLoc (), memchrOp0, memchrOp1,
196
+ builder.create <PtrDiffOp>(findOp.getLoc (), uInt64Ty, last, first));
197
+ }();
198
+
199
+ auto MemChrResult =
200
+ builder.createBitcast (findOp.getLoc (), memChr.getResult (), iterResTy);
201
+
202
+ // if (result)
203
+ // return result;
204
+ // else
205
+ // return last;
206
+ auto NullPtr = builder.create <ConstantOp>(
207
+ findOp.getLoc (), first.getType (), ConstPtrAttr::get (first.getType (), 0 ));
208
+ auto CmpResult = builder.create <CmpOp>(
209
+ findOp.getLoc (), BoolType::get (builder.getContext ()), CmpOpKind::eq,
210
+ NullPtr.getRes (), MemChrResult);
211
+
212
+ auto result = builder.create <TernaryOp>(
213
+ findOp.getLoc (), CmpResult.getResult (),
214
+ [&](mlir::OpBuilder &ob, mlir::Location Loc) {
215
+ ob.create <YieldOp>(Loc, last);
216
+ },
217
+ [&](mlir::OpBuilder &ob, mlir::Location Loc) {
218
+ ob.create <YieldOp>(Loc, MemChrResult);
219
+ });
193
220
194
221
findOp.replaceAllUsesWith (result);
195
222
findOp.erase ();
0 commit comments