Skip to content

Commit aa68966

Browse files
MrSidims0x12CC
authored andcommitted
[Backport to 17] Implement SPV_KHR_bfloat16 extension (KhronosGroup#3099)
The extension add translation from LLVM's bfloat type to OpTypeFloat %width% 16 %fp encoding% BFloat16KHR Mangling follows LLVM's rules for the type. Spec PR: KhronosGroup/SPIRV-Registry#323 --------- Signed-off-by: Sidorov, Dmitry <[email protected]> Co-authored-by: Aziz, Michael <[email protected]>
1 parent 218ebb5 commit aa68966

File tree

17 files changed

+219
-16
lines changed

17 files changed

+219
-16
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,4 @@ EXT(SPV_INTEL_maximum_registers)
7575
EXT(SPV_INTEL_bindless_images)
7676
EXT(SPV_INTEL_2d_block_io)
7777
EXT(SPV_INTEL_subgroup_matrix_multiply_accumulate)
78+
EXT(SPV_KHR_bfloat16)

lib/SPIRV/Mangler/ManglingUtils.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ static const char *PrimitiveNames[PRIMITIVE_NUM] = {
2828
"half",
2929
"float",
3030
"double",
31+
"__bf16",
3132
"void",
3233
"...",
3334
"image1d_ro_t",
@@ -105,6 +106,7 @@ const char *MangledTypes[PRIMITIVE_NUM] = {
105106
"Dh", // HALF
106107
"f", // FLOAT
107108
"d", // DOUBLE
109+
"u6__bf16", // __BF16
108110
"v", // VOID
109111
"z", // VarArg
110112
"14ocl_image1d_ro", // PRIMITIVE_IMAGE1D_RO_T
@@ -197,6 +199,7 @@ static const SPIRversion PrimitiveSupportedVersions[PRIMITIVE_NUM] = {
197199
SPIR12, // HALF
198200
SPIR12, // FLOAT
199201
SPIR12, // DOUBLE
202+
SPIR12, // __BF16
200203
SPIR12, // VOID
201204
SPIR12, // VarArg
202205
SPIR12, // PRIMITIVE_IMAGE1D_RO_T

lib/SPIRV/Mangler/ParameterType.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ enum TypePrimitiveEnum {
4545
PRIMITIVE_HALF,
4646
PRIMITIVE_FLOAT,
4747
PRIMITIVE_DOUBLE,
48+
PRIMITIVE_BFLOAT,
4849
PRIMITIVE_VOID,
4950
PRIMITIVE_VAR_ARG,
5051
PRIMITIVE_STRUCT_FIRST,

lib/SPIRV/SPIRVReader.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,8 @@ std::optional<uint64_t> SPIRVToLLVM::getAlignment(SPIRVValue *V) {
313313
Type *SPIRVToLLVM::transFPType(SPIRVType *T) {
314314
switch (T->getFloatBitWidth()) {
315315
case 16:
316+
if (T->isTypeFloat(16, FPEncodingBFloat16KHR))
317+
return Type::getBFloatTy(*Context);
316318
return Type::getHalfTy(*Context);
317319
case 32:
318320
return Type::getFloatTy(*Context);
@@ -1518,7 +1520,9 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
15181520
const llvm::fltSemantics *FS = nullptr;
15191521
switch (BT->getFloatBitWidth()) {
15201522
case 16:
1521-
FS = &APFloat::IEEEhalf();
1523+
FS =
1524+
(BT->isTypeFloat(16, FPEncodingBFloat16KHR) ? &APFloat::BFloat()
1525+
: &APFloat::IEEEhalf());
15221526
break;
15231527
case 32:
15241528
FS = &APFloat::IEEEsingle();

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,6 +1343,8 @@ static SPIR::RefParamType transTypeDesc(Type *Ty,
13431343
return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_FLOAT));
13441344
if (Ty->isDoubleTy())
13451345
return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_DOUBLE));
1346+
if (Ty->isBFloatTy())
1347+
return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_BFLOAT));
13461348
if (auto *VecTy = dyn_cast<FixedVectorType>(Ty)) {
13471349
return SPIR::RefParamType(new SPIR::VectorType(
13481350
transTypeDesc(VecTy->getElementType(), Info), VecTy->getNumElements()));

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,16 @@ SPIRVType *LLVMToSPIRVBase::transType(Type *T) {
322322
}
323323
}
324324

325+
if (T->isBFloatTy()) {
326+
BM->getErrorLog().checkError(
327+
BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_bfloat16),
328+
SPIRVEC_RequiresExtension,
329+
"SPV_KHR_bfloat16\n"
330+
"NOTE: LLVM module contains bfloat type, translation of which "
331+
"requires this extension");
332+
return mapType(T, BM->addFloatType(16, FPEncodingBFloat16KHR));
333+
}
334+
325335
if (T->isFloatingPointTy())
326336
return mapType(T, BM->addFloatType(T->getPrimitiveSizeInBits()));
327337

lib/SPIRV/libSPIRV/SPIRVEnum.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,9 @@ template <> inline void SPIRVMap<SPIRVCapabilityKind, SPIRVCapVec>::init() {
219219
{CapabilityCooperativeMatrixKHR});
220220
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixCheckedInstructionsINTEL,
221221
{CapabilityCooperativeMatrixKHR});
222+
ADD_VEC_INIT(CapabilityBFloat16DotProductKHR, {CapabilityBFloat16TypeKHR});
223+
ADD_VEC_INIT(CapabilityBFloat16CooperativeMatrixKHR,
224+
{CapabilityBFloat16TypeKHR, CapabilityCooperativeMatrixKHR});
222225
}
223226

224227
template <> inline void SPIRVMap<SPIRVExecutionModelKind, SPIRVCapVec>::init() {

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,18 @@ class SPIRVBinary : public SPIRVInstTemplateBase {
689689
return static_cast<SPIRVWord>(VersionNumber::SPIRV_1_4);
690690
return static_cast<SPIRVWord>(VersionNumber::SPIRV_1_0);
691691
}
692+
SPIRVCapVec getRequiredCapability() const override {
693+
if (OpCode == OpDot) {
694+
const SPIRVType *OpTy = getValueType(Ops[0]);
695+
if (OpTy && OpTy->isTypeVector()) {
696+
OpTy = OpTy->getVectorComponentType();
697+
if (OpTy && OpTy->isTypeFloat(16, FPEncodingBFloat16KHR)) {
698+
return getVec(CapabilityBFloat16DotProductKHR);
699+
}
700+
}
701+
}
702+
return SPIRVInstruction::getRequiredCapability();
703+
}
692704
};
693705

694706
template <Op OC>

lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,8 @@ class SPIRVModuleImpl : public SPIRVModule {
279279
template <class T> T *addType(T *Ty);
280280
SPIRVTypeArray *addArrayType(SPIRVType *, SPIRVConstant *) override;
281281
SPIRVTypeBool *addBoolType() override;
282-
SPIRVTypeFloat *addFloatType(unsigned BitWidth) override;
282+
SPIRVTypeFloat *addFloatType(unsigned BitWidth,
283+
unsigned FloatingPointEncoding) override;
283284
SPIRVTypeFunction *addFunctionType(SPIRVType *,
284285
const std::vector<SPIRVType *> &) override;
285286
SPIRVTypeInt *addIntegerType(unsigned BitWidth) override;
@@ -591,6 +592,8 @@ class SPIRVModuleImpl : public SPIRVModule {
591592
SPIRVCapMap CapMap;
592593
SPIRVUnknownStructFieldMap UnknownStructFieldMap;
593594
std::map<unsigned, SPIRVTypeInt *> IntTypeMap;
595+
SmallDenseMap<std::pair<unsigned, unsigned>, SPIRVTypeFloat *, 4>
596+
FloatTypeMap;
594597
std::map<unsigned, SPIRVConstant *> LiteralMap;
595598
std::vector<SPIRVExtInst *> DebugInstVec;
596599
std::vector<SPIRVExtInst *> AuxDataInstVec;
@@ -1004,9 +1007,15 @@ SPIRVTypeInt *SPIRVModuleImpl::addIntegerType(unsigned BitWidth) {
10041007
return addType(Ty);
10051008
}
10061009

1007-
SPIRVTypeFloat *SPIRVModuleImpl::addFloatType(unsigned BitWidth) {
1008-
SPIRVTypeFloat *T = addType(new SPIRVTypeFloat(this, getId(), BitWidth));
1009-
return T;
1010+
SPIRVTypeFloat *SPIRVModuleImpl::addFloatType(unsigned BitWidth,
1011+
unsigned FloatingPointEncoding) {
1012+
auto Desc = std::make_pair(BitWidth, FloatingPointEncoding);
1013+
auto Loc = FloatTypeMap.find(Desc);
1014+
if (Loc != FloatTypeMap.end())
1015+
return Loc->second;
1016+
auto *Ty = new SPIRVTypeFloat(this, getId(), BitWidth, FloatingPointEncoding);
1017+
FloatTypeMap[Desc] = Ty;
1018+
return addType(Ty);
10101019
}
10111020

10121021
SPIRVTypePointer *

lib/SPIRV/libSPIRV/SPIRVModule.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ class SPIRVModule {
236236
// Type creation functions
237237
virtual SPIRVTypeArray *addArrayType(SPIRVType *, SPIRVConstant *) = 0;
238238
virtual SPIRVTypeBool *addBoolType() = 0;
239-
virtual SPIRVTypeFloat *addFloatType(unsigned) = 0;
239+
virtual SPIRVTypeFloat *addFloatType(unsigned, unsigned = FPEncodingMax) = 0;
240240
virtual SPIRVTypeFunction *
241241
addFunctionType(SPIRVType *, const std::vector<SPIRVType *> &) = 0;
242242
virtual SPIRVTypeImage *addImageType(SPIRVType *,

0 commit comments

Comments
 (0)