Skip to content

Commit

Permalink
[CIR][LibOpt] Extend std::find optimization to all calls with raw poi…
Browse files Browse the repository at this point in the history
…nters (#400)

This also adds a missing check whether the pointer returned from
`memchr` is null and changes the result to `last` in that case.
  • Loading branch information
philnik777 authored and lanza committed Mar 23, 2024
1 parent e387207 commit 15d33f1
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 46 deletions.
Empty file added clang/asf
Empty file.
105 changes: 66 additions & 39 deletions clang/lib/CIR/Dialect/Transforms/LibOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,47 +120,36 @@ static bool containerHasStaticSize(StructType t, unsigned &size) {
}

void LibOptPass::xformStdFindIntoMemchr(StdFindOp findOp) {
// First and second operands need to be iterators begin() and end().
// TODO: look over cir.loads until we have a mem2reg + other passes
// to help out here.
auto iterBegin = dyn_cast<IterBeginOp>(findOp.getOperand(0).getDefiningOp());
if (!iterBegin)
return;
if (!isa<IterEndOp>(findOp.getOperand(1).getDefiningOp()))
return;

// Both operands have the same type, use iterBegin.

// Look at this pointer to retrieve container information.
auto thisPtr =
iterBegin.getOperand().getType().cast<PointerType>().getPointee();
auto containerTy = dyn_cast<StructType>(thisPtr);
if (!containerTy)
return;

if (!isSequentialContainer(containerTy))
return;

unsigned staticSize = 0;
if (!containerHasStaticSize(containerTy, staticSize))
// template <class T>
// requires (sizeof(T) == 1 && is_integral_v<T>)
// T* find(T* first, T* last, T value) {
// if (auto result = __builtin_memchr(first, value, last - first))
// return result;
// return last;
// }

auto first = findOp.getOperand(0);
auto last = findOp.getOperand(1);
auto value = findOp->getOperand(2);
if (!first.getType().isa<PointerType>() || !last.getType().isa<PointerType>())
return;

// Transformation:
// - 1st arg: the data pointer
// - Assert the Iterator is a pointer to primitive type.
// - Check IterBeginOp is char sized. TODO: add other types that map to
// char size.
auto iterResTy = iterBegin.getResult().getType().dyn_cast<PointerType>();
auto iterResTy = findOp.getType().dyn_cast<PointerType>();
assert(iterResTy && "expected pointer type for iterator");
auto underlyingDataTy = iterResTy.getPointee().dyn_cast<mlir::cir::IntType>();
auto underlyingDataTy = iterResTy.getPointee().dyn_cast<IntType>();
if (!underlyingDataTy || underlyingDataTy.getWidth() != 8)
return;

// - 2nd arg: the pattern
// - Check it's a pointer type.
// - Load the pattern from memory
// - cast it to `int`.
auto patternAddrTy = findOp.getOperand(2).getType().dyn_cast<PointerType>();
auto patternAddrTy = value.getType().dyn_cast<PointerType>();
if (!patternAddrTy || patternAddrTy.getPointee() != underlyingDataTy)
return;

Expand All @@ -169,27 +158,65 @@ void LibOptPass::xformStdFindIntoMemchr(StdFindOp findOp) {

CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(findOp.getOperation());
auto memchrOp0 = builder.createBitcast(
iterBegin.getLoc(), iterBegin.getResult(), builder.getVoidPtrTy());
auto memchrOp0 =
builder.createBitcast(first.getLoc(), first, builder.getVoidPtrTy());

// FIXME: get datalayout based "int" instead of fixed size 4.
auto loadPattern = builder.create<LoadOp>(
findOp.getOperand(2).getLoc(), underlyingDataTy, findOp.getOperand(2));
auto loadPattern =
builder.create<LoadOp>(value.getLoc(), underlyingDataTy, value);
auto memchrOp1 = builder.createIntCast(
loadPattern, IntType::get(builder.getContext(), 32, true));

// FIXME: get datalayout based "size_t" instead of fixed size 64.
auto uInt64Ty = IntType::get(builder.getContext(), 64, false);
auto memchrOp2 = builder.create<ConstantOp>(
findOp.getLoc(), uInt64Ty, mlir::cir::IntAttr::get(uInt64Ty, staticSize));
const auto uInt64Ty = IntType::get(builder.getContext(), 64, false);

// Build memchr op:
// void *memchr(const void *s, int c, size_t n);
auto memChr = builder.create<MemChrOp>(findOp.getLoc(), memchrOp0, memchrOp1,
memchrOp2);
mlir::Operation *result =
builder.createBitcast(findOp.getLoc(), memChr.getResult(), iterResTy)
.getDefiningOp();
auto memChr = [&] {
if (auto iterBegin = dyn_cast<IterBeginOp>(first.getDefiningOp());
iterBegin && isa<IterEndOp>(last.getDefiningOp())) {
// Both operands have the same type, use iterBegin.

// Look at this pointer to retrieve container information.
auto thisPtr =
iterBegin.getOperand().getType().cast<PointerType>().getPointee();
auto containerTy = dyn_cast<StructType>(thisPtr);

unsigned staticSize = 0;
if (containerTy && isSequentialContainer(containerTy) &&
containerHasStaticSize(containerTy, staticSize)) {
return builder.create<MemChrOp>(
findOp.getLoc(), memchrOp0, memchrOp1,
builder.create<ConstantOp>(
findOp.getLoc(), uInt64Ty,
mlir::cir::IntAttr::get(uInt64Ty, staticSize)));
}
}
return builder.create<MemChrOp>(
findOp.getLoc(), memchrOp0, memchrOp1,
builder.create<PtrDiffOp>(findOp.getLoc(), uInt64Ty, last, first));
}();

auto MemChrResult =
builder.createBitcast(findOp.getLoc(), memChr.getResult(), iterResTy);

// if (result)
// return result;
// else
// return last;
auto NullPtr = builder.create<ConstantOp>(
findOp.getLoc(), first.getType(), ConstPtrAttr::get(first.getType(), 0));
auto CmpResult = builder.create<CmpOp>(
findOp.getLoc(), BoolType::get(builder.getContext()), CmpOpKind::eq,
NullPtr.getRes(), MemChrResult);

auto result = builder.create<TernaryOp>(
findOp.getLoc(), CmpResult.getResult(),
[&](mlir::OpBuilder &ob, mlir::Location Loc) {
ob.create<YieldOp>(Loc, last);
},
[&](mlir::OpBuilder &ob, mlir::Location Loc) {
ob.create<YieldOp>(Loc, MemChrResult);
});

findOp.replaceAllUsesWith(result);
findOp.erase();
Expand Down
50 changes: 44 additions & 6 deletions clang/test/CIR/Transforms/lib-opt-find.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,64 @@

#include "std-cxx.h"

int test_find(unsigned char n = 3)
int test1(unsigned char n = 3)
{
// CHECK: test1
unsigned num_found = 0;
// CHECK: %[[pattern_addr:.*]] = cir.alloca !u8i, cir.ptr <!u8i>, ["n"
std::array<unsigned char, 9> v = {1, 2, 3, 4, 5, 6, 7, 8, 9};

auto f = std::find(v.begin(), v.end(), n);
// CHECK: %[[begin:.*]] = cir.call @_ZNSt5arrayIhLj9EE5beginEv
// CHECK: cir.call @_ZNSt5arrayIhLj9EE3endEv
// CHECK: %[[cast_to_void:.*]] = cir.cast(bitcast, %[[begin]] : !cir.ptr<!u8i>), !cir.ptr<!void>

// CHECK: %[[first:.*]] = cir.call @_ZNSt5arrayIhLj9EE5beginEv
// CHECK: %[[last:.*]] = cir.call @_ZNSt5arrayIhLj9EE3endEv
// CHECK: %[[cast_to_void:.*]] = cir.cast(bitcast, %[[first]] : !cir.ptr<!u8i>), !cir.ptr<!void>
// CHECK: %[[load_pattern:.*]] = cir.load %[[pattern_addr]] : cir.ptr <!u8i>, !u8i
// CHECK: %[[pattern:.*]] = cir.cast(integral, %[[load_pattern:.*]] : !u8i), !s32i

// CHECK-NOT: {{.*}} cir.call @_ZSt4findIPhhET_S1_S1_RKT0_(
// CHECK: %[[array_size:.*]] = cir.const(#cir.int<9> : !u64i) : !u64i

// CHECK: %[[result_cast:.*]] = cir.libc.memchr(%[[cast_to_void]], %[[pattern]], %[[array_size]])
// CHECK: cir.cast(bitcast, %[[result_cast]] : !cir.ptr<!void>), !cir.ptr<!u8i>
// CHECK: %[[memchr_res:.*]] = cir.cast(bitcast, %[[result_cast]] : !cir.ptr<!void>), !cir.ptr<!u8i>
// CHECK: %[[nullptr:.*]] = cir.const(#cir.ptr<null> : !cir.ptr<!u8i>) : !cir.ptr<!u8i>
// CHECK: %[[cmp_res:.*]] = cir.cmp(eq, %[[nullptr]], %[[memchr_res]]) : !cir.ptr<!u8i>, !cir.bool
// CHECK: cir.ternary(%[[cmp_res]], true {
// CHECK: cir.yield %[[last]] : !cir.ptr<!u8i>
// CHECK: }, false {
// CHECK: cir.yield %[[memchr_res]] : !cir.ptr<!u8i>
// CHECK: }) : (!cir.bool) -> !cir.ptr<!u8i>

if (f != v.end())
num_found++;

return num_found;
}
}

unsigned char* test2(unsigned char* first, unsigned char* last, unsigned char v)
{
return std::find(first, last, v);
// CHECK: test2

// CHECK: %[[first_storage:.*]] = cir.alloca !cir.ptr<!u8i>, cir.ptr <!cir.ptr<!u8i>>, ["first", init]
// CHECK: %[[last_storage:.*]] = cir.alloca !cir.ptr<!u8i>, cir.ptr <!cir.ptr<!u8i>>, ["last", init]
// CHECK: %[[pattern_storage:.*]] = cir.alloca !u8i, cir.ptr <!u8i>, ["v", init]
// CHECK: %[[first:.*]] = cir.load %[[first_storage]]
// CHECK: %[[last:.*]] = cir.load %[[last_storage]]
// CHECK: %[[cast_to_void:.*]] = cir.cast(bitcast, %[[first]] : !cir.ptr<!u8i>), !cir.ptr<!void>
// CHECK: %[[load_pattern:.*]] = cir.load %[[pattern_storage]] : cir.ptr <!u8i>, !u8i
// CHECK: %[[pattern:.*]] = cir.cast(integral, %[[load_pattern:.*]] : !u8i), !s32i

// CHECK-NOT: {{.*}} cir.call @_ZSt4findIPhhET_S1_S1_RKT0_(
// CHECK: %[[array_size:.*]] = cir.ptr_diff(%[[last]], %[[first]]) : !cir.ptr<!u8i> -> !u64i

// CHECK: %[[result_cast:.*]] = cir.libc.memchr(%[[cast_to_void]], %[[pattern]], %[[array_size]])
// CHECK: %[[memchr_res:.*]] = cir.cast(bitcast, %[[result_cast]] : !cir.ptr<!void>), !cir.ptr<!u8i>
// CHECK: %[[nullptr:.*]] = cir.const(#cir.ptr<null> : !cir.ptr<!u8i>) : !cir.ptr<!u8i>
// CHECK: %[[cmp_res:.*]] = cir.cmp(eq, %[[nullptr]], %[[memchr_res]]) : !cir.ptr<!u8i>, !cir.bool
// CHECK: cir.ternary(%[[cmp_res]], true {
// CHECK: cir.yield %[[last]] : !cir.ptr<!u8i>
// CHECK: }, false {
// CHECK: cir.yield %[[memchr_res]] : !cir.ptr<!u8i>
// CHECK: }) : (!cir.bool) -> !cir.ptr<!u8i>
}
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -908,8 +908,9 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,

void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
FlatSymbolRefAttr callee, ValueRange args) {
auto fargs = callee ? args : args.drop_front();
build(builder, state, results,
TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
TypeAttr::get(getLLVMFuncType(builder.getContext(), results, fargs)),
callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
/*CConv=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
Expand Down

0 comments on commit 15d33f1

Please sign in to comment.