Skip to content

Commit 95b6efc

Browse files
committed
f16
Signed-off-by: Eric Feng <Eric.Feng@amd.com>
1 parent ea5cc11 commit 95b6efc

4 files changed

Lines changed: 292 additions & 10 deletions

File tree

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp

Lines changed: 233 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
8+
#include <cstdint>
89

910
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
1011
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
@@ -34,6 +35,7 @@
3435
#include "mlir/Dialect/Utils/StaticValueUtils.h"
3536
#include "mlir/Dialect/Vector/IR/VectorOps.h"
3637
#include "mlir/IR/Attributes.h"
38+
#include "mlir/IR/Builders.h"
3739
#include "mlir/IR/BuiltinAttributes.h"
3840
#include "mlir/IR/BuiltinTypes.h"
3941
#include "mlir/IR/OpDefinition.h"
@@ -1269,6 +1271,10 @@ getMNKShape(VirtualMMAIntrinsic type) {
12691271
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
12701272
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
12711273
return {32, 32, 16};
1274+
// Sparse trick VSMFMAs for skinny GEMMs: semantically 8x16xK.
1275+
// TODO(#XXXX): Add I8 VDMFMA variant (VDMFMA_I32_8x16x128_I8).
1276+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
1277+
return {8, 16, 64};
12721278
}
12731279
assert(false && "unhandled virtual mma layout type.");
12741280
return {};
@@ -1285,12 +1291,13 @@ getABCElementTypes(MLIRContext *context, VirtualMMAIntrinsic type) {
12851291
return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
12861292
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
12871293
return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
1288-
// V(Virtual)MFMA instructions which have 2 mfma instructions interleaved
1289-
// along the k dimension.
12901294
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
12911295
return {f16, f16, f32};
12921296
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
12931297
return {f16, f16, f32};
1298+
// Sparse trick VSMFMAs for skinny GEMMs.
1299+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
1300+
return {f16, f16, f32};
12941301
}
12951302
assert(false && "unhandled virtual mma layout type.");
12961303
return {};
@@ -1326,21 +1333,43 @@ void VirtualMMAAttr::getUndistributedTileTypes(
13261333
VectorType::get({o.mSize, o.nSize}, o.cType)});
13271334
}
13281335

1336+
// Returns the number of elements held per lane for a given operand layout,
1337+
// accounting for broadcastFactor when threadProduct < subgroupSize.
1338+
static int64_t getPerLaneElements(MMASingleSubgroupLayout layout,
1339+
int64_t subgroupSize) {
1340+
int64_t threadProduct = llvm::product_of(layout.thread);
1341+
assert(subgroupSize % threadProduct == 0 &&
1342+
"subgroup size must be a multiple of thread product");
1343+
int64_t broadcastFactor = subgroupSize / threadProduct;
1344+
int64_t totalElements =
1345+
llvm::product_of(layout.element) * llvm::product_of(layout.outer);
1346+
assert(totalElements % broadcastFactor == 0 &&
1347+
"total elements must be divisible by broadcast factor");
1348+
return totalElements / broadcastFactor;
1349+
}
1350+
13291351
void VirtualMMAAttr::getDistributedTileTypes(
13301352
SmallVectorImpl<VectorType> &result) const {
13311353
MLIRContext *context = getContext();
13321354
VirtualMMAIntrinsic intrinsic = getIntrinsic();
1333-
result.assign({getThreadVectorType(context, intrinsic, kMMAOperandLhs),
1334-
getThreadVectorType(context, intrinsic, kMMAOperandRhs),
1335-
getThreadVectorType(context, intrinsic, kMMAOperandAcc)});
1355+
int64_t subgroupSize = getSubgroupSize();
1356+
OpaqueMmaLayout o = getOpaqueMMALayout(context, intrinsic);
1357+
auto lhs = getSingleSubgroupLayout(intrinsic, kMMAOperandLhs);
1358+
auto rhs = getSingleSubgroupLayout(intrinsic, kMMAOperandRhs);
1359+
auto acc = getSingleSubgroupLayout(intrinsic, kMMAOperandAcc);
1360+
result.assign(
1361+
{VectorType::get({getPerLaneElements(lhs, subgroupSize)}, o.aType),
1362+
VectorType::get({getPerLaneElements(rhs, subgroupSize)}, o.bType),
1363+
VectorType::get({getPerLaneElements(acc, subgroupSize)}, o.cType)});
13361364
}
13371365

