@@ -754,7 +754,7 @@ OpFoldResult MMAAttr::getDistributionWorkerCount(OpBuilder &, Location,
754754 return getAsIndexOpFoldResult (getContext (), getSubgroupSize ());
755755}
756756
757- // Get virtual intrinsics that is composed/based on queried op.
757+ // Returns virtual intrinsics that are composed from this concrete MMA op.
758758SmallVector<VirtualMMAIntrinsic> MMAAttr::getVirtualIntrinsics () const {
759759 switch (getIntrinsic ()) {
760760 case MMAIntrinsic::MFMA_F32_16x16x16_F16:
@@ -1269,15 +1269,15 @@ getMNKShape(VirtualMMAIntrinsic type) {
12691269 case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
12701270 case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
12711271 return {32 , 32 , 16 };
1272- // Sparse trick VDMFMAs for skinny GEMMs.
1272+ // Sparse trick VDMFMAs for skinny GEMMs: semantically 8x16xK .
12731273 case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
12741274 case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
12751275 return {8 , 16 , 64 };
12761276 case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1277- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8 :
1278- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8_FP8 :
1279- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8_BF8 :
1280- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8 :
1277+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ :
1278+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ :
1279+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ :
1280+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ :
12811281 return {8 , 16 , 128 };
12821282 }
12831283 assert (false && " unhandled virtual mma layout type." );
@@ -1312,13 +1312,13 @@ getABCElementTypes(MLIRContext *context, VirtualMMAIntrinsic type) {
13121312 return {bf16 , bf16 , f32 };
13131313 case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
13141314 return {i8 , i8 , i32 };
1315- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8 :
1315+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ :
13161316 return {f8E5M2FNUZ, f8E5M2FNUZ, f32 };
1317- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8_FP8 :
1317+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ :
13181318 return {f8E5M2FNUZ, f8E4M3FNUZ, f32 };
1319- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8_BF8 :
1319+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ :
13201320 return {f8E4M3FNUZ, f8E5M2FNUZ, f32 };
1321- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8 :
1321+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ :
13221322 return {f8E4M3FNUZ, f8E4M3FNUZ, f32 };
13231323 }
13241324 assert (false && " unhandled virtual mma layout type." );
@@ -1374,15 +1374,29 @@ void VirtualMMAAttr::getDistributedTileTypes(
13741374 SmallVectorImpl<VectorType> &result) const {
13751375 MLIRContext *context = getContext ();
13761376 VirtualMMAIntrinsic intrinsic = getIntrinsic ();
1377+ // VDMFMA layouts pair adjacent lanes to emulate a wider tile, so
1378+ // threadProduct < subgroupSize (broadcastFactor > 1). We need
1379+ // getPerLaneElements to divide out the broadcast and compute the correct
1380+ // per-lane element count.
1381+ auto lhsLayout = getSingleSubgroupLayout (intrinsic, kMMAOperandLhs );
13771382 int64_t subgroupSize = getSubgroupSize ();
1378- OpaqueMmaLayout o = getOpaqueMMALayout (context, intrinsic);
1379- auto lhs = getSingleSubgroupLayout (intrinsic, kMMAOperandLhs );
1380- auto rhs = getSingleSubgroupLayout (intrinsic, kMMAOperandRhs );
1381- auto acc = getSingleSubgroupLayout (intrinsic, kMMAOperandAcc );
1382- result.assign (
1383- {VectorType::get ({getPerLaneElements (lhs, subgroupSize)}, o.aType ),
1384- VectorType::get ({getPerLaneElements (rhs, subgroupSize)}, o.bType ),
1385- VectorType::get ({getPerLaneElements (acc, subgroupSize)}, o.cType )});
1383+ int64_t broadcastFactor = subgroupSize / llvm::product_of (lhsLayout.thread );
1384+ if (isVDMFMAIntrinsic (intrinsic) && broadcastFactor > 1 ) {
1385+ OpaqueMmaLayout o = getOpaqueMMALayout (context, intrinsic);
1386+ auto rhsLayout = getSingleSubgroupLayout (intrinsic, kMMAOperandRhs );
1387+ auto accLayout = getSingleSubgroupLayout (intrinsic, kMMAOperandAcc );
1388+ result.assign (
1389+ {VectorType::get ({getPerLaneElements (lhsLayout, subgroupSize)},
1390+ o.aType ),
1391+ VectorType::get ({getPerLaneElements (rhsLayout, subgroupSize)},
1392+ o.bType ),
1393+ VectorType::get ({getPerLaneElements (accLayout, subgroupSize)},
1394+ o.cType )});
1395+ } else {
1396+ result.assign ({getThreadVectorType (context, intrinsic, kMMAOperandLhs ),
1397+ getThreadVectorType (context, intrinsic, kMMAOperandRhs ),
1398+ getThreadVectorType (context, intrinsic, kMMAOperandAcc )});
1399+ }
13861400}
13871401
13881402int64_t VirtualMMAAttr::getSubgroupSize () const {
@@ -1394,10 +1408,10 @@ int64_t VirtualMMAAttr::getSubgroupSize() const {
13941408 case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
13951409 case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
13961410 case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1397- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8 :
1398- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8_FP8 :
1399- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8_BF8 :
1400- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8 : {
1411+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ :
1412+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ :
1413+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ :
1414+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ : {
14011415 return 64 ;
14021416 }
14031417 }
@@ -1485,10 +1499,10 @@ int64_t VirtualMMAAttr::getIntrinsicsK() const {
14851499 case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
14861500 case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
14871501 case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1488- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8 :
1489- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8_FP8 :
1490- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8_BF8 :
1491- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8 : {
1502+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ :
1503+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ :
1504+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ :
1505+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ : {
14921506 return 2 ;
14931507 }
14941508 case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
@@ -1500,26 +1514,23 @@ int64_t VirtualMMAAttr::getIntrinsicsK() const {
15001514 return 0 ;
15011515}
15021516
1503- // Expand collapsed ACC [c0, c1] -> [c0, 0, c1, 0].
1504- static Value expandAccumulator (OpBuilder &builder, Location loc, Value acc) {
1517+ // Expands a collapsed 2-element ACC into the 4-element native SMFMAC form
1518+ // by interleaving with zeros: [c0, c1] -> [c0, 0, c1, 0].
1519+ Value expandAccumulator (OpBuilder &builder, Location loc, Value acc) {
15051520 auto accType = cast<VectorType>(acc.getType ());
15061521 Value zero =
15071522 arith::ConstantOp::create (builder, loc, builder.getZeroAttr (accType));
1508-
1509- return vector::ShuffleOp::create (builder, loc, acc, zero,
1510- ArrayRef<int64_t >{0 , 2 , 1 , 3 });
1523+ return vector::InterleaveOp::create (builder, loc, acc, zero);
15111524}
15121525
1513- // Collapse expanded ACC [d0, d1, d2, d3] -> [d0+d1, d2+d3].
1514- static Value collapseAccumulator (OpBuilder &builder, Location loc, Value acc) {
1515- auto accType = cast<VectorType>(acc.getType ());
1516- Type elementType = accType.getElementType ();
1517-
1518- Value evens = vector::ShuffleOp::create (builder, loc, acc, acc,
1519- ArrayRef<int64_t >{0 , 2 });
1520- Value odds = vector::ShuffleOp::create (builder, loc, acc, acc,
1521- ArrayRef<int64_t >{1 , 3 });
1522-
1526+ // Collapses a 4-element native SMFMAC ACC back to the 2-element semantic form.
1527+ // Deinterleaves into evens [d0, d2] and odds [d1, d3], then sums pairwise:
1528+ // [d0, d1, d2, d3] -> [d0+d1, d2+d3].
1529+ Value collapseAccumulator (OpBuilder &builder, Location loc, Value acc) {
1530+ Type elementType = cast<VectorType>(acc.getType ()).getElementType ();
1531+ auto deinterleave = vector::DeinterleaveOp::create (builder, loc, acc);
1532+ Value evens = deinterleave.getRes1 ();
1533+ Value odds = deinterleave.getRes2 ();
15231534 if (isa<FloatType>(elementType)) {
15241535 return arith::AddFOp::create (builder, loc, evens, odds);
15251536 }
@@ -1574,6 +1585,27 @@ struct VDMFMAConfig {
15741585// (e.g., vector<2xf32>). buildVDMFMAOps handles the translation: it expands
15751586// a collapsed accumulator into the 4-element physical form before the smfmac
15761587// chain, then collapses the result back afterward.
1588+
1589+ bool isVDMFMAIntrinsic (VirtualMMAIntrinsic intrinsic) {
1590+ switch (intrinsic) {
1591+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
1592+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
1593+ case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1594+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1595+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1596+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1597+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ:
1598+ return true ;
1599+ case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
1600+ case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
1601+ case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
1602+ case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
1603+ return false ;
1604+ }
1605+ assert (false && " unhandled virtual mma intrinsic type" );
1606+ return false ;
1607+ }
1608+
15771609static LogicalResult buildVDMFMAOps (OpBuilder &builder, Location loc,
15781610 const VDMFMAConfig &config,
15791611 ValueRange inputs, Value acc,
@@ -1695,10 +1727,10 @@ LogicalResult VirtualMMAAttr::buildUnderlyingOperations(
16951727 return buildVDMFMAOps (builder, loc, config, inputs, outputs[0 ], results);
16961728 }
16971729 case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1698- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8 :
1699- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8_FP8 :
1700- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8_BF8 :
1701- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8 : {
1730+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ :
1731+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ :
1732+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ :
1733+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ : {
17021734 if (getColMajor ()) {
17031735 return failure ();
17041736 }
@@ -1730,10 +1762,10 @@ int64_t VirtualMMAAttr::getBlockSize() const {
17301762 case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
17311763 case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
17321764 case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1733- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8 :
1734- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8_FP8 :
1735- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8_BF8 :
1736- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8 : {
1765+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ :
1766+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ :
1767+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ :
1768+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ : {
17371769 return 1 ;
17381770 }
17391771 }
@@ -1813,10 +1845,10 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic,
18131845 /* element=*/ {2 , 1 }};
18141846 }
18151847 case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1816- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8 :
1817- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8_FP8 :
1818- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8_BF8 :
1819- case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8 :
1848+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ :
1849+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ :
1850+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ :
1851+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ :
18201852 switch (operandIndex) {
18211853 case kMMAOperandLhs :
18221854 return {/* outer=*/ {1 , 1 }, /* thread=*/ {8 , 4 }, /* tstrides=*/ {2 , 16 },
0 commit comments