Skip to content

Commit cc1c7d1

Browse files
seven-milelanza
authored andcommitted
[CIR][CIRGen] Support OpenCL Vector Types (llvm#613)
Resolve llvm#532 . Support CIRGen of `ExtVectorElementExpr` that includes swizzle `v.xyx` and subscription `v.s0`.
1 parent 4f7f17e commit cc1c7d1

File tree

6 files changed

+714
-2
lines changed

6 files changed

+714
-2
lines changed

clang/lib/CIR/CodeGen/CIRGenBuilder.h

+28
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,34 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
802802
return CIRBaseBuilderTy::createStore(loc, flag, dst);
803803
}
804804

805+
mlir::cir::VecShuffleOp
806+
createVecShuffle(mlir::Location loc, mlir::Value vec1, mlir::Value vec2,
807+
llvm::ArrayRef<mlir::Attribute> maskAttrs) {
808+
auto vecType = mlir::cast<mlir::cir::VectorType>(vec1.getType());
809+
auto resultTy = mlir::cir::VectorType::get(
810+
getContext(), vecType.getEltType(), maskAttrs.size());
811+
return CIRBaseBuilderTy::create<mlir::cir::VecShuffleOp>(
812+
loc, resultTy, vec1, vec2, getArrayAttr(maskAttrs));
813+
}
814+
815+
mlir::cir::VecShuffleOp createVecShuffle(mlir::Location loc, mlir::Value vec1,
816+
mlir::Value vec2,
817+
llvm::ArrayRef<int64_t> mask) {
818+
llvm::SmallVector<mlir::Attribute, 4> maskAttrs;
819+
for (int32_t idx : mask) {
820+
maskAttrs.push_back(mlir::cir::IntAttr::get(getSInt32Ty(), idx));
821+
}
822+
823+
return createVecShuffle(loc, vec1, vec2, maskAttrs);
824+
}
825+
826+
mlir::cir::VecShuffleOp createVecShuffle(mlir::Location loc, mlir::Value vec1,
827+
llvm::ArrayRef<int64_t> mask) {
828+
// FIXME(cir): Support use cir.vec.shuffle with single vec
829+
// Workaround: pass Vec as both vec1 and vec2
830+
return createVecShuffle(loc, vec1, vec1, mask);
831+
}
832+
805833
mlir::cir::StoreOp
806834
createAlignedStore(mlir::Location loc, mlir::Value val, mlir::Value dst,
807835
clang::CharUnits align = clang::CharUnits::One(),

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

+191
Original file line numberDiff line numberDiff line change
@@ -653,9 +653,55 @@ RValue CIRGenFunction::buildLoadOfLValue(LValue LV, SourceLocation Loc) {
653653
getLoc(Loc), load, LV.getVectorIdx()));
654654
}
655655

656+
if (LV.isExtVectorElt()) {
657+
return buildLoadOfExtVectorElementLValue(LV);
658+
}
659+
656660
llvm_unreachable("NYI");
657661
}
658662