13381366
int64_t VirtualMMAAttr::getSubgroupSize() const {
13391367
switch (getIntrinsic()) {
13401368
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
13411369
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
13421370
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
1343-
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
1371+
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
1372+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16: {
13441373
return 64;
13451374
}
13461375
}
@@ -1366,23 +1395,67 @@ LogicalResult VirtualMMAAttr::populateOperandOffsetsSizesStrides(
13661395
MMASingleSubgroupLayout subgroupLayout =
13671396
getSingleSubgroupLayout(getIntrinsic(), operandIndex,
13681397
operandIndex == kMMAOperandAcc && getColMajor());
1398+
1399+
// Compute broadcast factor: when thread product < subgroup size, multiple
1400+
// physical lanes share a logical thread position. broadcastFactor tells
1401+
// populateCanonicalOffsetsSizesAndStrides to split the element dimension
1402+
// so each physical lane gets a unique slice.
1403+
int64_t threadProduct = llvm::product_of(subgroupLayout.thread);
1404+
assert(getSubgroupSize() % threadProduct == 0 &&
1405+
"subgroup size must be a multiple of thread product");
1406+
int64_t broadcastFactor = getSubgroupSize() / threadProduct;
1407+
13691408
SmallVector<OpFoldResult> canonicalOffsets;
13701409
SmallVector<OpFoldResult> canonicalSizes;
13711410
if (failed(populateCanonicalOffsetsSizesAndStrides(
13721411
builder, loc, laneId, permutation, subgroupLayout, canonicalOffsets,
1373-
canonicalSizes, strides))) {
1412+
canonicalSizes, strides, broadcastFactor))) {
13741413
return failure();
13751414
}
13761415
offsets.append(canonicalOffsets);
13771416
sizes.append(canonicalSizes);
1378-
13791417
return success();
13801418
}
13811419

