Skip to content

Commit

Permalink
[CIR][CodeGen][Lowering] Supports arrays with trailing zeros (llvm#393)
Browse files Browse the repository at this point in the history
This PR adds support for constant arrays with trailing zeros.

The original `CodeGen` does the following: once a constant array contain
trailing zeros, a struct with two members is generated: initialized
elements and `zeroinitializer` for the remaining part. And depending on
some conditions, `memset` or `memcpy` are emitted. In the latter case a
global const array is created.
Well, we may go this way, but it requires us to implement
[features](https://github.com/llvm/clangir/blob/main/clang/lib/CIR/CodeGen/CIRGenDecl.cpp#L182)
that are not implemented yet.

Another option is to add one more parameter to the `constArrayAttr` and
utilize it during the lowering. So far I chose this way, but if you have
any doubts, we can discuss here. So we just emit constant array as
usually and once there are trailing zeros, lower this arrray (i.e. an
attribute) as a value.

I added a couple of tests and will add more, once we agree on the
approach. So far I marked the PR as a draft one.
  • Loading branch information
gitoleg authored and lanza committed Oct 12, 2024
1 parent e98cdfb commit a2d87e1
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 10 deletions.
16 changes: 14 additions & 2 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,22 @@ def ConstArrayAttr : CIR_Attr<"ConstArray", "const_array", [TypedAttrInterface]>
}];

let parameters = (ins AttributeSelfTypeParameter<"">:$type,
"Attribute":$elts);
"Attribute":$elts,
"int":$trailingZerosNum);

// Define a custom builder for the type; that removes the need to pass
// in an MLIRContext instance, as it can be infered from the `type`.
let builders = [
AttrBuilderWithInferredContext<(ins "mlir::cir::ArrayType":$type,
"Attribute":$elts), [{
return $_get(type.getContext(), type, elts);
int zeros = 0;
auto typeSize = type.cast<mlir::cir::ArrayType>().getSize();
if (auto str = elts.dyn_cast<mlir::StringAttr>())
zeros = typeSize - str.size();
else
zeros = typeSize - elts.cast<mlir::ArrayAttr>().size();

return $_get(type.getContext(), type, elts, zeros);
}]>
];

Expand All @@ -132,6 +140,10 @@ def ConstArrayAttr : CIR_Attr<"ConstArray", "const_array", [TypedAttrInterface]>

// Enable verifier.
let genVerifyDecl = 1;

let extraClassDeclaration = [{
bool hasTrailingZeros() const { return getTrailingZerosNum() != 0; };
}];
}

