Skip to content

Commit e44dbd7

Browse files
MrSidims0x12CC
authored andcommitted
[Backport to 18] Implement SPV_KHR_bfloat16 extension (KhronosGroup#3099)
The extension add translation from LLVM's bfloat type to OpTypeFloat %width% 16 %fp encoding% BFloat16KHR Mangling follows LLVM's rules for the type. Spec PR: KhronosGroup/SPIRV-Registry#323 --------- Signed-off-by: Sidorov, Dmitry <[email protected]> Co-authored-by: Aziz, Michael <[email protected]>
1 parent 1e493e3 commit e44dbd7

File tree

17 files changed

+216
-18
lines changed

17 files changed

+216
-18
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,4 @@ EXT(SPV_INTEL_maximum_registers)
7575
EXT(SPV_INTEL_bindless_images)
7676
EXT(SPV_INTEL_2d_block_io)
7777
EXT(SPV_INTEL_subgroup_matrix_multiply_accumulate)
78+
EXT(SPV_KHR_bfloat16)

lib/SPIRV/Mangler/ManglingUtils.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ static const char *PrimitiveNames[PRIMITIVE_NUM] = {
2828
"half",
2929
"float",
3030
"double",
31+
"__bf16",
3132
"void",
3233
"...",
3334
"image1d_ro_t",
@@ -105,6 +106,7 @@ const char *MangledTypes[PRIMITIVE_NUM] = {
105106
"Dh", // HALF
106107
"f", // FLOAT
107108
"d", // DOUBLE
109+
"u6__bf16", // __BF16
108110
"v", // VOID
109111
"z", // VarArg
110112
"14ocl_image1d_ro", // PRIMITIVE_IMAGE1D_RO_T
@@ -197,6 +199,7 @@ static const SPIRversion PrimitiveSupportedVersions[PRIMITIVE_NUM] = {
197199
SPIR12, // HALF
198200
SPIR12, // FLOAT
199201
SPIR12, // DOUBLE
202+
SPIR12, // __BF16
200203
SPIR12, // VOID
201204
SPIR12, // VarArg
202205
SPIR12, // PRIMITIVE_IMAGE1D_RO_T

lib/SPIRV/Mangler/ParameterType.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ enum TypePrimitiveEnum {
4545
PRIMITIVE_HALF,
4646
PRIMITIVE_FLOAT,
4747
PRIMITIVE_DOUBLE,
48+
PRIMITIVE_BFLOAT,
4849
PRIMITIVE_VOID,
4950
PRIMITIVE_VAR_ARG,
5051
PRIMITIVE_STRUCT_FIRST,

lib/SPIRV/SPIRVReader.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,8 @@ std::optional<uint64_t> SPIRVToLLVM::getAlignment(SPIRVValue *V) {
313313
Type *SPIRVToLLVM::transFPType(SPIRVType *T) {
314314
switch (T->getFloatBitWidth()) {
315315
case 16:
316+
if (T->isTypeFloat(16, FPEncodingBFloat16KHR))
317+
return Type::getBFloatTy(*Context);
316318
return Type::getHalfTy(*Context);
317319
case 32:
318320
return Type::getFloatTy(*Context);
@@ -1521,7 +1523,9 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
15211523
const llvm::fltSemantics *FS = nullptr;
15221524
switch (BT->getFloatBitWidth()) {
15231525
case 16:
1524-
FS = &APFloat::IEEEhalf();
1526+
FS =
1527+
(BT->isTypeFloat(16, FPEncodingBFloat16KHR) ? &APFloat::BFloat()
1528+
: &APFloat::IEEEhalf());
15251529
break;
15261530
case 32:
15271531
FS = &APFloat::IEEEsingle();

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,6 +1330,8 @@ static SPIR::RefParamType transTypeDesc(Type *Ty,
13301330
return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_FLOAT));
13311331
if (Ty->isDoubleTy())
13321332
return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_DOUBLE));
1333+
if (Ty->isBFloatTy())
1334+
return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_BFLOAT));
13331335
if (auto *VecTy = dyn_cast<FixedVectorType>(Ty)) {
13341336
return SPIR::RefParamType(new SPIR::VectorType(
13351337
transTypeDesc(VecTy->getElementType(), Info), VecTy->getNumElements()));

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,16 @@ SPIRVType *LLVMToSPIRVBase::transType(Type *T) {
397397
}
398398
}
399399

400+
if (T->isBFloatTy()) {
401+
BM->getErrorLog().checkError(
402+
BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_bfloat16),
403+
SPIRVEC_RequiresExtension,
404+
"SPV_KHR_bfloat16\n"
405+
"NOTE: LLVM module contains bfloat type, translation of which "
406+
"requires this extension");
407+
return mapType(T, BM->addFloatType(16, FPEncodingBFloat16KHR));
408+
}
409+
400410
if (T->isFloatingPointTy())
401411
return mapType(T, BM->addFloatType(T->getPrimitiveSizeInBits()));
402412

lib/SPIRV/libSPIRV/SPIRVEnum.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,9 @@ template <> inline void SPIRVMap<SPIRVCapabilityKind, SPIRVCapVec>::init() {
221221
{CapabilityCooperativeMatrixKHR});
222222
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixCheckedInstructionsINTEL,
223223
{CapabilityCooperativeMatrixKHR});
224+
ADD_VEC_INIT(CapabilityBFloat16DotProductKHR, {CapabilityBFloat16TypeKHR});
225+
ADD_VEC_INIT(CapabilityBFloat16CooperativeMatrixKHR,
226+
{CapabilityBFloat16TypeKHR, CapabilityCooperativeMatrixKHR});
224227
}
225228

226229
template <> inline void SPIRVMap<SPIRVExecutionModelKind, SPIRVCapVec>::init() {

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,18 @@ class SPIRVBinary : public SPIRVInstTemplateBase {
693693
return static_cast<SPIRVWord>(VersionNumber::SPIRV_1_4);
694694
return static_cast<SPIRVWord>(VersionNumber::SPIRV_1_0);
695695
}
696+
SPIRVCapVec getRequiredCapability() const override {
697+
if (OpCode == OpDot) {
698+
const SPIRVType *OpTy = getValueType(Ops[0]);
699+
if (OpTy && OpTy->isTypeVector()) {
700+
OpTy = OpTy->getVectorComponentType();
701+
if (OpTy && OpTy->isTypeFloat(16, FPEncodingBFloat16KHR)) {
702+
return getVec(CapabilityBFloat16DotProductKHR);
703+
}
704+
}
705+
}
706+
return SPIRVInstruction::getRequiredCapability();
707+
}
696708
};
697709

698710
template <Op OC>

lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,8 @@ class SPIRVModuleImpl : public SPIRVModule {
254254
template <class T> T *addType(T *Ty);
255255
SPIRVTypeArray *addArrayType(SPIRVType *, SPIRVConstant *) override;
256256
SPIRVTypeBool *addBoolType() override;
257-
SPIRVTypeFloat *addFloatType(unsigned BitWidth) override;
257+
SPIRVTypeFloat *addFloatType(unsigned BitWidth,
258+
unsigned FloatingPointEncoding) override;
258259
SPIRVTypeFunction *addFunctionType(SPIRVType *,
259260
const std::vector<SPIRVType *> &) override;
260261
SPIRVTypeInt *addIntegerType(unsigned BitWidth) override;
@@ -568,7 +569,8 @@ class SPIRVModuleImpl : public SPIRVModule {
568569
SPIRVTypeBool *BoolTy;
569570
SPIRVTypeVoid *VoidTy;
570571
SmallDenseMap<unsigned, SPIRVTypeInt *, 4> IntTypeMap;
571-
SmallDenseMap<unsigned, SPIRVTypeFloat *, 4> FloatTypeMap;
572+
SmallDenseMap<std::pair<unsigned, unsigned>, SPIRVTypeFloat *, 4>
573+
FloatTypeMap;
572574
std::map<unsigned, SPIRVConstant *> LiteralMap;
573575
std::vector<SPIRVExtInst *> DebugInstVec;
574576
std::vector<SPIRVExtInst *> AuxDataInstVec;
@@ -993,12 +995,14 @@ SPIRVTypeInt *SPIRVModuleImpl::addIntegerType(unsigned BitWidth) {
993995
return addType(Ty);
994996
}
995997

996-
SPIRVTypeFloat *SPIRVModuleImpl::addFloatType(unsigned BitWidth) {
997-
auto Loc = FloatTypeMap.find(BitWidth);
998+
SPIRVTypeFloat *SPIRVModuleImpl::addFloatType(unsigned BitWidth,
999+
unsigned FloatingPointEncoding) {
1000+
auto Desc = std::make_pair(BitWidth, FloatingPointEncoding);
1001+
auto Loc = FloatTypeMap.find(Desc);
9981002
if (Loc != FloatTypeMap.end())
9991003
return Loc->second;
1000-
auto *Ty = new SPIRVTypeFloat(this, getId(), BitWidth);
1001-
FloatTypeMap[BitWidth] = Ty;
1004+
auto *Ty = new SPIRVTypeFloat(this, getId(), BitWidth, FloatingPointEncoding);
1005+
FloatTypeMap[Desc] = Ty;
10021006
return addType(Ty);
10031007
}
10041008

lib/SPIRV/libSPIRV/SPIRVModule.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ class SPIRVModule {
242242
// Type creation functions
243243
virtual SPIRVTypeArray *addArrayType(SPIRVType *, SPIRVConstant *) = 0;
244244
virtual SPIRVTypeBool *addBoolType() = 0;
245-
virtual SPIRVTypeFloat *addFloatType(unsigned) = 0;
245+
virtual SPIRVTypeFloat *addFloatType(unsigned, unsigned = FPEncodingMax) = 0;
246246
virtual SPIRVTypeFunction *
247247
addFunctionType(SPIRVType *, const std::vector<SPIRVType *> &) = 0;
248248
virtual SPIRVTypeImage *addImageType(SPIRVType *,

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,9 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
640640
add(CapabilityFPMaxErrorINTEL, "FPMaxErrorINTEL");
641641
add(CapabilityCacheControlsINTEL, "CacheControlsINTEL");
642642
add(CapabilityRegisterLimitsINTEL, "RegisterLimitsINTEL");
643+
add(CapabilityBFloat16TypeKHR, "BFloat16TypeKHR");
644+
add(CapabilityBFloat16DotProductKHR, "BFloat16DotProductKHR");
645+
add(CapabilityBFloat16CooperativeMatrixKHR, "BFloat16CooperativeMatrixKHR");
643646
add(CapabilityArithmeticFenceEXT, "ArithmeticFenceEXT");
644647
add(CapabilitySubgroup2DBlockIOINTEL, "Subgroup2DBlockIOINTEL");
645648
add(CapabilitySubgroup2DBlockTransformINTEL, "Subgroup2DBlockTransformINTEL");

lib/SPIRV/libSPIRV/SPIRVType.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,15 @@ bool SPIRVType::isTypeComposite() const {
162162
isTypeJointMatrixINTEL() || isTypeCooperativeMatrixKHR();
163163
}
164164

165-
bool SPIRVType::isTypeFloat(unsigned Bits) const {
166-
return isType<SPIRVTypeFloat>(this, Bits);
165+
bool SPIRVType::isTypeFloat(unsigned Bits,
166+
unsigned FloatingPointEncoding) const {
167+
if (!isType<SPIRVTypeFloat>(this))
168+
return false;
169+
if (Bits == 0)
170+
return true;
171+
const auto *ThisFloat = static_cast<const SPIRVTypeFloat *>(this);
172+
return ThisFloat->getBitWidth() == Bits &&
173+
ThisFloat->getFloatingPointEncoding() == FloatingPointEncoding;
167174
}
168175

169176
bool SPIRVType::isTypeOCLImage() const {

lib/SPIRV/libSPIRV/SPIRVType.h

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ class SPIRVType : public SPIRVEntry {
8484
bool isTypeEvent() const;
8585
bool isTypeDeviceEvent() const;
8686
bool isTypeReserveId() const;
87-
bool isTypeFloat(unsigned Bits = 0) const;
87+
bool isTypeFloat(unsigned Bits = 0,
88+
unsigned FloatingPointEncoding = FPEncodingMax) const;
8889
bool isTypeImage() const;
8990
bool isTypeOCLImage() const;
9091
bool isTypePipe() const;
@@ -201,16 +202,31 @@ class SPIRVTypeFloat : public SPIRVType {
201202
public:
202203
static const Op OC = OpTypeFloat;
203204
// Complete constructor
204-
SPIRVTypeFloat(SPIRVModule *M, SPIRVId TheId, unsigned TheBitWidth)
205-
: SPIRVType(M, 3, OC, TheId), BitWidth(TheBitWidth) {}
205+
SPIRVTypeFloat(SPIRVModule *M, SPIRVId TheId, unsigned TheBitWidth,
206+
unsigned TheFloatingPointEncoding)
207+
: SPIRVType(M, 3 + (TheFloatingPointEncoding != FPEncodingMax), OC,
208+
TheId),
209+
BitWidth(TheBitWidth), FloatingPointEncoding(TheFloatingPointEncoding) {
210+
}
206211
// Incomplete constructor
207-
SPIRVTypeFloat() : SPIRVType(OC), BitWidth(0) {}
212+
SPIRVTypeFloat()
213+
: SPIRVType(OC), BitWidth(0), FloatingPointEncoding(FPEncodingMax) {}
208214

209215
unsigned getBitWidth() const { return BitWidth; }
210216

217+
unsigned getFloatingPointEncoding() const { return FloatingPointEncoding; }
218+
219+
std::optional<ExtensionID> getRequiredExtension() const override {
220+
if (isTypeFloat(16, FPEncodingBFloat16KHR))
221+
return ExtensionID::SPV_KHR_bfloat16;
222+
return {};
223+
}
224+
211225
SPIRVCapVec getRequiredCapability() const override {
212226
SPIRVCapVec CV;
213-
if (isTypeFloat(16)) {
227+
if (isTypeFloat(16, FPEncodingBFloat16KHR)) {
228+
CV.push_back(CapabilityBFloat16TypeKHR);
229+
} else if (isTypeFloat(16)) {
214230
CV.push_back(CapabilityFloat16Buffer);
215231
auto Extensions = getModule()->getSourceExtension();
216232
if (std::any_of(Extensions.begin(), Extensions.end(),
@@ -222,14 +238,34 @@ class SPIRVTypeFloat : public SPIRVType {
222238
}
223239

224240
protected:
225-
_SPIRV_DEF_ENCDEC2(Id, BitWidth)
241+
void encode(spv_ostream &O) const override {
242+
assert(WordCount == 3 || WordCount == 4);
243+
auto Encoder = getEncoder(O);
244+
Encoder << Id << BitWidth;
245+
if (WordCount == 4)
246+
Encoder << FloatingPointEncoding;
247+
}
248+
249+
void decode(std::istream &I) override {
250+
assert(WordCount == 3 || WordCount == 4);
251+
auto Decoder = getDecoder(I);
252+
Decoder >> Id >> BitWidth;
253+
if (WordCount == 4)
254+
Decoder >> FloatingPointEncoding;
255+
}
256+
226257
void validate() const override {
227258
SPIRVEntry::validate();
228259
assert(BitWidth >= 16 && BitWidth <= 64 && "Invalid bit width");
260+
assert(
261+
(FloatingPointEncoding == FPEncodingMax ||
262+
(BitWidth == 16 && FloatingPointEncoding == FPEncodingBFloat16KHR)) &&
263+
"Invalid floating point encoding");
229264
}
230265

231266
private:
232267
unsigned BitWidth; // Bit width
268+
unsigned FloatingPointEncoding;
233269
};
234270

235271
class SPIRVTypePointer : public SPIRVType {
@@ -1142,7 +1178,10 @@ class SPIRVTypeCooperativeMatrixKHR : public SPIRVType {
11421178
return ExtensionID::SPV_KHR_cooperative_matrix;
11431179
}
11441180
SPIRVCapVec getRequiredCapability() const override {
1145-
return getVec(CapabilityCooperativeMatrixKHR);
1181+
auto CV = getVec(CapabilityCooperativeMatrixKHR);
1182+
if (CompType->isTypeFloat(16, FPEncodingBFloat16KHR))
1183+
CV.push_back(CapabilityBFloat16CooperativeMatrixKHR);
1184+
return CV;
11461185
}
11471186

11481187
SPIRVType *getCompType() const { return CompType; }
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_bfloat16 -o %t.spv
3+
; RUN: llvm-spirv %t.spv -to-text -o - | FileCheck %s --check-prefix=CHECK-SPIRV
4+
5+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
6+
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefix=CHECK-LLVM
7+
8+
; RUN: not llvm-spirv %t.bc 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
9+
10+
; CHECK-ERROR: RequiresExtension: Feature requires the following SPIR-V extension:
11+
; CHECK-ERROR-NEXT: SPV_KHR_bfloat16
12+
; CHECK-ERROR-NEXT: NOTE: LLVM module contains bfloat type, translation of which
13+
; CHECK-ERROR-SAME: requires this extension
14+
15+
source_filename = "bfloat16.cpp"
16+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
17+
target triple = "spirv64-unknown-unknown"
18+
19+
; CHECK-SPIRV-DAG: Capability BFloat16TypeKHR
20+
; CHECK-SPIRV-DAG: Extension "SPV_KHR_bfloat16"
21+
; CHECK-SPIRV: 4 TypeFloat [[BFLOAT:[0-9]+]] 16 0
22+
; CHECK-SPIRV: 4 TypeVector [[#]] [[BFLOAT]] 2
23+
24+
; CHECK-LLVM: [[ADDR1:]] = alloca bfloat
25+
; CHECK-LLVM: [[ADDR2:]] = alloca <2 x bfloat>
26+
; CHECK-LLVM: [[DATA1:]] = load bfloat, ptr [[ADDR1]]
27+
; CHECK-LLVM: [[DATA2:]] = load <2 x bfloat>, ptr [[ADDR2]]
28+
29+
define spir_kernel void @test() {
30+
entry:
31+
%addr1 = alloca bfloat
32+
%addr2 = alloca <2 x bfloat>
33+
%data1 = load bfloat, ptr %addr1
34+
%data2 = load <2 x bfloat>, ptr %addr2
35+
ret void
36+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_bfloat16 -o %t.spv
3+
; RUN: llvm-spirv %t.spv -to-text -o - | FileCheck %s --check-prefix=CHECK-SPIRV
4+
5+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
6+
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefix=CHECK-LLVM
7+
8+
source_filename = "bfloat16_dot.cpp"
9+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
10+
target triple = "spirv64-unknown-unknown"
11+
12+
; CHECK-SPIRV-DAG: Capability BFloat16TypeKHR
13+
; CHECK-SPIRV-DAG: Capability BFloat16DotProductKHR
14+
; CHECK-SPIRV-DAG: Extension "SPV_KHR_bfloat16"
15+
; CHECK-SPIRV: 4 TypeFloat [[BFLOAT:[0-9]+]] 16 0
16+
; CHECK-SPIRV: 4 TypeVector [[#]] [[BFLOAT]] 2
17+
; CHECK-SPIRV: Dot
18+
19+
; CHECK-LLVM: %addrA = alloca <2 x bfloat>
20+
; CHECK-LLVM: %addrB = alloca <2 x bfloat>
21+
; CHECK-LLVM: %dataA = load <2 x bfloat>, ptr %addrA
22+
; CHECK-LLVM: %dataB = load <2 x bfloat>, ptr %addrB
23+
; CHECK-LLVM: %call = call spir_func bfloat @_Z3dotDv2_u6__bf16S_(<2 x bfloat> %dataA, <2 x bfloat> %dataB)
24+
25+
declare spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat>, <2 x bfloat>)
26+
27+
define spir_kernel void @test() {
28+
entry:
29+
%addrA = alloca <2 x bfloat>
30+
%addrB = alloca <2 x bfloat>
31+
%dataA = load <2 x bfloat>, ptr %addrA
32+
%dataB = load <2 x bfloat>, ptr %addrB
33+
%call = call spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat> %dataA, <2 x bfloat> %dataB)
34+
ret void
35+
}
36+
37+
!opencl.ocl.version = !{!7}
38+
39+
!7 = !{i32 2, i32 0}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_KHR_bfloat16 -o %t.spv
3+
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
6+
; RUN: llvm-spirv --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_KHR_bfloat16 -r %t.spv -o %t.rev.bc
7+
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
8+
9+
; CHECK-SPIRV-DAG: Capability CooperativeMatrixKHR
10+
; CHECK-SPIRV-DAG: Capability BFloat16TypeKHR
11+
; CHECK-SPIRV-DAG: Capability BFloat16CooperativeMatrixKHR
12+
; CHECK-SPIRV-DAG: Extension "SPV_KHR_cooperative_matrix"
13+
; CHECK-SPIRV-DAG: Extension "SPV_KHR_bfloat16"
14+
15+
; CHECK-SPIRV-DAG: 4 TypeFloat [[#BFloatTy:]] 16 0
16+
; CHECK-SPIRV-DAG: TypeInt [[#Int32Ty:]] 32 0
17+
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const12:]] 12
18+
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const3:]] 3
19+
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const2:]] 2
20+
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy:]] [[#BFloatTy]] [[#Const3]] [[#Const12]] [[#Const12]] [[#Const2]]
21+
; CHECK-SPIRV-DAG: Constant [[#BFloatTy]] [[#]] 16256
22+
; CHECK-SPIRV: CompositeConstruct [[#MatTy]]
23+
24+
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructu6__bf16(bfloat 0xR3F80)
25+
26+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
27+
target triple = "spir64-unknown-unknown"
28+
29+
declare spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructu6__bf16(bfloat)
30+
31+
define spir_kernel void @test() {
32+
%mat = call spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructu6__bf16(bfloat 1.0)
33+
ret void
34+
}

0 commit comments

Comments
 (0)