Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][CodeGen][Lowering] Supports arrays with trailing zeros #393

Merged
merged 5 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,16 @@ def ConstArrayAttr : CIR_Attr<"ConstArray", "const_array", [TypedAttrInterface]>
}];

let parameters = (ins AttributeSelfTypeParameter<"">:$type,
"Attribute":$elts);
"Attribute":$elts,
"bool":$hasTrailingZeros);

// Define a custom builder for the type; that removes the need to pass
// in an MLIRContext instance, as it can be infered from the `type`.
let builders = [
AttrBuilderWithInferredContext<(ins "mlir::cir::ArrayType":$type,
"Attribute":$elts), [{
return $_get(type.getContext(), type, elts);
"Attribute":$elts,
CArg<"bool", "false">:$hasTrailingZeros), [{
return $_get(type.getContext(), type, elts, hasTrailingZeros);
}]>
];

Expand All @@ -132,6 +134,19 @@ def ConstArrayAttr : CIR_Attr<"ConstArray", "const_array", [TypedAttrInterface]>

// Enable verifier.
let genVerifyDecl = 1;

let extraClassDeclaration = [{
bool hasTrailingZeros() const { return getTrailingZerosNum() != 0; };

int getTrailingZerosNum() const {
auto typeSize = getType().cast<mlir::cir::ArrayType>().getSize();
auto elts = getElts();
if (auto str = elts.dyn_cast<mlir::StringAttr>())
return typeSize - str.size();
else
return typeSize - getElts().cast<mlir::ArrayAttr>().size();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the trailing number of zeros is a constant, can we compute this during attr creation and only return a value here instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point!

}
}];
}

//===----------------------------------------------------------------------===//
Expand Down
5 changes: 3 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,9 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
}

mlir::cir::ConstArrayAttr getConstArray(mlir::Attribute attrs,
mlir::cir::ArrayType arrayTy) {
return mlir::cir::ConstArrayAttr::get(arrayTy, attrs);
mlir::cir::ArrayType arrayTy,
bool hasTrailingZeros = false) {
return mlir::cir::ConstArrayAttr::get(arrayTy, attrs, hasTrailingZeros);
}

