Skip to content

Commit 6cd551e

Browse files
committed
extend to all cdna3
Signed-off-by: Eric Feng <Eric.Feng@amd.com>
1 parent 1818058 commit 6cd551e

4 files changed

Lines changed: 257 additions & 6 deletions

File tree

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,8 +1271,13 @@ getMNKShape(VirtualMMAIntrinsic type) {
12711271
return {32, 32, 16};
12721272
// Sparse trick VDMFMAs for skinny GEMMs: semantically 8x16xK.
12731273
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
1274+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
12741275
return {8, 16, 64};
12751276
case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1277+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1278+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1279+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1280+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ:
12761281
return {8, 16, 128};
12771282
}
12781283
assert(false && "unhandled virtual mma layout type.");
@@ -1282,6 +1287,8 @@ getMNKShape(VirtualMMAIntrinsic type) {
12821287
static std::tuple<Type, Type, Type>
12831288
getABCElementTypes(MLIRContext *context, VirtualMMAIntrinsic type) {
12841289
Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context);
1290+
Type f8E5M2FNUZ = Float8E5M2FNUZType::get(context);
1291+
Type bf16 = BFloat16Type::get(context);
12851292
Type f16 = Float16Type::get(context);
12861293
Type f32 = Float32Type::get(context);
12871294
Type i8 = IntegerType::get(context, 8);
@@ -1301,8 +1308,18 @@ getABCElementTypes(MLIRContext *context, VirtualMMAIntrinsic type) {
13011308
// Sparse trick VDMFMAs for skinny GEMMs.
13021309
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
13031310
return {f16, f16, f32};
1311+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
1312+
return {bf16, bf16, f32};
13041313
case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
13051314
return {i8, i8, i32};
1315+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1316+
return {f8E5M2FNUZ, f8E5M2FNUZ, f32};
1317+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1318+
return {f8E5M2FNUZ, f8E4M3FNUZ, f32};
1319+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1320+
return {f8E4M3FNUZ, f8E5M2FNUZ, f32};
1321+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ:
1322+
return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
13061323
}
13071324
assert(false && "unhandled virtual mma layout type.");
13081325
return {};
@@ -1375,7 +1392,12 @@ int64_t VirtualMMAAttr::getSubgroupSize() const {
13751392
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
13761393
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
13771394
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
1378-
case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8: {
1395+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
1396+
case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1397+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1398+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1399+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1400+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ: {
13791401
return 64;
13801402
}
13811403
}
@@ -1441,7 +1463,7 @@ static Value createLaneParityPredicate(OpBuilder &builder, Location loc) {
14411463
// 0x4 (0100b) -> positions {0,1}; 0xE (1110b) -> positions {2,3}.
14421464
//
14431465
// For 16-bit source data (f16/bf16): vector<4xi8>, 2 groups per i8.
1444-
// For 8-bit source data (i8): vector<2xi16>, 4 groups per i16.
1466+
// For 8-bit source data (i8/f8*): vector<2xi16>, 4 groups per i16.
14451467
//
14461468
// Only the first element carries active selector bits; remaining
14471469
// elements are padding zeros.
@@ -1462,7 +1484,12 @@ int64_t VirtualMMAAttr::getIntrinsicsK() const {
14621484
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
14631485
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
14641486
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
1465-
case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8: {
1487+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
1488+
case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1489+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1490+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1491+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1492+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ: {
14661493
return 2;
14671494
}
14681495
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
@@ -1649,7 +1676,8 @@ LogicalResult VirtualMMAAttr::buildUnderlyingOperations(
16491676
results.push_back(acc);
16501677
return success();
16511678
}
1652-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16: {
1679+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
1680+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16: {
16531681
if (getColMajor()) {
16541682
return failure();
16551683
}
@@ -1667,7 +1695,11 @@ LogicalResult VirtualMMAAttr::buildUnderlyingOperations(
16671695
{{0, 1, 8, 9, 2, 3, 10, 11}, {4, 5, 12, 13, 6, 7, 14, 15}}};
16681696
return buildVDMFMAOps(builder, loc, config, inputs, outputs[0], results);
16691697
}
1670-
case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8: {
1698+
case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1699+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1700+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1701+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1702+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ: {
16711703
if (getColMajor()) {
16721704
return failure();
16731705
}
@@ -1697,7 +1729,12 @@ int64_t VirtualMMAAttr::getBlockSize() const {
16971729
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
16981730
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
16991731
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
1700-
case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8: {
1732+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
1733+
case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1734+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1735+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1736+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1737+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ: {
17011738
return 1;
17021739
}
17031740
}
@@ -1764,6 +1801,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic,
17641801
/*element=*/{4, 1}};
17651802
}
17661803
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
1804+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
17671805
switch (operandIndex) {
17681806
case kMMAOperandLhs:
17691807
return {/*outer=*/{1, 1}, /*thread=*/{8, 4}, /*tstrides=*/{2, 16},
@@ -1776,6 +1814,10 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic,
17761814
/*element=*/{2, 1}};
17771815
}
17781816
case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1817+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1818+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1819+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1820+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ:
17791821
switch (operandIndex) {
17801822
case kMMAOperandLhs:
17811823
return {/*outer=*/{1, 1}, /*thread=*/{8, 4}, /*tstrides=*/{2, 16},

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,15 +375,27 @@ def VMFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"VMFMA_F32_16x16x32_F8E4M3F
375375
def VMFMA_F32_32x32x16_F8E4M3FNUZ : I32EnumAttrCase<"VMFMA_F32_32x32x16_F8E4M3FNUZ", 3>;
376376
def VDMFMA_F32_8x16x64_F16 : I32EnumAttrCase<"VDMFMA_F32_8x16x64_F16", 4>;
377377
def VDMFMA_I32_8x16x128_I8 : I32EnumAttrCase<"VDMFMA_I32_8x16x128_I8", 5>;
378+
def VDMFMA_F32_8x16x64_BF16 : I32EnumAttrCase<"VDMFMA_F32_8x16x64_BF16", 6>;
379+
def VDMFMA_F32_8x16x128_F8E5M2FNUZ : I32EnumAttrCase<"VDMFMA_F32_8x16x128_F8E5M2FNUZ", 7>;
380+
def VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ : I32EnumAttrCase<"VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ", 8>;
381+
def VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ : I32EnumAttrCase<"VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ", 9>;
382+
def VDMFMA_F32_8x16x128_F8E4M3FNUZ : I32EnumAttrCase<"VDMFMA_F32_8x16x128_F8E4M3FNUZ", 10>;
378383

379384
def IREEGPU_VirtualMMAIntrinsic : IREEGPU_I32EnumAttr<"VirtualMMAIntrinsic",
380385
"Descriptor for different Virtual MMA intrinsics", [
381386
VMFMA_F32_16x16x32_F16,
382387
VMFMA_F32_32x32x16_F16,
383388
VMFMA_F32_16x16x32_F8E4M3FNUZ,
384389
VMFMA_F32_32x32x16_F8E4M3FNUZ,
390+
// 16-bit VDMFMA variants.
385391
VDMFMA_F32_8x16x64_F16,
392+
VDMFMA_F32_8x16x64_BF16,
393+
// 8-bit VDMFMA variants.
386394
VDMFMA_I32_8x16x128_I8,
395+
VDMFMA_F32_8x16x128_F8E5M2FNUZ,
396+
VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ,
397+
VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ,
398+
VDMFMA_F32_8x16x128_F8E4M3FNUZ,
387399
]>;
388400

389401
// Enum for scaled mma intrinsic, loosely matching the MMAIntrinsic enum above

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_attrs.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,51 @@ module {
5454
// CHECK-LABEL: func @test_vdmfma_i8_8x16x128
5555
// CHECK-SAME: mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_I32_8x16x128_I8>
5656

57+
module {
58+
func.func @test_vdmfma_bf16_8x16x64() attributes {
59+
mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x64_BF16>} {
60+
return
61+
}
62+
}
63+
// CHECK-LABEL: func @test_vdmfma_bf16_8x16x64
64+
// CHECK-SAME: mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x64_BF16>
65+
66+
module {
67+
func.func @test_vdmfma_f8E5M2FNUZ_8x16x128() attributes {
68+
mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_F8E5M2FNUZ>} {
69+
return
70+
}
71+
}
72+
// CHECK-LABEL: func @test_vdmfma_f8E5M2FNUZ_8x16x128
73+
// CHECK-SAME: mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_F8E5M2FNUZ>
74+
75+
module {
76+
func.func @test_vdmfma_f8E5M2FNUZ_f8E4M3FNUZ_8x16x128() attributes {
77+
mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ>} {
78+
return
79+
}
80+
}
81+
// CHECK-LABEL: func @test_vdmfma_f8E5M2FNUZ_f8E4M3FNUZ_8x16x128
82+
// CHECK-SAME: mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ>
83+
84+
module {
85+
func.func @test_vdmfma_f8E4M3FNUZ_f8E5M2FNUZ_8x16x128() attributes {
86+
mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ>} {
87+
return
88+
}
89+
}
90+
// CHECK-LABEL: func @test_vdmfma_f8E4M3FNUZ_f8E5M2FNUZ_8x16x128
91+
// CHECK-SAME: mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ>
92+
93+
module {
94+
func.func @test_vdmfma_f8E4M3FNUZ_8x16x128() attributes {
95+
mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_F8E4M3FNUZ>} {
96+
return
97+
}
98+
}
99+
// CHECK-LABEL: func @test_vdmfma_f8E4M3FNUZ_8x16x128
100+
// CHECK-SAME: mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_F8E4M3FNUZ>
101+
57102
module {
58103
func.func @test_WMMAR3_f16_16x16x16_f32() attributes {
59104
mma_types = #iree_gpu.mma_layout<WMMAR3_F32_16x16x16_F16>} {

0 commit comments

Comments
 (0)