663+
int64_t CIRGenFunction::getAccessedFieldNo(unsigned int idx,
664+
const mlir::ArrayAttr elts) {
665+
auto elt = mlir::dyn_cast<mlir::IntegerAttr>(elts[idx]);
666+
assert(elt && "The indices should be integer attributes");
667+
return elt.getInt();
668+
}
669+
670+
// If this is a reference to a subset of the elements of a vector, create an
671+
// appropriate shufflevector.
672+
RValue CIRGenFunction::buildLoadOfExtVectorElementLValue(LValue LV) {
673+
mlir::Location loc = LV.getExtVectorPointer().getLoc();
674+
mlir::Value Vec = builder.createLoad(loc, LV.getExtVectorAddress());
675+
676+
// HLSL allows treating scalars as one-element vectors. Converting the scalar
677+
// IR value to a vector here allows the rest of codegen to behave as normal.
678+
if (getLangOpts().HLSL && !mlir::isa<mlir::cir::VectorType>(Vec.getType())) {
679+
llvm_unreachable("HLSL NYI");
680+
}
681+
682+
const mlir::ArrayAttr Elts = LV.getExtVectorElts();
683+
684+
// If the result of the expression is a non-vector type, we must be extracting
685+
// a single element. Just codegen as an extractelement.
686+
const auto *ExprVT = LV.getType()->getAs<clang::VectorType>();
687+
if (!ExprVT) {
688+
int64_t InIdx = getAccessedFieldNo(0, Elts);
689+
mlir::cir::ConstantOp Elt =
690+
builder.getConstInt(loc, builder.getSInt64Ty(), InIdx);
691+
return RValue::get(builder.create<mlir::cir::VecExtractOp>(loc, Vec, Elt));
692+
}
693+
694+
// Always use shuffle vector to try to retain the original program structure
695+
unsigned NumResultElts = ExprVT->getNumElements();
696+
697+
SmallVector<int64_t, 4> Mask;
698+
for (unsigned i = 0; i != NumResultElts; ++i)
699+
Mask.push_back(getAccessedFieldNo(i, Elts));
700+
701+
Vec = builder.createVecShuffle(loc, Vec, Mask);
702+
return RValue::get(Vec);
703+
}
704+
659705
RValue CIRGenFunction::buildLoadOfBitfieldLValue(LValue LV,
660706
SourceLocation Loc) {
661707
const CIRGenBitFieldInfo &info = LV.getBitFieldInfo();
@@ -674,6 +720,80 @@ RValue CIRGenFunction::buildLoadOfBitfieldLValue(LValue LV,
674720
return RValue::get(field);
675721
}
676722

723+
void CIRGenFunction::buildStoreThroughExtVectorComponentLValue(RValue Src,
724+
LValue Dst) {
725+
mlir::Location loc = Dst.getExtVectorPointer().getLoc();
726+
727+
// HLSL allows storing to scalar values through ExtVector component LValues.
728+
// To support this we need to handle the case where the destination address is
729+
// a scalar.
730+
Address DstAddr = Dst.getExtVectorAddress();
731+
if (!mlir::isa<mlir::cir::VectorType>(DstAddr.getElementType())) {
732+
llvm_unreachable("HLSL NYI");
733+
}
734+
735+
// This access turns into a read/modify/write of the vector. Load the input
736+
// value now.
737+
mlir::Value Vec = builder.createLoad(loc, DstAddr);
738+
const mlir::ArrayAttr Elts = Dst.getExtVectorElts();
739+
740+
mlir::Value SrcVal = Src.getScalarVal();
741+
742+
if (const clang::VectorType *VTy =
743+
Dst.getType()->getAs<clang::VectorType>()) {
744+
unsigned NumSrcElts = VTy->getNumElements();
745+
unsigned NumDstElts = cast<mlir::cir::VectorType>(Vec.getType()).getSize();
746+
if (NumDstElts == NumSrcElts) {
747+
// Use shuffle vector is the src and destination are the same number of
748+
// elements and restore the vector mask since it is on the side it will be
749+
// stored.
750+
SmallVector<int64_t, 4> Mask(NumDstElts);
751+
for (unsigned i = 0; i != NumSrcElts; ++i)
752+
Mask[getAccessedFieldNo(i, Elts)] = i;
753+
754+
Vec = builder.createVecShuffle(loc, SrcVal, Mask);
755+
} else if (NumDstElts > NumSrcElts) {
756+
// Extended the source vector to the same length and then shuffle it
757+
// into the destination.
758+
// FIXME: since we're shuffling with undef, can we just use the indices
759+
// into that? This could be simpler.
760+
SmallVector<int64_t, 4> ExtMask;
761+
for (unsigned i = 0; i != NumSrcElts; ++i)
762+
ExtMask.push_back(i);
763+
ExtMask.resize(NumDstElts, -1);
764+
mlir::Value ExtSrcVal = builder.createVecShuffle(loc, SrcVal, ExtMask);
765+
// build identity
766+
SmallVector<int64_t, 4> Mask;
767+
for (unsigned i = 0; i != NumDstElts; ++i)
768+
Mask.push_back(i);
769+
770+
// When the vector size is odd and .odd or .hi is used, the last element
771+
// of the Elts constant array will be one past the size of the vector.
772+
// Ignore the last element here, if it is greater than the mask size.
773+
if (getAccessedFieldNo(NumSrcElts - 1, Elts) == Mask.size())
774+
llvm_unreachable("NYI");
775+
776+
// modify when what gets shuffled in
777+
for (unsigned i = 0; i != NumSrcElts; ++i)
778+
Mask[getAccessedFieldNo(i, Elts)] = i + NumDstElts;
779+
Vec = builder.createVecShuffle(loc, Vec, ExtSrcVal, Mask);
780+
} else {
781+
// We should never shorten the vector
782+
llvm_unreachable("unexpected shorten vector length");
783+
}
784+
} else {
785+
// If the Src is a scalar (not a vector), and the target is a vector it must
786+
// be updating one element.
787+
unsigned InIdx = getAccessedFieldNo(0, Elts);
788+
auto Elt = builder.getSInt64(InIdx, loc);
789+
790+
Vec = builder.create<mlir::cir::VecInsertOp>(loc, Vec, SrcVal, Elt);
791+
}
792+
793+
builder.createStore(loc, Vec, Dst.getExtVectorAddress(),
794+
Dst.isVolatileQualified());
795+
}
796+
677797
void CIRGenFunction::buildStoreThroughLValue(RValue Src, LValue Dst,
678798
bool isInit) {
679799
if (!Dst.isSimple()) {
@@ -686,6 +806,10 @@ void CIRGenFunction::buildStoreThroughLValue(RValue Src, LValue Dst,
686806
builder.createStore(loc, Vector, Dst.getVectorAddress());
687807
return;
688808
}
809+
810+
if (Dst.isExtVectorElt())
811+
return buildStoreThroughExtVectorComponentLValue(Src, Dst);
812+
689813
assert(Dst.isBitField() && "NIY LValue type");
690814
mlir::Value result;
691815
return buildStoreThroughBitfieldLValue(Src, Dst, result);
@@ -979,6 +1103,71 @@ CIRGenFunction::buildPointerToDataMemberBinaryExpr(const BinaryOperator *E) {
9791103
return makeAddrLValue(memberAddr, memberPtrTy->getPointeeType(), baseInfo);
9801104
}
9811105

1106+
LValue
1107+
CIRGenFunction::buildExtVectorElementExpr(const ExtVectorElementExpr *E) {
1108+
// Emit the base vector as an l-value.
1109+
LValue Base;
1110+
1111+
// ExtVectorElementExpr's base can either be a vector or pointer to vector.
1112+
if (E->isArrow()) {
1113+
// If it is a pointer to a vector, emit the address and form an lvalue with
1114+
// it.
1115+
LValueBaseInfo BaseInfo;
1116+
// TODO(cir): Support TBAA
1117+
assert(!MissingFeatures::tbaa());
1118+
Address Ptr = buildPointerWithAlignment(E->getBase(), &BaseInfo);
1119+
const auto *PT = E->getBase()->getType()->castAs<clang::PointerType>();
1120+
Base = makeAddrLValue(Ptr, PT->getPointeeType(), BaseInfo);
1121+
Base.getQuals().removeObjCGCAttr();
1122+
} else if (E->getBase()->isGLValue()) {
1123+
// Otherwise, if the base is an lvalue ( as in the case of foo.x.x),
1124+
// emit the base as an lvalue.
1125+
assert(E->getBase()->getType()->isVectorType());
1126+
Base = buildLValue(E->getBase());
1127+
} else {
1128+
// Otherwise, the base is a normal rvalue (as in (V+V).x), emit it as such.
1129+
assert(E->getBase()->getType()->isVectorType() &&
1130+
"Result must be a vector");
1131+
mlir::Value Vec = buildScalarExpr(E->getBase());
1132+
1133+
// Store the vector to memory (because LValue wants an address).
1134+
QualType BaseTy = E->getBase()->getType();
1135+
Address VecMem = CreateMemTemp(BaseTy, Vec.getLoc(), "tmp");
1136+
builder.createStore(Vec.getLoc(), Vec, VecMem);
1137+
Base = makeAddrLValue(VecMem, BaseTy, AlignmentSource::Decl);
1138+
}
1139+
1140+
QualType type =
1141+
E->getType().withCVRQualifiers(Base.getQuals().getCVRQualifiers());
1142+
1143+
// Encode the element access list into a vector of unsigned indices.
1144+
SmallVector<uint32_t, 4> indices;
1145+
E->getEncodedElementAccess(indices);
1146+
1147+
if (Base.isSimple()) {
1148+
SmallVector<int64_t, 4> attrElts;
1149+
for (uint32_t i : indices) {
1150+
attrElts.push_back(static_cast<int64_t>(i));
1151+
}
1152+
auto elts = builder.getI64ArrayAttr(attrElts);
1153+
return LValue::MakeExtVectorElt(Base.getAddress(), elts, type,
1154+
Base.getBaseInfo());
1155+
}
1156+
assert(Base.isExtVectorElt() && "Can only subscript lvalue vec elts here!");
1157+
1158+
mlir::ArrayAttr baseElts = Base.getExtVectorElts();
1159+
1160+
// Composite the two indices
1161+
SmallVector<int64_t, 4> attrElts;
1162+
for (uint32_t i : indices) {
1163+
attrElts.push_back(getAccessedFieldNo(i, baseElts));
1164+
}
1165+
auto elts = builder.getI64ArrayAttr(attrElts);
1166+
1167+
return LValue::MakeExtVectorElt(Base.getExtVectorAddress(), elts, type,
1168+
Base.getBaseInfo());
1169+
}
1170+
9821171
LValue CIRGenFunction::buildBinaryOperatorLValue(const BinaryOperator *E) {
9831172
// Comma expressions just emit their LHS then their RHS as an l-value.
9841173
if (E->getOpcode() == BO_Comma) {
@@ -2263,6 +2452,8 @@ LValue CIRGenFunction::buildLValue(const Expr *E) {
22632452
return buildConditionalOperatorLValue(cast<ConditionalOperator>(E));
22642453
case Expr::ArraySubscriptExprClass:
22652454
return buildArraySubscriptExpr(cast<ArraySubscriptExpr>(E));
2455+
case Expr::ExtVectorElementExprClass:
2456+
return buildExtVectorElementExpr(cast<ExtVectorElementExpr>(E));
22662457
case Expr::BinaryOperatorClass:
22672458
return buildBinaryOperatorLValue(cast<BinaryOperator>(E));
22682459
case Expr::CompoundAssignOperatorClass: {

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,12 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
328328
E->getSrcExpr()->getType(), E->getType(),
329329
E->getSourceRange().getBegin());
330330
}
331+
332+
mlir::Value VisitExtVectorElementExpr(Expr *E) {
333+
return buildLoadOfLValue(E);
334+
}
335+
331336
mlir::Value VisitMemberExpr(MemberExpr *E);
332-
mlir::Value VisitExtVectorelementExpr(Expr *E) { llvm_unreachable("NYI"); }
333337
mlir::Value VisitCompoundLiteralEpxr(CompoundLiteralExpr *E) {
334338
llvm_unreachable("NYI");
335339
}

clang/lib/CIR/CodeGen/CIRGenFunction.h

+7
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,12 @@ class CIRGenFunction : public CIRGenTypeCache {
763763
mlir::Location Loc, LValueBaseInfo BaseInfo,
764764
bool isNontemporal = false);
765765

766+
int64_t getAccessedFieldNo(unsigned idx, const mlir::ArrayAttr elts);
767+
768+
RValue buildLoadOfExtVectorElementLValue(LValue LV);
769+
770+
void buildStoreThroughExtVectorComponentLValue(RValue Src, LValue Dst);
771+
766772
RValue buildLoadOfBitfieldLValue(LValue LV, SourceLocation Loc);
767773

768774
/// Load a scalar value from an address, taking care to appropriately convert
@@ -1223,6 +1229,7 @@ class CIRGenFunction : public CIRGenTypeCache {
12231229
LValue lvalue, bool capturedByInit = false);
12241230

12251231
LValue buildDeclRefLValue(const clang::DeclRefExpr *E);
1232+
LValue buildExtVectorElementExpr(const ExtVectorElementExpr *E);
12261233
LValue buildBinaryOperatorLValue(const clang::BinaryOperator *E);
12271234
LValue buildCompoundAssignmentLValue(const clang::CompoundAssignOperator *E);
12281235
LValue buildUnaryOpLValue(const clang::UnaryOperator *E);

clang/lib/CIR/CodeGen/CIRGenValue.h

+29-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ class LValue {
207207
unsigned Alignment;
208208
mlir::Value V;
209209
mlir::Type ElementType;
210-
mlir::Value VectorIdx; // Index for vector subscript
210+
mlir::Value VectorIdx; // Index for vector subscript
211+
mlir::Attribute VectorElts; // ExtVector element subset: V.xyx
211212
LValueBaseInfo BaseInfo;
212213
const CIRGenBitFieldInfo *BitFieldInfo{0};
213214

@@ -316,6 +317,20 @@ class LValue {
316317
return VectorIdx;
317318
}
318319

320+
// extended vector elements.
321+
Address getExtVectorAddress() const {
322+
assert(isExtVectorElt());
323+
return Address(getExtVectorPointer(), ElementType, getAlignment());
324+
}
325+
mlir::Value getExtVectorPointer() const {
326+
assert(isExtVectorElt());
327+
return V;
328+
}
329+
mlir::ArrayAttr getExtVectorElts() const {
330+
assert(isExtVectorElt());
331+
return mlir::cast<mlir::ArrayAttr>(VectorElts);
332+
}
333+
319334
static LValue MakeVectorElt(Address vecAddress, mlir::Value Index,
320335
clang::QualType type, LValueBaseInfo BaseInfo) {
321336
LValue R;
@@ -328,6 +343,19 @@ class LValue {
328343
return R;
329344
}
330345

346+
static LValue MakeExtVectorElt(Address vecAddress, mlir::ArrayAttr elts,
347+
clang::QualType type,
348+
LValueBaseInfo baseInfo) {
349+
LValue R;
350+
R.LVType = ExtVectorElt;
351+
R.V = vecAddress.getPointer();
352+
R.ElementType = vecAddress.getElementType();
353+
R.VectorElts = elts;
354+
R.Initialize(type, type.getQualifiers(), vecAddress.getAlignment(),
355+
baseInfo);
356+
return R;
357+
}
358+
331359
// bitfield lvalue
332360
Address getBitFieldAddress() const {
333361
return Address(getBitFieldPointer(), ElementType, getAlignment());

0 commit comments

Comments
 (0)