Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][CodeGen][Lowering] Supports arrays with trailing zeros #393

Merged
merged 5 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,22 @@ def ConstArrayAttr : CIR_Attr<"ConstArray", "const_array", [TypedAttrInterface]>
}];

let parameters = (ins AttributeSelfTypeParameter<"">:$type,
"Attribute":$elts);
"Attribute":$elts,
"int":$trailingZerosNum);

// Define a custom builder for the type; that removes the need to pass
// in an MLIRContext instance, as it can be infered from the `type`.
let builders = [
AttrBuilderWithInferredContext<(ins "mlir::cir::ArrayType":$type,
"Attribute":$elts), [{
return $_get(type.getContext(), type, elts);
int zeros = 0;
auto typeSize = type.cast<mlir::cir::ArrayType>().getSize();
if (auto str = elts.dyn_cast<mlir::StringAttr>())
zeros = typeSize - str.size();
else
zeros = typeSize - elts.cast<mlir::ArrayAttr>().size();

return $_get(type.getContext(), type, elts, zeros);
}]>
];

Expand All @@ -132,6 +140,10 @@ def ConstArrayAttr : CIR_Attr<"ConstArray", "const_array", [TypedAttrInterface]>

// Enable verifier.
let genVerifyDecl = 1;

let extraClassDeclaration = [{
bool hasTrailingZeros() const { return getTrailingZerosNum() != 0; };
}];
}

