Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][CIRGen] Support OpenCL Vector Types #613

Merged
merged 10 commits into from
Jun 6, 2024
28 changes: 28 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,34 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
return CIRBaseBuilderTy::createStore(loc, flag, dst);
}

mlir::cir::VecShuffleOp
createVecShuffle(mlir::Location loc, mlir::Value vec1, mlir::Value vec2,
llvm::ArrayRef<mlir::Attribute> maskAttrs) {
auto vecType = mlir::cast<mlir::cir::VectorType>(vec1.getType());
auto resultTy = mlir::cir::VectorType::get(
getContext(), vecType.getEltType(), maskAttrs.size());
return CIRBaseBuilderTy::create<mlir::cir::VecShuffleOp>(
loc, resultTy, vec1, vec2, getArrayAttr(maskAttrs));
}

mlir::cir::VecShuffleOp createVecShuffle(mlir::Location loc, mlir::Value vec1,
mlir::Value vec2,
llvm::ArrayRef<int64_t> mask) {
llvm::SmallVector<mlir::Attribute, 4> maskAttrs;
for (int32_t idx : mask) {
maskAttrs.push_back(mlir::cir::IntAttr::get(getSInt32Ty(), idx));
}

return createVecShuffle(loc, vec1, vec2, maskAttrs);
}

mlir::cir::VecShuffleOp createVecShuffle(mlir::Location loc, mlir::Value vec1,
llvm::ArrayRef<int64_t> mask) {
// FIXME(cir): Support use cir.vec.shuffle with single vec
// Workaround: pass Vec as both vec1 and vec2
return createVecShuffle(loc, vec1, vec1, mask);
}

mlir::cir::StoreOp
createAlignedStore(mlir::Location loc, mlir::Value val, mlir::Value dst,
clang::CharUnits align = clang::CharUnits::One(),
Expand Down
191 changes: 191 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -653,9 +653,55 @@ RValue CIRGenFunction::buildLoadOfLValue(LValue LV, SourceLocation Loc) {
getLoc(Loc), load, LV.getVectorIdx()));
}

if (LV.isExtVectorElt()) {
return buildLoadOfExtVectorElementLValue(LV);
}

llvm_unreachable("NYI");
}

int64_t CIRGenFunction::getAccessedFieldNo(unsigned int idx,
const mlir::ArrayAttr elts) {
auto elt = mlir::dyn_cast<mlir::IntegerAttr>(elts[idx]);
assert(elt && "The indices should be integer attributes");
return elt.getInt();
}

// If this is a reference to a subset of the elements of a vector, create an
// appropriate shufflevector.
RValue CIRGenFunction::buildLoadOfExtVectorElementLValue(LValue LV) {
mlir::Location loc = LV.getExtVectorPointer().getLoc();
mlir::Value Vec = builder.createLoad(loc, LV.getExtVectorAddress());

// HLSL allows treating scalars as one-element vectors. Converting the scalar
// IR value to a vector here allows the rest of codegen to behave as normal.
if (getLangOpts().HLSL && !mlir::isa<mlir::cir::VectorType>(Vec.getType())) {
llvm_unreachable("HLSL NYI");
}

const mlir::ArrayAttr Elts = LV.getExtVectorElts();

// If the result of the expression is a non-vector type, we must be extracting
// a single element. Just codegen as an extractelement.
const auto *ExprVT = LV.getType()->getAs<clang::VectorType>();
if (!ExprVT) {
int64_t InIdx = getAccessedFieldNo(0, Elts);
mlir::cir::ConstantOp Elt =
builder.getConstInt(loc, builder.getSInt64Ty(), InIdx);
return RValue::get(builder.create<mlir::cir::VecExtractOp>(loc, Vec, Elt));
}

// Always use shuffle vector to try to retain the original program structure
unsigned NumResultElts = ExprVT->getNumElements();

SmallVector<int64_t, 4> Mask;
for (unsigned i = 0; i != NumResultElts; ++i)
Mask.push_back(getAccessedFieldNo(i, Elts));

Vec = builder.createVecShuffle(loc, Vec, Mask);
return RValue::get(Vec);
}