mlir::Attribute getConstStructOrZeroAttr(mlir::ArrayAttr arrayAttr,
Expand Down
11 changes: 10 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenExprConst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -956,10 +956,19 @@ buildArrayConstant(CIRGenModule &CGM, mlir::Type DesiredType,
// Add a zeroinitializer array filler if we have lots of trailing zeroes.
unsigned TrailingZeroes = ArrayBound - NonzeroLength;
if (TrailingZeroes >= 8) {
assert(0 && "NYE");
assert(Elements.size() >= NonzeroLength &&
"missing initializer for non-zero element");

SmallVector<mlir::Attribute, 4> Eles;
Eles.reserve(Elements.size());
for (auto const &Element : Elements)
Eles.push_back(Element);

return builder.getConstArray(
mlir::ArrayAttr::get(builder.getContext(), Eles),
mlir::cir::ArrayType::get(builder.getContext(), CommonElementType,
ArrayBound),
TrailingZeroes);
// TODO(cir): If all the elements had the same type up to the trailing
// zeroes, emit a struct of two arrays (the nonzero data and the
// zeroinitializer). Use DesiredType to get the element type.
Expand Down
21 changes: 17 additions & 4 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2282,7 +2282,7 @@ mlir::OpTrait::impl::verifySameFirstSecondOperandAndResultType(Operation *op) {

LogicalResult mlir::cir::ConstArrayAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
::mlir::Type type, Attribute attr) {
::mlir::Type type, Attribute attr, bool hasTrailingZeros) {

if (!(attr.isa<mlir::ArrayAttr>() || attr.isa<mlir::StringAttr>()))
return emitError() << "constant array expects ArrayAttr or StringAttr";
Expand All @@ -2305,7 +2305,9 @@ LogicalResult mlir::cir::ConstArrayAttr::verify(
auto at = type.cast<ArrayType>();

// Make sure both number of elements and subelement types match type.
if (at.getSize() != arrayAttr.size())
auto trailingZeros = at.getSize() - arrayAttr.size();
if ((!hasTrailingZeros && trailingZeros) ||
(hasTrailingZeros && !trailingZeros))
return emitError() << "constant array size should match type size";
LogicalResult eltTypeCheck = success();
arrayAttr.walkImmediateSubElements(
Expand Down Expand Up @@ -2370,16 +2372,27 @@ ::mlir::Attribute ConstArrayAttr::parse(::mlir::AsmParser &parser,
}
}

bool hasZeros = false;
if (parser.parseOptionalComma().succeeded()) {
if (parser.parseOptionalKeyword("trailingZeros").succeeded())
hasZeros = true;
else
return {};
}

// Parse literal '>'
if (parser.parseGreater())
return {};
return parser.getChecked<ConstArrayAttr>(loc, parser.getContext(),
resultTy.value(), resultVal.value());

return parser.getChecked<ConstArrayAttr>(
loc, parser.getContext(), resultTy.value(), resultVal.value(), hasZeros);
}

void ConstArrayAttr::print(::mlir::AsmPrinter &printer) const {
printer << "<";
printer.printStrippedAttrOrType(getElts());
if (auto zeros = getTrailingZerosNum())
printer << ", trailingZeros";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For printed CIR I suggest we use trailing_zeros or trailingzeros.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

printer << ">";
}

Expand Down
28 changes: 25 additions & 3 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,15 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
const mlir::TypeConverter *converter) {
auto llvmTy = converter->convertType(constArr.getType());
auto loc = parentOp->getLoc();
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);
mlir::Value result;

if (auto zeros = constArr.getTrailingZerosNum()) {
auto arrayTy = constArr.getType();
result = rewriter.create<mlir::cir::ZeroInitConstOp>(
loc, converter->convertType(arrayTy));
} else {
result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);
}

// Iteratively lower each constant element of the array.
if (auto arrayAttr = constArr.getElts().dyn_cast<mlir::ArrayAttr>()) {
Expand Down Expand Up @@ -983,6 +991,15 @@ lowerConstArrayAttr(mlir::cir::ConstArrayAttr constArr,
return std::nullopt;
}

bool hasTrailingZeros(mlir::cir::ConstArrayAttr attr) {
auto array = attr.getElts().dyn_cast<mlir::ArrayAttr>();
return attr.hasTrailingZeros() ||
(array && std::count_if(array.begin(), array.end(), [](auto elt) {
auto ar = dyn_cast<mlir::cir::ConstArrayAttr>(elt);
return ar && hasTrailingZeros(ar);
}));
}

class CIRConstantLowering
: public mlir::OpConversionPattern<mlir::cir::ConstantOp> {
public:
Expand Down Expand Up @@ -1034,8 +1051,13 @@ class CIRConstantLowering
return op.emitError() << "array does not have a constant initializer";

std::optional<mlir::Attribute> denseAttr;
if (constArr &&
(denseAttr = lowerConstArrayAttr(constArr, typeConverter))) {
if (constArr && hasTrailingZeros(constArr)) {
auto newOp =
lowerCirAttrAsValue(op, constArr, rewriter, getTypeConverter());
rewriter.replaceOp(op, newOp);
return mlir::success();
} else if (constArr &&
(denseAttr = lowerConstArrayAttr(constArr, typeConverter))) {
attr = denseAttr.value();
} else {
auto initVal =
Expand Down
10 changes: 10 additions & 0 deletions clang/test/CIR/CodeGen/const-array.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-cir %s -o - | FileCheck %s

void foo() {
int a[10] = {1};
}

// CHECK: cir.func {{.*@foo}}
// CHECK: %0 = cir.alloca !cir.array<!s32i x 10>, cir.ptr <!cir.array<!s32i x 10>>, ["a"] {alignment = 16 : i64}
// CHECK: %1 = cir.const(#cir.const_array<[#cir.int<1> : !s32i], trailingZeros> : !cir.array<!s32i x 10>) : !cir.array<!s32i x 10>
// CHECK: cir.store %1, %0 : !cir.array<!s32i x 10>, cir.ptr <!cir.array<!s32i x 10>>
13 changes: 13 additions & 0 deletions clang/test/CIR/Lowering/const.cir
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,17 @@ module {
// CHECK: llvm.store %8, %1 : !llvm.array<1 x struct<"struct.anon.1", (i32, i32)>>, !llvm.ptr
// CHECK: llvm.return

cir.func @testArrWithTrailingZeros() {
%0 = cir.alloca !cir.array<!s32i x 10>, cir.ptr <!cir.array<!s32i x 10>>, ["a"] {alignment = 16 : i64}
%1 = cir.const(#cir.const_array<[#cir.int<1> : !s32i], trailingZeros> : !cir.array<!s32i x 10>) : !cir.array<!s32i x 10>
cir.store %1, %0 : !cir.array<!s32i x 10>, cir.ptr <!cir.array<!s32i x 10>>
cir.return
}
// CHECK: llvm.func @testArrWithTrailingZeros()
// CHECK: %0 = llvm.mlir.constant(1 : index) : i64
// CHECK: %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 16 : i64} : (i64) -> !llvm.ptr
// CHECK: %2 = cir.llvmir.zeroinit : !llvm.array<10 x i32>
// CHECK: %3 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %4 = llvm.insertvalue %3, %2[0] : !llvm.array<10 x i32>

}
Loading