Skip to content

Commit

Permalink
[CIR] Vector types, part 2 (#387)
Browse files Browse the repository at this point in the history
This is part 2 of implementing vector types and vector operations in
ClangIR, issue #284.

Create new operation `cir.vec.insert`, which changes one element of an
existing vector object and returns the modified vector object. The input
and output vectors are prvalues; this operation does not touch memory.
The assembly format and the order of the arguments match that of
llvm.insertelement in the LLVM dialect, since the operations have
identical semantics.

Implement vector element lvalues in class `LValue`, adding member
functions `getVectorAddress()`, `getVectorPointer()`, `getVectorIdx()`,
and `MakeVectorElt(...)`.

The assembly format for operation `cir.vec.extract` was changed to match
that of llvm.extractelement in the LLVM dialect, since the operations
have identical semantics.

These two features, `cir.vec.insert` and vector element lvalues, are
used to implement `v[n] = e`, where `v` is a vector. This is a little
tricky, because `v[n]` isn't really an lvalue, as its address cannot be
taken. The only place it can be used as an lvalue is on the left-hand
side of an assignment.

Implement unary operators on vector objects (except for logical not on a
vector mask, which will be covered in a future commit for boolean
vectors). The code for lowering cir.unary for all types, in
`CIRUnaryOpLowering::matchAndRewrite`, was largely rewritten. Support
for unary `+` on non-vector pointer types was added. (It was already
supported and tested in AST->ClangIR CodeGen, but was missing from
ClangIR->LLVM Dialect lowering.)

Add tests for all binary vector arithmetic operations other than
relational operators and shift operators. There were all working after
the previous vector types commit, but only addition had beet tested at
the time.

Co-authored-by: Bruno Cardoso Lopes <[email protected]>
  • Loading branch information
2 people authored and lanza committed Apr 29, 2024
1 parent dda9575 commit b2b60d6
Show file tree
Hide file tree
Showing 9 changed files with 519 additions and 93 deletions.
33 changes: 29 additions & 4 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1672,14 +1672,39 @@ def GetMemberOp : CIR_Op<"get_member"> {
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// VecInsertOp
//===----------------------------------------------------------------------===//

def VecInsertOp : CIR_Op<"vec.insert", [Pure,
TypesMatchWith<"argument type matches vector element type", "vec", "value",
"$_self.cast<VectorType>().getEltType()">,
AllTypesMatch<["result", "vec"]>]> {

let summary = "Insert one element into a vector object";
let description = [{
The `cir.vec.insert` operation replaces the element of the given vector at
the given index with the given value. The new vector with the inserted
element is returned.
}];

let arguments = (ins CIR_VectorType:$vec, AnyType:$value, CIR_IntType:$index);
let results = (outs CIR_VectorType:$result);

let assemblyFormat = [{
$value `,` $vec `[` $index `:` type($index) `]` attr-dict `:` type($vec)
}];

let hasVerifier = 0;
}

//===----------------------------------------------------------------------===//
// VecExtractOp
//===----------------------------------------------------------------------===//

def VecExtractOp : CIR_Op<"vec.extract", [Pure,
TypesMatchWith<"type of 'result' matches element type of 'vec'",
"vec", "result",
"$_self.cast<VectorType>().getEltType()">]> {
TypesMatchWith<"type of 'result' matches element type of 'vec'", "vec",
"result", "$_self.cast<VectorType>().getEltType()">]> {

let summary = "Extract one element from a vector object";
let description = [{
Expand All @@ -1691,7 +1716,7 @@ def VecExtractOp : CIR_Op<"vec.extract", [Pure,
let results = (outs CIR_AnyType:$result);

let assemblyFormat = [{
$vec `[` $index `:` type($index) `]` type($vec) `->` type($result) attr-dict
$vec `[` $index `:` type($index) `]` attr-dict `:` type($vec)
}];

let hasVerifier = 0;
Expand Down
37 changes: 28 additions & 9 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,16 +545,18 @@ void CIRGenFunction::buildStoreOfScalar(mlir::Value Value, Address Addr,
bool Volatile, QualType Ty,
LValueBaseInfo BaseInfo, bool isInit,
bool isNontemporal) {
if (!CGM.getCodeGenOpts().PreserveVec3Type && Ty->isVectorType() &&
Ty->castAs<clang::VectorType>()->getNumElements() == 3)
llvm_unreachable("NYI: Special treatment of 3-element vectors");

Value = buildToMemory(Value, Ty);

if (Ty->isAtomicType()) {
llvm_unreachable("NYI");
}

if (const auto *ClangVecTy = Ty->getAs<clang::VectorType>()) {
if (!CGM.getCodeGenOpts().PreserveVec3Type &&
ClangVecTy->getNumElements() == 3)
llvm_unreachable("NYI: Special treatment of 3-element vector store");
}

// Update the alloca with more info on initialization.
assert(Addr.getPointer() && "expected pointer to exist");
auto SrcAlloca =
Expand Down Expand Up @@ -622,6 +624,18 @@ RValue CIRGenFunction::buildLoadOfBitfieldLValue(LValue LV,
}

void CIRGenFunction::buildStoreThroughLValue(RValue Src, LValue Dst) {
if (!Dst.isSimple()) {
if (Dst.isVectorElt()) {
// Read/modify/write the vector, inserting the new element
mlir::Location loc = Dst.getVectorPointer().getLoc();
mlir::Value Vector = builder.createLoad(loc, Dst.getVectorAddress());
Vector = builder.create<mlir::cir::VecInsertOp>(
loc, Vector, Src.getScalarVal(), Dst.getVectorIdx());
builder.createStore(loc, Vector, Dst.getVectorAddress());
return;
}
llvm_unreachable("NYI: non-simple store through lvalue");
}
assert(Dst.isSimple() && "only implemented simple");

// There's special magic for assigning into an ARC-qualified l-value.
Expand Down Expand Up @@ -1387,7 +1401,10 @@ LValue CIRGenFunction::buildArraySubscriptExpr(const ArraySubscriptExpr *E,
// with this subscript.
if (E->getBase()->getType()->isVectorType() &&
!isa<ExtVectorElementExpr>(E->getBase())) {
llvm_unreachable("vector subscript is NYI");
LValue LHS = buildLValue(E->getBase());
auto Index = EmitIdxAfterBase(/*Promote=*/false);
return LValue::MakeVectorElt(LHS.getAddress(), Index,
E->getBase()->getType(), LHS.getBaseInfo());
}

// All the other cases basically behave like simple offsetting.
Expand Down Expand Up @@ -2371,16 +2388,18 @@ mlir::Value CIRGenFunction::buildLoadOfScalar(Address Addr, bool Volatile,
QualType Ty, mlir::Location Loc,
LValueBaseInfo BaseInfo,
bool isNontemporal) {
if (!CGM.getCodeGenOpts().PreserveVec3Type && Ty->isVectorType() &&
Ty->castAs<clang::VectorType>()->getNumElements() == 3)
llvm_unreachable("NYI: Special treatment of 3-element vectors");

// Atomic operations have to be done on integral types
LValue AtomicLValue = LValue::makeAddr(Addr, Ty, getContext(), BaseInfo);
if (Ty->isAtomicType() || LValueIsSuitableForInlineAtomic(AtomicLValue)) {
llvm_unreachable("NYI");
}

if (const auto *ClangVecTy = Ty->getAs<clang::VectorType>()) {
if (!CGM.getCodeGenOpts().PreserveVec3Type &&
ClangVecTy->getNumElements() == 3)
llvm_unreachable("NYI: Special treatment of 3-element vector load");
}

mlir::cir::LoadOp Load = builder.create<mlir::cir::LoadOp>(
Loc, Addr.getElementType(), Addr.getPointer());

Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1646,7 +1646,7 @@ mlir::Value ScalarExprEmitter::VisitUnaryLNot(const UnaryOperator *E) {
if (dstTy.isa<mlir::cir::BoolType>())
return boolVal;

llvm_unreachable("destination type for negation unary operator is NYI");
llvm_unreachable("destination type for logical-not unary operator is NYI");
}

// Conversion from bool, integral, or floating-point to integral or
Expand Down
26 changes: 26 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ class LValue {
unsigned Alignment;
mlir::Value V;
mlir::Type ElementType;
mlir::Value VectorIdx; // Index for vector subscript
LValueBaseInfo BaseInfo;
const CIRGenBitFieldInfo *BitFieldInfo{0};

Expand Down Expand Up @@ -301,6 +302,31 @@ class LValue {
const clang::Qualifiers &getQuals() const { return Quals; }
clang::Qualifiers &getQuals() { return Quals; }

// vector element lvalue
Address getVectorAddress() const {
return Address(getVectorPointer(), ElementType, getAlignment());
}
mlir::Value getVectorPointer() const {
assert(isVectorElt());
return V;
}
mlir::Value getVectorIdx() const {
assert(isVectorElt());
return VectorIdx;
}

static LValue MakeVectorElt(Address vecAddress, mlir::Value Index,
clang::QualType type, LValueBaseInfo BaseInfo) {
LValue R;
R.LVType = VectorElt;
R.V = vecAddress.getPointer();
R.ElementType = vecAddress.getElementType();
R.VectorIdx = Index;
R.Initialize(type, type.getQualifiers(), vecAddress.getAlignment(),
BaseInfo);
return R;
}

// bitfield lvalue
Address getBitFieldAddress() const {
return Address(getBitFieldPointer(), ElementType, getAlignment());
Expand Down
Loading

0 comments on commit b2b60d6

Please sign in to comment.