RValue CIRGenFunction::buildLoadOfBitfieldLValue(LValue LV,
SourceLocation Loc) {
const CIRGenBitFieldInfo &info = LV.getBitFieldInfo();
Expand All @@ -674,6 +720,80 @@ RValue CIRGenFunction::buildLoadOfBitfieldLValue(LValue LV,
return RValue::get(field);
}

void CIRGenFunction::buildStoreThroughExtVectorComponentLValue(RValue Src,
LValue Dst) {
mlir::Location loc = Dst.getExtVectorPointer().getLoc();

// HLSL allows storing to scalar values through ExtVector component LValues.
// To support this we need to handle the case where the destination address is
// a scalar.
Address DstAddr = Dst.getExtVectorAddress();
if (!mlir::isa<mlir::cir::VectorType>(DstAddr.getElementType())) {
llvm_unreachable("HLSL NYI");
}

// This access turns into a read/modify/write of the vector. Load the input
// value now.
mlir::Value Vec = builder.createLoad(loc, DstAddr);
const mlir::ArrayAttr Elts = Dst.getExtVectorElts();

mlir::Value SrcVal = Src.getScalarVal();

if (const clang::VectorType *VTy =
Dst.getType()->getAs<clang::VectorType>()) {
unsigned NumSrcElts = VTy->getNumElements();
unsigned NumDstElts = cast<mlir::cir::VectorType>(Vec.getType()).getSize();
if (NumDstElts == NumSrcElts) {
// Use shuffle vector is the src and destination are the same number of
// elements and restore the vector mask since it is on the side it will be
// stored.
SmallVector<int64_t, 4> Mask(NumDstElts);
for (unsigned i = 0; i != NumSrcElts; ++i)
Mask[getAccessedFieldNo(i, Elts)] = i;

Vec = builder.createVecShuffle(loc, SrcVal, Mask);
} else if (NumDstElts > NumSrcElts) {
// Extended the source vector to the same length and then shuffle it
// into the destination.
// FIXME: since we're shuffling with undef, can we just use the indices
// into that? This could be simpler.
SmallVector<int64_t, 4> ExtMask;
for (unsigned i = 0; i != NumSrcElts; ++i)
ExtMask.push_back(i);
ExtMask.resize(NumDstElts, -1);
mlir::Value ExtSrcVal = builder.createVecShuffle(loc, SrcVal, ExtMask);
// build identity
SmallVector<int64_t, 4> Mask;
for (unsigned i = 0; i != NumDstElts; ++i)
Mask.push_back(i);

// When the vector size is odd and .odd or .hi is used, the last element
// of the Elts constant array will be one past the size of the vector.
// Ignore the last element here, if it is greater than the mask size.
if (getAccessedFieldNo(NumSrcElts - 1, Elts) == Mask.size())
llvm_unreachable("NYI");

// modify when what gets shuffled in
for (unsigned i = 0; i != NumSrcElts; ++i)
Mask[getAccessedFieldNo(i, Elts)] = i + NumDstElts;
Vec = builder.createVecShuffle(loc, Vec, ExtSrcVal, Mask);
} else {
// We should never shorten the vector
llvm_unreachable("unexpected shorten vector length");
}
} else {
// If the Src is a scalar (not a vector), and the target is a vector it must
// be updating one element.
unsigned InIdx = getAccessedFieldNo(0, Elts);
auto Elt = builder.getSInt64(InIdx, loc);

Vec = builder.create<mlir::cir::VecInsertOp>(loc, Vec, SrcVal, Elt);
}

builder.createStore(loc, Vec, Dst.getExtVectorAddress(),
Dst.isVolatileQualified());
}

void CIRGenFunction::buildStoreThroughLValue(RValue Src, LValue Dst,
bool isInit) {
if (!Dst.isSimple()) {
Expand All @@ -686,6 +806,10 @@ void CIRGenFunction::buildStoreThroughLValue(RValue Src, LValue Dst,
builder.createStore(loc, Vector, Dst.getVectorAddress());
return;
}

if (Dst.isExtVectorElt())
return buildStoreThroughExtVectorComponentLValue(Src, Dst);

assert(Dst.isBitField() && "NIY LValue type");
mlir::Value result;
return buildStoreThroughBitfieldLValue(Src, Dst, result);
Expand Down Expand Up @@ -979,6 +1103,71 @@ CIRGenFunction::buildPointerToDataMemberBinaryExpr(const BinaryOperator *E) {
return makeAddrLValue(memberAddr, memberPtrTy->getPointeeType(), baseInfo);
}

LValue
CIRGenFunction::buildExtVectorElementExpr(const ExtVectorElementExpr *E) {
// Emit the base vector as an l-value.
LValue Base;

// ExtVectorElementExpr's base can either be a vector or pointer to vector.
if (E->isArrow()) {
// If it is a pointer to a vector, emit the address and form an lvalue with
// it.
LValueBaseInfo BaseInfo;
// TODO(cir): Support TBAA
assert(!MissingFeatures::tbaa());
Address Ptr = buildPointerWithAlignment(E->getBase(), &BaseInfo);
const auto *PT = E->getBase()->getType()->castAs<clang::PointerType>();
Base = makeAddrLValue(Ptr, PT->getPointeeType(), BaseInfo);
Base.getQuals().removeObjCGCAttr();
} else if (E->getBase()->isGLValue()) {
// Otherwise, if the base is an lvalue ( as in the case of foo.x.x),
// emit the base as an lvalue.
assert(E->getBase()->getType()->isVectorType());
Base = buildLValue(E->getBase());
} else {
// Otherwise, the base is a normal rvalue (as in (V+V).x), emit it as such.
assert(E->getBase()->getType()->isVectorType() &&
"Result must be a vector");
mlir::Value Vec = buildScalarExpr(E->getBase());

// Store the vector to memory (because LValue wants an address).
QualType BaseTy = E->getBase()->getType();
Address VecMem = CreateMemTemp(BaseTy, Vec.getLoc(), "tmp");
builder.createStore(Vec.getLoc(), Vec, VecMem);
Base = makeAddrLValue(VecMem, BaseTy, AlignmentSource::Decl);
}

QualType type =
E->getType().withCVRQualifiers(Base.getQuals().getCVRQualifiers());

// Encode the element access list into a vector of unsigned indices.
SmallVector<uint32_t, 4> indices;
E->getEncodedElementAccess(indices);

if (Base.isSimple()) {
SmallVector<int64_t, 4> attrElts;
for (uint32_t i : indices) {
attrElts.push_back(static_cast<int64_t>(i));
}
auto elts = builder.getI64ArrayAttr(attrElts);
return LValue::MakeExtVectorElt(Base.getAddress(), elts, type,
Base.getBaseInfo());
}
assert(Base.isExtVectorElt() && "Can only subscript lvalue vec elts here!");

mlir::ArrayAttr baseElts = Base.getExtVectorElts();

// Composite the two indices
SmallVector<int64_t, 4> attrElts;
for (uint32_t i : indices) {
attrElts.push_back(getAccessedFieldNo(i, baseElts));
}
auto elts = builder.getI64ArrayAttr(attrElts);

return LValue::MakeExtVectorElt(Base.getExtVectorAddress(), elts, type,
Base.getBaseInfo());
}

LValue CIRGenFunction::buildBinaryOperatorLValue(const BinaryOperator *E) {
// Comma expressions just emit their LHS then their RHS as an l-value.
if (E->getOpcode() == BO_Comma) {
Expand Down Expand Up @@ -2263,6 +2452,8 @@ LValue CIRGenFunction::buildLValue(const Expr *E) {
return buildConditionalOperatorLValue(cast<ConditionalOperator>(E));
case Expr::ArraySubscriptExprClass:
return buildArraySubscriptExpr(cast<ArraySubscriptExpr>(E));
case Expr::ExtVectorElementExprClass:
return buildExtVectorElementExpr(cast<ExtVectorElementExpr>(E));
case Expr::BinaryOperatorClass:
return buildBinaryOperatorLValue(cast<BinaryOperator>(E));
case Expr::CompoundAssignOperatorClass: {
Expand Down
6 changes: 5 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,12 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
E->getSrcExpr()->getType(), E->getType(),
E->getSourceRange().getBegin());
}

mlir::Value VisitExtVectorElementExpr(Expr *E) {
return buildLoadOfLValue(E);
}

mlir::Value VisitMemberExpr(MemberExpr *E);
mlir::Value VisitExtVectorelementExpr(Expr *E) { llvm_unreachable("NYI"); }
mlir::Value VisitCompoundLiteralEpxr(CompoundLiteralExpr *E) {
llvm_unreachable("NYI");
}
Expand Down
7 changes: 7 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,12 @@ class CIRGenFunction : public CIRGenTypeCache {
mlir::Location Loc, LValueBaseInfo BaseInfo,
bool isNontemporal = false);

int64_t getAccessedFieldNo(unsigned idx, const mlir::ArrayAttr elts);

RValue buildLoadOfExtVectorElementLValue(LValue LV);

void buildStoreThroughExtVectorComponentLValue(RValue Src, LValue Dst);

RValue buildLoadOfBitfieldLValue(LValue LV, SourceLocation Loc);

/// Load a scalar value from an address, taking care to appropriately convert
Expand Down Expand Up @@ -1223,6 +1229,7 @@ class CIRGenFunction : public CIRGenTypeCache {
LValue lvalue, bool capturedByInit = false);

LValue buildDeclRefLValue(const clang::DeclRefExpr *E);
LValue buildExtVectorElementExpr(const ExtVectorElementExpr *E);
LValue buildBinaryOperatorLValue(const clang::BinaryOperator *E);
LValue buildCompoundAssignmentLValue(const clang::CompoundAssignOperator *E);
LValue buildUnaryOpLValue(const clang::UnaryOperator *E);
Expand Down
30 changes: 29 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@ class LValue {
unsigned Alignment;
mlir::Value V;
mlir::Type ElementType;
mlir::Value VectorIdx; // Index for vector subscript
mlir::Value VectorIdx; // Index for vector subscript
mlir::Attribute VectorElts; // ExtVector element subset: V.xyx
LValueBaseInfo BaseInfo;
const CIRGenBitFieldInfo *BitFieldInfo{0};

Expand Down Expand Up @@ -316,6 +317,20 @@ class LValue {
return VectorIdx;
}

// extended vector elements.
Address getExtVectorAddress() const {
assert(isExtVectorElt());
return Address(getExtVectorPointer(), ElementType, getAlignment());
}
mlir::Value getExtVectorPointer() const {
assert(isExtVectorElt());
return V;
}
mlir::ArrayAttr getExtVectorElts() const {
assert(isExtVectorElt());
return mlir::cast<mlir::ArrayAttr>(VectorElts);
}

static LValue MakeVectorElt(Address vecAddress, mlir::Value Index,
clang::QualType type, LValueBaseInfo BaseInfo) {
LValue R;
Expand All @@ -328,6 +343,19 @@ class LValue {
return R;
}

static LValue MakeExtVectorElt(Address vecAddress, mlir::ArrayAttr elts,
clang::QualType type,
LValueBaseInfo baseInfo) {
LValue R;
R.LVType = ExtVectorElt;
R.V = vecAddress.getPointer();
R.ElementType = vecAddress.getElementType();
R.VectorElts = elts;
R.Initialize(type, type.getQualifiers(), vecAddress.getAlignment(),
baseInfo);
return R;
}

// bitfield lvalue
Address getBitFieldAddress() const {
return Address(getBitFieldPointer(), ElementType, getAlignment());
Expand Down
Loading
Loading