@@ -1271,8 +1271,13 @@ getMNKShape(VirtualMMAIntrinsic type) {
12711271 return {32 , 32 , 16 };
12721272 // Sparse trick VDMFMAs for skinny GEMMs: semantically 8x16xK.
12731273 case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
1274+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
12741275 return {8 , 16 , 64 };
12751276 case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1277+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1278+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1279+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1280+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ:
12761281 return {8 , 16 , 128 };
12771282 }
12781283 assert (false && " unhandled virtual mma layout type." );
@@ -1282,6 +1287,8 @@ getMNKShape(VirtualMMAIntrinsic type) {
12821287static std::tuple<Type, Type, Type>
12831288getABCElementTypes (MLIRContext *context, VirtualMMAIntrinsic type) {
12841289 Type f8E4M3FNUZ = Float8E4M3FNUZType::get (context);
1290+ Type f8E5M2FNUZ = Float8E5M2FNUZType::get (context);
1291+ Type bf16 = BFloat16Type::get (context);
12851292 Type f16 = Float16Type::get (context);
12861293 Type f32 = Float32Type::get (context);
12871294 Type i8 = IntegerType::get (context, 8 );
@@ -1301,8 +1308,18 @@ getABCElementTypes(MLIRContext *context, VirtualMMAIntrinsic type) {
13011308 // Sparse trick VDMFMAs for skinny GEMMs.
13021309 case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
13031310 return {f16 , f16 , f32 };
1311+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
1312+ return {bf16 , bf16 , f32 };
13041313 case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
13051314 return {i8 , i8 , i32 };
1315+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1316+ return {f8E5M2FNUZ, f8E5M2FNUZ, f32 };
1317+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1318+ return {f8E5M2FNUZ, f8E4M3FNUZ, f32 };
1319+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1320+ return {f8E4M3FNUZ, f8E5M2FNUZ, f32 };
1321+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ:
1322+ return {f8E4M3FNUZ, f8E4M3FNUZ, f32 };
13061323 }
13071324 assert (false && " unhandled virtual mma layout type." );
13081325 return {};
@@ -1375,7 +1392,12 @@ int64_t VirtualMMAAttr::getSubgroupSize() const {
13751392 case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
13761393 case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
13771394 case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
1378- case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8: {
1395+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
1396+ case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1397+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1398+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1399+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1400+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ: {
13791401 return 64 ;
13801402 }
13811403 }
@@ -1441,7 +1463,7 @@ static Value createLaneParityPredicate(OpBuilder &builder, Location loc) {
14411463// 0x4 (0100b) -> positions {0,1}; 0xE (1110b) -> positions {2,3}.
14421464//
14431465// For 16-bit source data (f16/bf16): vector<4xi8>, 2 groups per i8.
1444- // For 8-bit source data (i8): vector<2xi16>, 4 groups per i16.
1466+ // For 8-bit source data (i8/f8* ): vector<2xi16>, 4 groups per i16.
14451467//
14461468// Only the first element carries active selector bits; remaining
14471469// elements are padding zeros.
@@ -1462,7 +1484,12 @@ int64_t VirtualMMAAttr::getIntrinsicsK() const {
14621484 case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
14631485 case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
14641486 case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
1465- case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8: {
1487+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
1488+ case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1489+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1490+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1491+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1492+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ: {
14661493 return 2 ;
14671494 }
14681495 case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
@@ -1649,7 +1676,8 @@ LogicalResult VirtualMMAAttr::buildUnderlyingOperations(
16491676 results.push_back (acc);
16501677 return success ();
16511678 }
1652- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16: {
1679+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
1680+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16: {
16531681 if (getColMajor ()) {
16541682 return failure ();
16551683 }
@@ -1667,7 +1695,11 @@ LogicalResult VirtualMMAAttr::buildUnderlyingOperations(
16671695 {{0 , 1 , 8 , 9 , 2 , 3 , 10 , 11 }, {4 , 5 , 12 , 13 , 6 , 7 , 14 , 15 }}};
16681696 return buildVDMFMAOps (builder, loc, config, inputs, outputs[0 ], results);
16691697 }
1670- case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8: {
1698+ case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1699+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1700+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1701+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1702+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ: {
16711703 if (getColMajor ()) {
16721704 return failure ();
16731705 }
@@ -1697,7 +1729,12 @@ int64_t VirtualMMAAttr::getBlockSize() const {
16971729 case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
16981730 case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
16991731 case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
1700- case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8: {
1732+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
1733+ case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1734+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1735+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1736+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1737+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ: {
17011738 return 1 ;
17021739 }
17031740 }
@@ -1764,6 +1801,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic,
17641801 /* element=*/ {4 , 1 }};
17651802 }
17661803 case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
1804+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
17671805 switch (operandIndex) {
17681806 case kMMAOperandLhs :
17691807 return {/* outer=*/ {1 , 1 }, /* thread=*/ {8 , 4 }, /* tstrides=*/ {2 , 16 },
@@ -1776,6 +1814,10 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic,
17761814 /* element=*/ {2 , 1 }};
17771815 }
17781816 case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1817+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1818+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1819+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1820+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ:
17791821 switch (operandIndex) {
17801822 case kMMAOperandLhs :
17811823 return {/* outer=*/ {1 , 1 }, /* thread=*/ {8 , 4 }, /* tstrides=*/ {2 , 16 },
0 commit comments