//===----------------------------------------------------------------------===//
Expand Down
10 changes: 9 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenExprConst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -959,10 +959,18 @@ buildArrayConstant(CIRGenModule &CGM, mlir::Type DesiredType,
// Add a zeroinitializer array filler if we have lots of trailing zeroes.
unsigned TrailingZeroes = ArrayBound - NonzeroLength;
if (TrailingZeroes >= 8) {
assert(0 && "NYE");
assert(Elements.size() >= NonzeroLength &&
"missing initializer for non-zero element");

SmallVector<mlir::Attribute, 4> Eles;
Eles.reserve(Elements.size());
for (auto const &Element : Elements)
Eles.push_back(Element);

return builder.getConstArray(
mlir::ArrayAttr::get(builder.getContext(), Eles),
mlir::cir::ArrayType::get(builder.getContext(), CommonElementType,
ArrayBound));
// TODO(cir): If all the elements had the same type up to the trailing
// zeroes, emit a struct of two arrays (the nonzero data and the
// zeroinitializer). Use DesiredType to get the element type.
Expand Down
25 changes: 21 additions & 4 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2295,7 +2295,7 @@ mlir::OpTrait::impl::verifySameFirstSecondOperandAndResultType(Operation *op) {

LogicalResult mlir::cir::ConstArrayAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
::mlir::Type type, Attribute attr) {
::mlir::Type type, Attribute attr, int trailingZerosNum) {

if (!(attr.isa<mlir::ArrayAttr>() || attr.isa<mlir::StringAttr>()))
return emitError() << "constant array expects ArrayAttr or StringAttr";
Expand All @@ -2318,7 +2318,7 @@ LogicalResult mlir::cir::ConstArrayAttr::verify(
auto at = type.cast<ArrayType>();

// Make sure both number of elements and subelement types match type.
if (at.getSize() != arrayAttr.size())
if (at.getSize() != arrayAttr.size() + trailingZerosNum)
return emitError() << "constant array size should match type size";
LogicalResult eltTypeCheck = success();
arrayAttr.walkImmediateSubElements(
Expand Down Expand Up @@ -2383,16 +2383,33 @@ ::mlir::Attribute ConstArrayAttr::parse(::mlir::AsmParser &parser,
}
}

auto zeros = 0;
if (parser.parseOptionalComma().succeeded()) {
if (parser.parseOptionalKeyword("trailing_zeros").succeeded()) {
auto typeSize = resultTy.value().cast<mlir::cir::ArrayType>().getSize();
auto elts = resultVal.value();
if (auto str = elts.dyn_cast<mlir::StringAttr>())
zeros = typeSize - str.size();
else
zeros = typeSize - elts.cast<mlir::ArrayAttr>().size();
} else {
return {};
}
}

// Parse literal '>'
if (parser.parseGreater())
return {};
return parser.getChecked<ConstArrayAttr>(loc, parser.getContext(),
resultTy.value(), resultVal.value());

return parser.getChecked<ConstArrayAttr>(
loc, parser.getContext(), resultTy.value(), resultVal.value(), zeros);
}

void ConstArrayAttr::print(::mlir::AsmPrinter &printer) const {
printer << "<";
printer.printStrippedAttrOrType(getElts());
if (auto zeros = getTrailingZerosNum())
printer << ", trailing_zeros";
printer << ">";
}

Expand Down
28 changes: 25 additions & 3 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,15 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
const mlir::TypeConverter *converter) {
auto llvmTy = converter->convertType(constArr.getType());
auto loc = parentOp->getLoc();
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);
mlir::Value result;

if (auto zeros = constArr.getTrailingZerosNum()) {
auto arrayTy = constArr.getType();
result = rewriter.create<mlir::cir::ZeroInitConstOp>(
loc, converter->convertType(arrayTy));
} else {
result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);
}

// Iteratively lower each constant element of the array.
if (auto arrayAttr = constArr.getElts().dyn_cast<mlir::ArrayAttr>()) {
Expand Down Expand Up @@ -1069,6 +1077,15 @@ lowerConstArrayAttr(mlir::cir::ConstArrayAttr constArr,
return std::nullopt;
}

bool hasTrailingZeros(mlir::cir::ConstArrayAttr attr) {
auto array = attr.getElts().dyn_cast<mlir::ArrayAttr>();
return attr.hasTrailingZeros() ||
(array && std::count_if(array.begin(), array.end(), [](auto elt) {
auto ar = dyn_cast<mlir::cir::ConstArrayAttr>(elt);
return ar && hasTrailingZeros(ar);
}));
}

class CIRConstantLowering
: public mlir::OpConversionPattern<mlir::cir::ConstantOp> {
public:
Expand Down Expand Up @@ -1120,8 +1137,13 @@ class CIRConstantLowering
return op.emitError() << "array does not have a constant initializer";

std::optional<mlir::Attribute> denseAttr;
if (constArr &&
(denseAttr = lowerConstArrayAttr(constArr, typeConverter))) {
if (constArr && hasTrailingZeros(constArr)) {
auto newOp =
lowerCirAttrAsValue(op, constArr, rewriter, getTypeConverter());
rewriter.replaceOp(op, newOp);
return mlir::success();
} else if (constArr &&
(denseAttr = lowerConstArrayAttr(constArr, typeConverter))) {
attr = denseAttr.value();
} else {
auto initVal =
Expand Down
10 changes: 10 additions & 0 deletions clang/test/CIR/CodeGen/const-array.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-cir %s -o - | FileCheck %s

void foo() {
int a[10] = {1};
}

// CHECK: cir.func {{.*@foo}}
// CHECK: %0 = cir.alloca !cir.array<!s32i x 10>, cir.ptr <!cir.array<!s32i x 10>>, ["a"] {alignment = 16 : i64}
// CHECK: %1 = cir.const(#cir.const_array<[#cir.int<1> : !s32i], trailing_zeros> : !cir.array<!s32i x 10>) : !cir.array<!s32i x 10>
// CHECK: cir.store %1, %0 : !cir.array<!s32i x 10>, cir.ptr <!cir.array<!s32i x 10>>
13 changes: 13 additions & 0 deletions clang/test/CIR/Lowering/const.cir
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,17 @@ module {
// CHECK: llvm.store %8, %1 : !llvm.array<1 x struct<"struct.anon.1", (i32, i32)>>, !llvm.ptr
// CHECK: llvm.return

cir.func @testArrWithTrailingZeros() {
%0 = cir.alloca !cir.array<!s32i x 10>, cir.ptr <!cir.array<!s32i x 10>>, ["a"] {alignment = 16 : i64}
%1 = cir.const(#cir.const_array<[#cir.int<1> : !s32i], trailing_zeros> : !cir.array<!s32i x 10>) : !cir.array<!s32i x 10>
cir.store %1, %0 : !cir.array<!s32i x 10>, cir.ptr <!cir.array<!s32i x 10>>
cir.return
}
// CHECK: llvm.func @testArrWithTrailingZeros()
// CHECK: %0 = llvm.mlir.constant(1 : index) : i64
// CHECK: %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 16 : i64} : (i64) -> !llvm.ptr
// CHECK: %2 = cir.llvmir.zeroinit : !llvm.array<10 x i32>
// CHECK: %3 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %4 = llvm.insertvalue %3, %2[0] : !llvm.array<10 x i32>

}

0 comments on commit a2d87e1

Please sign in to comment.