55// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
77#include " iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
8+ #include < cstdint>
89
910#include " iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
1011#include " iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
3435#include " mlir/Dialect/Utils/StaticValueUtils.h"
3536#include " mlir/Dialect/Vector/IR/VectorOps.h"
3637#include " mlir/IR/Attributes.h"
38+ #include " mlir/IR/Builders.h"
3739#include " mlir/IR/BuiltinAttributes.h"
3840#include " mlir/IR/BuiltinTypes.h"
3941#include " mlir/IR/OpDefinition.h"
@@ -1269,6 +1271,10 @@ getMNKShape(VirtualMMAIntrinsic type) {
12691271 case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
12701272 case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
12711273 return {32 , 32 , 16 };
1274+ // Sparse trick VSMFMAs for skinny GEMMs: semantically 8x16xK.
1275+ // TODO(#XXXX): Add I8 VDMFMA variant (VDMFMA_I32_8x16x128_I8).
1276+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
1277+ return {8 , 16 , 64 };
12721278 }
12731279 assert (false && " unhandled virtual mma layout type." );
12741280 return {};
@@ -1285,12 +1291,13 @@ getABCElementTypes(MLIRContext *context, VirtualMMAIntrinsic type) {
12851291 return {f8E4M3FNUZ, f8E4M3FNUZ, f32 };
12861292 case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
12871293 return {f8E4M3FNUZ, f8E4M3FNUZ, f32 };
1288- // V(Virtual)MFMA instructions which have 2 mfma instructions interleaved
1289- // along the k dimension.
12901294 case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
12911295 return {f16 , f16 , f32 };
12921296 case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
12931297 return {f16 , f16 , f32 };
1298+ // Sparse trick VSMFMAs for skinny GEMMs.
1299+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
1300+ return {f16 , f16 , f32 };
12941301 }
12951302 assert (false && " unhandled virtual mma layout type." );
12961303 return {};
@@ -1326,21 +1333,43 @@ void VirtualMMAAttr::getUndistributedTileTypes(
13261333 VectorType::get ({o.mSize , o.nSize }, o.cType )});
13271334}
13281335
1336+ // Returns the number of elements held per lane for a given operand layout,
1337+ // accounting for broadcastFactor when threadProduct < subgroupSize.
1338+ static int64_t getPerLaneElements (MMASingleSubgroupLayout layout,
1339+ int64_t subgroupSize) {
1340+ int64_t threadProduct = llvm::product_of (layout.thread );
1341+ assert (subgroupSize % threadProduct == 0 &&
1342+ " subgroup size must be a multiple of thread product" );
1343+ int64_t broadcastFactor = subgroupSize / threadProduct;
1344+ int64_t totalElements =
1345+ llvm::product_of (layout.element ) * llvm::product_of (layout.outer );
1346+ assert (totalElements % broadcastFactor == 0 &&
1347+ " total elements must be divisible by broadcast factor" );
1348+ return totalElements / broadcastFactor;
1349+ }
1350+
13291351void VirtualMMAAttr::getDistributedTileTypes (
13301352 SmallVectorImpl<VectorType> &result) const {
13311353 MLIRContext *context = getContext ();
13321354 VirtualMMAIntrinsic intrinsic = getIntrinsic ();
1333- result.assign ({getThreadVectorType (context, intrinsic, kMMAOperandLhs ),
1334- getThreadVectorType (context, intrinsic, kMMAOperandRhs ),
1335- getThreadVectorType (context, intrinsic, kMMAOperandAcc )});
1355+ int64_t subgroupSize = getSubgroupSize ();
1356+ OpaqueMmaLayout o = getOpaqueMMALayout (context, intrinsic);
1357+ auto lhs = getSingleSubgroupLayout (intrinsic, kMMAOperandLhs );
1358+ auto rhs = getSingleSubgroupLayout (intrinsic, kMMAOperandRhs );
1359+ auto acc = getSingleSubgroupLayout (intrinsic, kMMAOperandAcc );
1360+ result.assign (
1361+ {VectorType::get ({getPerLaneElements (lhs, subgroupSize)}, o.aType ),
1362+ VectorType::get ({getPerLaneElements (rhs, subgroupSize)}, o.bType ),
1363+ VectorType::get ({getPerLaneElements (acc, subgroupSize)}, o.cType )});
13361364}
13371365
13381366int64_t VirtualMMAAttr::getSubgroupSize () const {
13391367 switch (getIntrinsic ()) {
13401368 case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
13411369 case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
13421370 case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
1343- case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
1371+ case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
1372+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16: {
13441373 return 64 ;
13451374 }
13461375 }
@@ -1366,23 +1395,67 @@ LogicalResult VirtualMMAAttr::populateOperandOffsetsSizesStrides(
13661395 MMASingleSubgroupLayout subgroupLayout =
13671396 getSingleSubgroupLayout (getIntrinsic (), operandIndex,
13681397 operandIndex == kMMAOperandAcc && getColMajor ());
1398+
1399+ // Compute broadcast factor: when thread product < subgroup size, multiple
1400+ // physical lanes share a logical thread position. broadcastFactor tells
1401+ // populateCanonicalOffsetsSizesAndStrides to split the element dimension
1402+ // so each physical lane gets a unique slice.
1403+ int64_t threadProduct = llvm::product_of (subgroupLayout.thread );
1404+ assert (getSubgroupSize () % threadProduct == 0 &&
1405+ " subgroup size must be a multiple of thread product" );
1406+ int64_t broadcastFactor = getSubgroupSize () / threadProduct;
1407+
13691408 SmallVector<OpFoldResult> canonicalOffsets;
13701409 SmallVector<OpFoldResult> canonicalSizes;
13711410 if (failed (populateCanonicalOffsetsSizesAndStrides (
13721411 builder, loc, laneId, permutation, subgroupLayout, canonicalOffsets,
1373- canonicalSizes, strides))) {
1412+ canonicalSizes, strides, broadcastFactor ))) {
13741413 return failure ();
13751414 }
13761415 offsets.append (canonicalOffsets);
13771416 sizes.append (canonicalSizes);
1378-
13791417 return success ();
13801418}
13811419
1420+ // Returns true on odd lanes and false on even lanes.
1421+ static Value createLaneParityPredicate (OpBuilder &builder, Location loc) {
1422+ Value laneId = gpu::LaneIdOp::create (builder, loc, /* upper_bound=*/ nullptr );
1423+ Value one = arith::ConstantIndexOp::create (builder, loc, 1 );
1424+ Value zero = arith::ConstantIndexOp::create (builder, loc, 0 );
1425+ Value lowBit = arith::AndIOp::create (builder, loc, laneId, one);
1426+ return arith::CmpIOp::create (builder, loc, arith::CmpIPredicate::ne, lowBit,
1427+ zero);
1428+ }
1429+
1430+ // Creates a constant sparse index vector for SMFMAC operations.
1431+ //
1432+ // The sparse index encodes which 2 positions out of each group of 4
1433+ // K-elements are selected for 2:4 structured sparsity. Each 4-bit
1434+ // field within selectorBits selects positions for one K-group:
1435+ // 0x4 (0100b) -> positions {0,1}; 0xE (1110b) -> positions {2,3}.
1436+ //
1437+ // For 16-bit source data (f16/bf16): vector<4xi8>, 2 groups per i8.
1438+ // For 8-bit source data (i8): vector<2xi16>, 4 groups per i16.
1439+ //
1440+ // Only the first element carries active selector bits; remaining
1441+ // elements are padding zeros.
1442+ static Value createConstSparseIndex (OpBuilder &builder, Location loc,
1443+ VectorType sparseIndexVectorType,
1444+ int64_t selectorBits) {
1445+ Type elemTy = sparseIndexVectorType.getElementType ();
1446+ Value zero = arith::ConstantOp::create (
1447+ builder, loc, builder.getZeroAttr (sparseIndexVectorType));
1448+ Value selector = arith::ConstantOp::create (
1449+ builder, loc, builder.getIntegerAttr (elemTy, selectorBits));
1450+ return vector::InsertOp::create (builder, loc, selector, zero, 0 );
1451+ }
1452+
1453+ // Returns the K unroll factor: virtual_K / native_K.
13821454int64_t VirtualMMAAttr::getIntrinsicsK () const {
13831455 switch (getIntrinsic ()) {
13841456 case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
1385- case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
1457+ case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
1458+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16: {
13861459 return 2 ;
13871460 }
13881461 case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
@@ -1394,6 +1467,125 @@ int64_t VirtualMMAAttr::getIntrinsicsK() const {
13941467 return 0 ;
13951468}
13961469
1470+ // Expand collapsed ACC [c0, c1] -> [c0, 0, c1, 0].
1471+ static Value expandAccumulator (OpBuilder &builder, Location loc, Value acc) {
1472+ auto accType = cast<VectorType>(acc.getType ());
1473+ Value zero =
1474+ arith::ConstantOp::create (builder, loc, builder.getZeroAttr (accType));
1475+
1476+ return vector::ShuffleOp::create (builder, loc, acc, zero,
1477+ ArrayRef<int64_t >{0 , 2 , 1 , 3 });
1478+ }
1479+
1480+ // Collapse expanded ACC [d0, d1, d2, d3] -> [d0+d1, d2+d3].
1481+ static Value collapseAccumulator (OpBuilder &builder, Location loc, Value acc) {
1482+ auto accType = cast<VectorType>(acc.getType ());
1483+ Type elementType = accType.getElementType ();
1484+
1485+ Value evens = vector::ShuffleOp::create (builder, loc, acc, acc,
1486+ ArrayRef<int64_t >{0 , 2 });
1487+ Value odds = vector::ShuffleOp::create (builder, loc, acc, acc,
1488+ ArrayRef<int64_t >{1 , 3 });
1489+
1490+ if (isa<FloatType>(elementType)) {
1491+ return arith::AddFOp::create (builder, loc, evens, odds);
1492+ }
1493+ return arith::AddIOp::create (builder, loc, evens, odds);
1494+ }
1495+
1496+ // Struct with consolidated info necessary for sparse trick invocation as a
1497+ // VDMFMA.
1498+ struct VDMFMAConfig {
1499+ int64_t m, n, nativeK;
1500+ int64_t unrollFactor;
1501+ VectorType sparseIndexVectorType;
1502+ int64_t evenSparseIndex;
1503+ int64_t oddSparseIndex;
1504+ int64_t aSliceWidth; // Elements per A slice per SMFMAC call.
1505+ SmallVector<SmallVector<int64_t , 8 >, 2 > bInterleaveIndices;
1506+ };
1507+
1508+ // Virtual Dense MFMA (VDMFMA) ops represent invocations of the sparse trick
1509+ // targeting skinny GEMMs (M=8).
1510+ //
1511+ // === The sparse trick ===
1512+ //
1513+ // Sparse MFMA (V_SMFMAC) instructions perform MMA on an imbalanced pair of
1514+ // operands: a 4:2 structured-sparse matrix A and a dense matrix B. The
1515+ // instruction also takes a sparsity index that encodes which 2 of every 4
1516+ // elements along K are non-zero within the sparse matrix A. The trick exploits
1517+ // this by pairing even/odd lanes to jointly describe a full dense row.
1518+ //
1519+ // The lane-pairing layout maps each of the 8 logical M-rows to a pair of
1520+ // adjacent physical rows (row 2i and 2i+1 for logical row i). Within each pair,
1521+ // the even lane supplies positions {0,1} from each K-group of 4 and the odd
1522+ // lane supplies positions {2,3}. The hardware interprets each physical row as
1523+ // having 2:4 structured sparsity and computes a partial dot product over only
1524+ // its non-zero elements. Summing the two physical rows' results reconstructs
1525+ // the full dense dot product for the logical row. This yields a semantic M=8
1526+ // matmul from a physical 16x16 instruction.
1527+ //
1528+ // Each lane loads unique A data via broadcastFactor distribution. Even lanes
1529+ // receive K[0:aSliceWidth*unrollFactor/2], odd lanes receive
1530+ // K[aSliceWidth*unrollFactor/2:aSliceWidth*unrollFactor]. A is sliced
1531+ // sequentially into per-SMFMAC chunks of aSliceWidth elements.
1532+ //
1533+ // === Accumulator expand/collapse ===
1534+ //
1535+ // Because the sparse trick maps two hardware rows to one logical row, adjacent
1536+ // register pairs in the output hold partial sums for the same dense row.
1537+ // Collapsing sums each pair (v0+v1, v2+v3) to produce the 2-element semantic
1538+ // result: one complete value per logical row.
1539+ //
1540+ // The layout and distribution infrastructure operate on the collapsed vector
1541+ // (e.g., vector<2xf32>). buildVDMFMAOps handles the translation: it expands
1542+ // a collapsed accumulator into the 4-element physical form before the smfmac
1543+ // chain, then collapses the result back afterward.
1544+ static LogicalResult buildVDMFMAOps (OpBuilder &builder, Location loc,
1545+ const VDMFMAConfig &config,
1546+ ValueRange inputs, Value acc,
1547+ SmallVectorImpl<Value> &results) {
1548+ Value smfmacAcc = expandAccumulator (builder, loc, acc);
1549+ VectorType expandedAccType = cast<VectorType>(smfmacAcc.getType ());
1550+
1551+ Value isOddLane = createLaneParityPredicate (builder, loc);
1552+
1553+ Value sparseIndex = arith::SelectOp::create (
1554+ builder, loc, isOddLane,
1555+ createConstSparseIndex (builder, loc, config.sparseIndexVectorType ,
1556+ config.oddSparseIndex ),
1557+ createConstSparseIndex (builder, loc, config.sparseIndexVectorType ,
1558+ config.evenSparseIndex ));
1559+
1560+ Value lhs = inputs[0 ];
1561+ Value rhs = inputs[1 ];
1562+
1563+ assert (static_cast <int64_t >(config.bInterleaveIndices .size ()) ==
1564+ config.unrollFactor &&
1565+ " must provide B interleave indices for each unroll iteration" );
1566+
1567+ for (int64_t i = 0 ; i < config.unrollFactor ; ++i) {
1568+ int64_t aOffset = config.aSliceWidth * i;
1569+ Value aSlice = vector::ExtractStridedSliceOp::create (
1570+ builder, loc, lhs, /* offsets=*/ ArrayRef<int64_t >{aOffset},
1571+ /* sizes=*/ ArrayRef<int64_t >{config.aSliceWidth },
1572+ /* strides=*/ ArrayRef<int64_t >{1 });
1573+
1574+ Value bSlice = vector::ShuffleOp::create (builder, loc, rhs, rhs,
1575+ config.bInterleaveIndices [i]);
1576+
1577+ smfmacAcc = amdgpu::SparseMFMAOp::create (
1578+ builder, loc, expandedAccType,
1579+ /* m=*/ config.m , /* n=*/ config.n , /* k=*/ config.nativeK ,
1580+ /* sourceA=*/ aSlice, /* sourceB=*/ bSlice, /* destC=*/ smfmacAcc,
1581+ /* sparseIdx=*/ sparseIndex, /* cbsz=*/ 0 , /* abid=*/ 0 );
1582+ }
1583+
1584+ Value result = collapseAccumulator (builder, loc, smfmacAcc);
1585+ results.push_back (result);
1586+ return success ();
1587+ }
1588+
13971589// Generates amdgpu.mfma/wmma operation on the given inputs for this attribute
13981590// type.
13991591LogicalResult VirtualMMAAttr::buildUnderlyingOperations (
@@ -1450,6 +1642,24 @@ LogicalResult VirtualMMAAttr::buildUnderlyingOperations(
14501642 results.push_back (acc);
14511643 return success ();
14521644 }
1645+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16: {
1646+ if (getColMajor ()) {
1647+ return failure ();
1648+ }
1649+ VDMFMAConfig config{
1650+ /* m=*/ 16 ,
1651+ /* n=*/ 16 ,
1652+ /* nativeK=*/ 32 ,
1653+ /* unrollFactor=*/ 2 ,
1654+ /* sparseIndexVectorType=*/
1655+ VectorType::get ({4 }, builder.getIntegerType (8 )),
1656+ /* evenSparseIndex=*/ 0x44 ,
1657+ /* oddSparseIndex=*/ 0xEE ,
1658+ /* aSliceWidth=*/ 4 ,
1659+ /* bInterleaveIndices=*/
1660+ {{0 , 1 , 8 , 9 , 2 , 3 , 10 , 11 }, {4 , 5 , 12 , 13 , 6 , 7 , 14 , 15 }}};
1661+ return buildVDMFMAOps (builder, loc, config, inputs, outputs[0 ], results);
1662+ }
14531663 }
14541664 return failure ();
14551665}
@@ -1459,7 +1669,8 @@ int64_t VirtualMMAAttr::getBlockSize() const {
14591669 case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
14601670 case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
14611671 case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
1462- case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
1672+ case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
1673+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16: {
14631674 return 1 ;
14641675 }
14651676 }
@@ -1525,6 +1736,18 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic,
15251736 return {/* outer=*/ {4 , 1 }, /* thread=*/ {2 , 32 }, /* tstrides=*/ {32 , 1 },
15261737 /* element=*/ {4 , 1 }};
15271738 }
1739+ case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
1740+ switch (operandIndex) {
1741+ case kMMAOperandLhs :
1742+ return {/* outer=*/ {1 , 1 }, /* thread=*/ {8 , 4 }, /* tstrides=*/ {2 , 16 },
1743+ /* element=*/ {1 , 16 }};
1744+ case kMMAOperandRhs :
1745+ return {/* outer=*/ {1 , 1 }, /* thread=*/ {4 , 16 }, /* tstrides=*/ {16 , 1 },
1746+ /* element=*/ {16 , 1 }};
1747+ case kMMAOperandAcc :
1748+ return {/* outer=*/ {1 , 1 }, /* thread=*/ {4 , 16 }, /* tstrides=*/ {16 , 1 },
1749+ /* element=*/ {2 , 1 }};
1750+ }
15281751 }
15291752 assert (false && " unhandled virtual mma layout type." );
15301753 return {};
0 commit comments