//===----------------------------------------------------------------------===//
Expand Down
10 changes: 9 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenExprConst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -956,10 +956,18 @@ buildArrayConstant(CIRGenModule &CGM, mlir::Type DesiredType,
// Add a zeroinitializer array filler if we have lots of trailing zeroes.
unsigned TrailingZeroes = ArrayBound - NonzeroLength;
if (TrailingZeroes >= 8) {
assert(0 && "NYE");
assert(Elements.size() >= NonzeroLength &&
"missing initializer for non-zero element");

SmallVector<mlir::Attribute, 4> Eles;
Eles.reserve(Elements.size());
for (auto const &Element : Elements)
Eles.push_back(Element);

return builder.getConstArray(
mlir::ArrayAttr::get(builder.getContext(), Eles),
mlir::cir::ArrayType::get(builder.getContext(), CommonElementType,
ArrayBound));
// TODO(cir): If all the elements had the same type up to the trailing
// zeroes, emit a struct of two arrays (the nonzero data and the
// zeroinitializer). Use DesiredType to get the element type.
Expand Down
25 changes: 21 additions & 4 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2282,7 +2282,7 @@ mlir::OpTrait::impl::verifySameFirstSecondOperandAndResultType(Operation *op) {

LogicalResult mlir::cir::ConstArrayAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
::mlir::Type type, Attribute attr) {
::mlir::Type type, Attribute attr, int trailingZerosNum) {

if (!(attr.isa<mlir::ArrayAttr>() || attr.isa<mlir::StringAttr>()))
return emitError() << "constant array expects ArrayAttr or StringAttr";
Expand All @@ -2305,7 +2305,7 @@ LogicalResult mlir::cir::ConstArrayAttr::verify(
auto at = type.cast<ArrayType>();

// Make sure both number of elements and subelement types match type.
if (at.getSize() != arrayAttr.size())
if (at.getSize() != arrayAttr.size() + trailingZerosNum)
return emitError() << "constant array size should match type size";
LogicalResult eltTypeCheck = success();
arrayAttr.walkImmediateSubElements(
Expand Down Expand Up @@ -2370,16 +2370,33 @@ ::mlir::Attribute ConstArrayAttr::parse(::mlir::AsmParser &parser,
}
}

auto zeros = 0;
if (parser.parseOptionalComma().succeeded()) {
if (parser.parseOptionalKeyword("trailing_zeros").succeeded()) {
auto typeSize = resultTy.value().cast<mlir::cir::ArrayType>().getSize();
auto elts = resultVal.value();
if (auto str = elts.dyn_cast<mlir::StringAttr>())
zeros = typeSize - str.size();
else
zeros = typeSize - elts.cast<mlir::ArrayAttr>().size();
} else {
return {};
}
}

// Parse literal '>'
if (parser.parseGreater())
return {};
return parser.getChecked<ConstArrayAttr>(loc, parser.getContext(),
resultTy.value(), resultVal.value());

return parser.getChecked<ConstArrayAttr>(
loc, parser.getContext(), resultTy.value(), resultVal.value(), zeros);
}

void ConstArrayAttr::print(::mlir::AsmPrinter &printer) const {
printer << "<";
printer.printStrippedAttrOrType(getElts());
if (auto zeros = getTrailingZerosNum())
printer << ", trailing_zeros";
printer << ">";
}

Expand Down
28 changes: 25 additions & 3 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,15 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
const mlir::TypeConverter *converter) {
auto llvmTy = converter->convertType(constArr.getType());
auto loc = parentOp->getLoc();
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);
mlir::Value result;

if (auto zeros = constArr.getTrailingZerosNum()) {
auto arrayTy = constArr.getType();
result = rewriter.create<mlir::cir::ZeroInitConstOp>(
loc, converter->convertType(arrayTy));
} else {
result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);
}

// Iteratively lower each constant element of the array.
if (auto arrayAttr = constArr.getElts().dyn_cast<mlir::ArrayAttr>()) {
Expand Down Expand Up @@ -983,6 +991,15 @@ lowerConstArrayAttr(mlir::cir::ConstArrayAttr constArr,
return std::nullopt;
}

bool hasTrailingZeros(mlir::cir::ConstArrayAttr attr) {
auto array = attr.getElts().dyn_cast<mlir::ArrayAttr>();
return attr.hasTrailingZeros() ||
(array && std::count_if(array.begin(), array.end(), [](auto elt) {
auto ar = dyn_cast<mlir::cir::ConstArrayAttr>(elt);
return ar && hasTrailingZeros(ar);
}));
}

class CIRConstantLowering
: public mlir::OpConversionPattern<mlir::cir::ConstantOp> {
public:
Expand Down Expand Up @@ -1034,8 +1051,13 @@ class CIRConstantLowering
return op.emitError() << "array does not have a constant initializer";

std::optional<mlir::Attribute> denseAttr;
if (constArr &&
(denseAttr = lowerConstArrayAttr(constArr, typeConverter))) {
if (constArr && hasTrailingZeros(constArr)) {
auto newOp =
lowerCirAttrAsValue(op, constArr, rewriter, getTypeConverter());
rewriter.replaceOp(op, newOp);
return mlir::success();
} else if (constArr &&
(denseAttr = lowerConstArrayAttr(constArr, typeConverter))) {
attr = denseAttr.value();
} else {
auto initVal =
Expand Down
10 changes: 10 additions & 0 deletions clang/test/CIR/CodeGen/const-array.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-cir %s -o - | FileCheck %s

void foo() {
int a[10] = {1};
}

// CHECK: cir.func {{.*@foo}}
// CHECK: %0 = cir.alloca !cir.array<!s32i x 10>, cir.ptr <!cir.array<!s32i x 10>>, ["a"] {alignment = 16 : i64}
// CHECK: %1 = cir.const(#cir.const_array<[#cir.int<1> : !s32i], trailing_zeros> : !cir.array<!s32i x 10>) : !cir.array<!s32i x 10>
// CHECK: cir.store %1, %0 : !cir.array<!s32i x 10>, cir.ptr <!cir.array<!s32i x 10>>
13 changes: 13 additions & 0 deletions clang/test/CIR/Lowering/const.cir
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,17 @@ module {
// CHECK: llvm.store %8, %1 : !llvm.array<1 x struct<"struct.anon.1", (i32, i32)>>, !llvm.ptr
// CHECK: llvm.return

cir.func @testArrWithTrailingZeros() {
%0 = cir.alloca !cir.array<!s32i x 10>, cir.ptr <!cir.array<!s32i x 10>>, ["a"] {alignment = 16 : i64}
%1 = cir.const(#cir.const_array<[#cir.int<1> : !s32i], trailing_zeros> : !cir.array<!s32i x 10>) : !cir.array<!s32i x 10>
cir.store %1, %0 : !cir.array<!s32i x 10>, cir.ptr <!cir.array<!s32i x 10>>
cir.return
}
// CHECK: llvm.func @testArrWithTrailingZeros()
// CHECK: %0 = llvm.mlir.constant(1 : index) : i64
// CHECK: %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 16 : i64} : (i64) -> !llvm.ptr
// CHECK: %2 = cir.llvmir.zeroinit : !llvm.array<10 x i32>
// CHECK: %3 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %4 = llvm.insertvalue %3, %2[0] : !llvm.array<10 x i32>

}
Loading