1420+
// Returns true on odd lanes and false on even lanes.
1421+
static Value createLaneParityPredicate(OpBuilder &builder, Location loc) {
1422+
Value laneId = gpu::LaneIdOp::create(builder, loc, /*upper_bound=*/nullptr);
1423+
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
1424+
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
1425+
Value lowBit = arith::AndIOp::create(builder, loc, laneId, one);
1426+
return arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ne, lowBit,
1427+
zero);
1428+
}
1429+
1430+
// Creates a constant sparse index vector for SMFMAC operations.
1431+
//
1432+
// The sparse index encodes which 2 positions out of each group of 4
1433+
// K-elements are selected for 2:4 structured sparsity. Each 4-bit
1434+
// field within selectorBits selects positions for one K-group:
1435+
// 0x4 (0100b) -> positions {0,1}; 0xE (1110b) -> positions {2,3}.
1436+
//
1437+
// For 16-bit source data (f16/bf16): vector<4xi8>, 2 groups per i8.
1438+
// For 8-bit source data (i8): vector<2xi16>, 4 groups per i16.
1439+
//
1440+
// Only the first element carries active selector bits; remaining
1441+
// elements are padding zeros.
1442+
static Value createConstSparseIndex(OpBuilder &builder, Location loc,
1443+
VectorType sparseIndexVectorType,
1444+
int64_t selectorBits) {
1445+
Type elemTy = sparseIndexVectorType.getElementType();
1446+
Value zero = arith::ConstantOp::create(
1447+
builder, loc, builder.getZeroAttr(sparseIndexVectorType));
1448+
Value selector = arith::ConstantOp::create(
1449+
builder, loc, builder.getIntegerAttr(elemTy, selectorBits));
1450+
return vector::InsertOp::create(builder, loc, selector, zero, 0);
1451+
}
1452+
1453+
// Returns the K unroll factor: virtual_K / native_K.
13821454
int64_t VirtualMMAAttr::getIntrinsicsK() const {
13831455
switch (getIntrinsic()) {
13841456
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
1385-
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
1457+
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
1458+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16: {
13861459
return 2;
13871460
}
13881461
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
@@ -1394,6 +1467,125 @@ int64_t VirtualMMAAttr::getIntrinsicsK() const {
13941467
return 0;
13951468
}
13961469

1470+
// Expand collapsed ACC [c0, c1] -> [c0, 0, c1, 0].
1471+
static Value expandAccumulator(OpBuilder &builder, Location loc, Value acc) {
1472+
auto accType = cast<VectorType>(acc.getType());
1473+
Value zero =
1474+
arith::ConstantOp::create(builder, loc, builder.getZeroAttr(accType));
1475+
1476+
return vector::ShuffleOp::create(builder, loc, acc, zero,
1477+
ArrayRef<int64_t>{0, 2, 1, 3});
1478+
}
1479+
1480+
// Collapse expanded ACC [d0, d1, d2, d3] -> [d0+d1, d2+d3].
1481+
static Value collapseAccumulator(OpBuilder &builder, Location loc, Value acc) {
1482+
auto accType = cast<VectorType>(acc.getType());
1483+
Type elementType = accType.getElementType();
1484+
1485+
Value evens = vector::ShuffleOp::create(builder, loc, acc, acc,
1486+
ArrayRef<int64_t>{0, 2});
1487+
Value odds = vector::ShuffleOp::create(builder, loc, acc, acc,
1488+
ArrayRef<int64_t>{1, 3});
1489+
1490+
if (isa<FloatType>(elementType)) {
1491+
return arith::AddFOp::create(builder, loc, evens, odds);
1492+
}
1493+
return arith::AddIOp::create(builder, loc, evens, odds);
1494+
}
1495+
1496+
// Struct with consolidated info necessary for sparse trick invocation as a
1497+
// VDMFMA.
1498+
struct VDMFMAConfig {
1499+
int64_t m, n, nativeK;
1500+
int64_t unrollFactor;
1501+
VectorType sparseIndexVectorType;
1502+
int64_t evenSparseIndex;
1503+
int64_t oddSparseIndex;
1504+
int64_t aSliceWidth; // Elements per A slice per SMFMAC call.
1505+
SmallVector<SmallVector<int64_t, 8>, 2> bInterleaveIndices;
1506+
};
1507+
1508+
// Virtual Dense MFMA (VDMFMA) ops represent invocations of the sparse trick
1509+
// targeting skinny GEMMs (M=8).
1510+
//
1511+
// === The sparse trick ===
1512+
//
1513+
// Sparse MFMA (V_SMFMAC) instructions perform MMA on an imbalanced pair of
1514+
// operands: a 4:2 structured-sparse matrix A and a dense matrix B. The
1515+
// instruction also takes a sparsity index that encodes which 2 of every 4
1516+
// elements along K are non-zero within the sparse matrix A. The trick exploits
1517+
// this by pairing even/odd lanes to jointly describe a full dense row.
1518+
//
1519+
// The lane-pairing layout maps each of the 8 logical M-rows to a pair of
1520+
// adjacent physical rows (row 2i and 2i+1 for logical row i). Within each pair,
1521+
// the even lane supplies positions {0,1} from each K-group of 4 and the odd
1522+
// lane supplies positions {2,3}. The hardware interprets each physical row as
1523+
// having 2:4 structured sparsity and computes a partial dot product over only
1524+
// its non-zero elements. Summing the two physical rows' results reconstructs
1525+
// the full dense dot product for the logical row. This yields a semantic M=8
1526+
// matmul from a physical 16x16 instruction.
1527+
//
1528+
// Each lane loads unique A data via broadcastFactor distribution. Even lanes
1529+
// receive K[0:aSliceWidth*unrollFactor/2], odd lanes receive
1530+
// K[aSliceWidth*unrollFactor/2:aSliceWidth*unrollFactor]. A is sliced
1531+
// sequentially into per-SMFMAC chunks of aSliceWidth elements.
1532+
//
1533+
// === Accumulator expand/collapse ===
1534+
//
1535+
// Because the sparse trick maps two hardware rows to one logical row, adjacent
1536+
// register pairs in the output hold partial sums for the same dense row.
1537+
// Collapsing sums each pair (v0+v1, v2+v3) to produce the 2-element semantic
1538+
// result: one complete value per logical row.
1539+
//
1540+
// The layout and distribution infrastructure operate on the collapsed vector
1541+
// (e.g., vector<2xf32>). buildVDMFMAOps handles the translation: it expands
1542+
// a collapsed accumulator into the 4-element physical form before the smfmac
1543+
// chain, then collapses the result back afterward.
1544+
static LogicalResult buildVDMFMAOps(OpBuilder &builder, Location loc,
1545+
const VDMFMAConfig &config,
1546+
ValueRange inputs, Value acc,
1547+
SmallVectorImpl<Value> &results) {
1548+
Value smfmacAcc = expandAccumulator(builder, loc, acc);
1549+
VectorType expandedAccType = cast<VectorType>(smfmacAcc.getType());
1550+
1551+
Value isOddLane = createLaneParityPredicate(builder, loc);
1552+
1553+
Value sparseIndex = arith::SelectOp::create(
1554+
builder, loc, isOddLane,
1555+
createConstSparseIndex(builder, loc, config.sparseIndexVectorType,
1556+
config.oddSparseIndex),
1557+
createConstSparseIndex(builder, loc, config.sparseIndexVectorType,
1558+
config.evenSparseIndex));
1559+
1560+
Value lhs = inputs[0];
1561+
Value rhs = inputs[1];
1562+
1563+
assert(static_cast<int64_t>(config.bInterleaveIndices.size()) ==
1564+
config.unrollFactor &&
1565+
"must provide B interleave indices for each unroll iteration");
1566+
1567+
for (int64_t i = 0; i < config.unrollFactor; ++i) {
1568+
int64_t aOffset = config.aSliceWidth * i;
1569+
Value aSlice = vector::ExtractStridedSliceOp::create(
1570+
builder, loc, lhs, /*offsets=*/ArrayRef<int64_t>{aOffset},
1571+
/*sizes=*/ArrayRef<int64_t>{config.aSliceWidth},
1572+
/*strides=*/ArrayRef<int64_t>{1});
1573+
1574+
Value bSlice = vector::ShuffleOp::create(builder, loc, rhs, rhs,
1575+
config.bInterleaveIndices[i]);
1576+
1577+
smfmacAcc = amdgpu::SparseMFMAOp::create(
1578+
builder, loc, expandedAccType,
1579+
/*m=*/config.m, /*n=*/config.n, /*k=*/config.nativeK,
1580+
/*sourceA=*/aSlice, /*sourceB=*/bSlice, /*destC=*/smfmacAcc,
1581+
/*sparseIdx=*/sparseIndex, /*cbsz=*/0, /*abid=*/0);
1582+
}
1583+
1584+
Value result = collapseAccumulator(builder, loc, smfmacAcc);
1585+
results.push_back(result);
1586+
return success();
1587+
}
1588+
13971589
// Generates amdgpu.mfma/wmma operation on the given inputs for this attribute
13981590
// type.
13991591
LogicalResult VirtualMMAAttr::buildUnderlyingOperations(
@@ -1450,6 +1642,24 @@ LogicalResult VirtualMMAAttr::buildUnderlyingOperations(
14501642
results.push_back(acc);
14511643
return success();
14521644
}
1645+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16: {
1646+
if (getColMajor()) {
1647+
return failure();
1648+
}
1649+
VDMFMAConfig config{
1650+
/*m=*/16,
1651+
/*n=*/16,
1652+
/*nativeK=*/32,
1653+
/*unrollFactor=*/2,
1654+
/*sparseIndexVectorType=*/
1655+
VectorType::get({4}, builder.getIntegerType(8)),
1656+
/*evenSparseIndex=*/0x44,
1657+
/*oddSparseIndex=*/0xEE,
1658+
/*aSliceWidth=*/4,
1659+
/*bInterleaveIndices=*/
1660+
{{0, 1, 8, 9, 2, 3, 10, 11}, {4, 5, 12, 13, 6, 7, 14, 15}}};
1661+
return buildVDMFMAOps(builder, loc, config, inputs, outputs[0], results);
1662+
}
14531663
}
14541664
return failure();
14551665
}
@@ -1459,7 +1669,8 @@ int64_t VirtualMMAAttr::getBlockSize() const {
14591669
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
14601670
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
14611671
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
1462-
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
1672+
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
1673+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16: {
14631674
return 1;
14641675
}
14651676
}
@@ -1525,6 +1736,18 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic,
15251736
return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1},
15261737
/*element=*/{4, 1}};
15271738
}
1739+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
1740+
switch (operandIndex) {
1741+
case kMMAOperandLhs:
1742+
return {/*outer=*/{1, 1}, /*thread=*/{8, 4}, /*tstrides=*/{2, 16},
1743+
/*element=*/{1, 16}};
1744+
case kMMAOperandRhs:
1745+
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
1746+
/*element=*/{16, 1}};
1747+
case kMMAOperandAcc:
1748+
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
1749+
/*element=*/{2, 1}};
1750+
}
15281751
}
15291752
assert(false && "unhandled virtual mma layout type.");
15301753
return {};

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,13 +373,16 @@ def VMFMA_F32_16x16x32_F16 : I32EnumAttrCase<"VMFMA_F32_16x16x32_F16", 0>;
373373
def VMFMA_F32_32x32x16_F16 : I32EnumAttrCase<"VMFMA_F32_32x32x16_F16", 1>;
374374
def VMFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"VMFMA_F32_16x16x32_F8E4M3FNUZ", 2>;
375375
def VMFMA_F32_32x32x16_F8E4M3FNUZ : I32EnumAttrCase<"VMFMA_F32_32x32x16_F8E4M3FNUZ", 3>;
376+
def VDMFMA_F32_8x16x64_F16 : I32EnumAttrCase<"VDMFMA_F32_8x16x64_F16", 4>;
377+
// TODO(#XXXX): Add I8 VDMFMA variant (VDMFMA_I32_8x16x128_I8).
376378

377379
def IREEGPU_VirtualMMAIntrinsic : IREEGPU_I32EnumAttr<"VirtualMMAIntrinsic",
378380
"Descriptor for different Virtual MMA intrinsics", [
379381
VMFMA_F32_16x16x32_F16,
380382
VMFMA_F32_32x32x16_F16,
381383
VMFMA_F32_16x16x32_F8E4M3FNUZ,
382384
VMFMA_F32_32x32x16_F8E4M3FNUZ,
385+
VDMFMA_F32_8x16x64_F16,
383386
]>;
384387

385388
// Enum for scaled mma intrinsic, loosely matching the MMAIntrinsic enum above

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_attrs.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@ module {
3636
// CHECK-LABEL: func @test_col_major_vmfma_f16_16x16x32_f32
3737
// CHECK-SAME: mma_types = #iree_gpu.virtual_mma_layout<VMFMA_F32_16x16x32_F16, col_major = true>
3838

39+
module {
40+
func.func @test_vdmfma_f16_8x16x64() attributes {
41+
mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x64_F16>} {
42+
return
43+
}
44+
}
45+
// CHECK-LABEL: func @test_vdmfma_f16_8x16x64
46+
// CHECK-SAME: mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x64_F16>
47+
3948
module {
4049
func.func @test_WMMAR3_f16_16x16x16_f32() attributes {
4150
mma_types = #iree_gpu.mma_layout<WMMAR3_F32_16x16x16_F16>} {

0 commit comments

Comments
 (0)