Skip to content

Commit 97dd26b

Browse files
efricclaude
andcommitted
nits
Applies the VDMFMA-1 changes: renames FP8/BF8 enum variants to F8E4M3FNUZ/F8E5M2FNUZ, switches expand/collapse accumulator to use vector.interleave/deinterleave, adds isVDMFMAIntrinsic helper and header declarations, and fixes getDistributedTileTypes broadcastFactor logic. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Eric Feng <Eric.Feng@amd.com>
1 parent 88891b5 commit 97dd26b

5 files changed

Lines changed: 153 additions & 111 deletions

File tree

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp

Lines changed: 85 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ OpFoldResult MMAAttr::getDistributionWorkerCount(OpBuilder &, Location,
754754
return getAsIndexOpFoldResult(getContext(), getSubgroupSize());
755755
}
756756

757-
// Get virtual intrinsics that is composed/based on queried op.
757+
// Returns virtual intrinsics that are composed from this concrete MMA op.
758758
SmallVector<VirtualMMAIntrinsic> MMAAttr::getVirtualIntrinsics() const {
759759
switch (getIntrinsic()) {
760760
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
@@ -1269,15 +1269,15 @@ getMNKShape(VirtualMMAIntrinsic type) {
12691269
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
12701270
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
12711271
return {32, 32, 16};
1272-
// Sparse trick VDMFMAs for skinny GEMMs.
1272+
// Sparse trick VDMFMAs for skinny GEMMs: semantically 8x16xK.
12731273
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
12741274
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
12751275
return {8, 16, 64};
12761276
case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1277-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8:
1278-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8_FP8:
1279-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8_BF8:
1280-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8:
1277+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1278+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1279+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1280+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ:
12811281
return {8, 16, 128};
12821282
}
12831283
assert(false && "unhandled virtual mma layout type.");
@@ -1312,13 +1312,13 @@ getABCElementTypes(MLIRContext *context, VirtualMMAIntrinsic type) {
13121312
return {bf16, bf16, f32};
13131313
case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
13141314
return {i8, i8, i32};
1315-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8:
1315+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
13161316
return {f8E5M2FNUZ, f8E5M2FNUZ, f32};
1317-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8_FP8:
1317+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
13181318
return {f8E5M2FNUZ, f8E4M3FNUZ, f32};
1319-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8_BF8:
1319+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
13201320
return {f8E4M3FNUZ, f8E5M2FNUZ, f32};
1321-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8:
1321+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ:
13221322
return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
13231323
}
13241324
assert(false && "unhandled virtual mma layout type.");
@@ -1374,15 +1374,29 @@ void VirtualMMAAttr::getDistributedTileTypes(
13741374
SmallVectorImpl<VectorType> &result) const {
13751375
MLIRContext *context = getContext();
13761376
VirtualMMAIntrinsic intrinsic = getIntrinsic();
1377+
// VDMFMA layouts pair adjacent lanes to emulate a wider tile, so
1378+
// threadProduct < subgroupSize (broadcastFactor > 1). We need
1379+
// getPerLaneElements to divide out the broadcast and compute the correct
1380+
// per-lane element count.
1381+
auto lhsLayout = getSingleSubgroupLayout(intrinsic, kMMAOperandLhs);
13771382
int64_t subgroupSize = getSubgroupSize();
1378-
OpaqueMmaLayout o = getOpaqueMMALayout(context, intrinsic);
1379-
auto lhs = getSingleSubgroupLayout(intrinsic, kMMAOperandLhs);
1380-
auto rhs = getSingleSubgroupLayout(intrinsic, kMMAOperandRhs);
1381-
auto acc = getSingleSubgroupLayout(intrinsic, kMMAOperandAcc);
1382-
result.assign(
1383-
{VectorType::get({getPerLaneElements(lhs, subgroupSize)}, o.aType),
1384-
VectorType::get({getPerLaneElements(rhs, subgroupSize)}, o.bType),
1385-
VectorType::get({getPerLaneElements(acc, subgroupSize)}, o.cType)});
1383+
int64_t broadcastFactor = subgroupSize / llvm::product_of(lhsLayout.thread);
1384+
if (isVDMFMAIntrinsic(intrinsic) && broadcastFactor > 1) {
1385+
OpaqueMmaLayout o = getOpaqueMMALayout(context, intrinsic);
1386+
auto rhsLayout = getSingleSubgroupLayout(intrinsic, kMMAOperandRhs);
1387+
auto accLayout = getSingleSubgroupLayout(intrinsic, kMMAOperandAcc);
1388+
result.assign(
1389+
{VectorType::get({getPerLaneElements(lhsLayout, subgroupSize)},
1390+
o.aType),
1391+
VectorType::get({getPerLaneElements(rhsLayout, subgroupSize)},
1392+
o.bType),
1393+
VectorType::get({getPerLaneElements(accLayout, subgroupSize)},
1394+
o.cType)});
1395+
} else {
1396+
result.assign({getThreadVectorType(context, intrinsic, kMMAOperandLhs),
1397+
getThreadVectorType(context, intrinsic, kMMAOperandRhs),
1398+
getThreadVectorType(context, intrinsic, kMMAOperandAcc)});
1399+
}
13861400
}
13871401

13881402
int64_t VirtualMMAAttr::getSubgroupSize() const {
@@ -1394,10 +1408,10 @@ int64_t VirtualMMAAttr::getSubgroupSize() const {
13941408
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
13951409
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
13961410
case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1397-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8:
1398-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8_FP8:
1399-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8_BF8:
1400-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8: {
1411+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1412+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1413+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1414+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ: {
14011415
return 64;
14021416
}
14031417
}
@@ -1485,10 +1499,10 @@ int64_t VirtualMMAAttr::getIntrinsicsK() const {
14851499
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
14861500
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
14871501
case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1488-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8:
1489-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8_FP8:
1490-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8_BF8:
1491-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8: {
1502+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1503+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1504+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1505+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ: {
14921506
return 2;
14931507
}
14941508
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
@@ -1500,26 +1514,23 @@ int64_t VirtualMMAAttr::getIntrinsicsK() const {
15001514
return 0;
15011515
}
15021516

1503-
// Expand collapsed ACC [c0, c1] -> [c0, 0, c1, 0].
1504-
static Value expandAccumulator(OpBuilder &builder, Location loc, Value acc) {
1517+
// Expands a collapsed 2-element ACC into the 4-element native SMFMAC form
1518+
// by interleaving with zeros: [c0, c1] -> [c0, 0, c1, 0].
1519+
Value expandAccumulator(OpBuilder &builder, Location loc, Value acc) {
15051520
auto accType = cast<VectorType>(acc.getType());
15061521
Value zero =
15071522
arith::ConstantOp::create(builder, loc, builder.getZeroAttr(accType));
1508-
1509-
return vector::ShuffleOp::create(builder, loc, acc, zero,
1510-
ArrayRef<int64_t>{0, 2, 1, 3});
1523+
return vector::InterleaveOp::create(builder, loc, acc, zero);
15111524
}
15121525

1513-
// Collapse expanded ACC [d0, d1, d2, d3] -> [d0+d1, d2+d3].
1514-
static Value collapseAccumulator(OpBuilder &builder, Location loc, Value acc) {
1515-
auto accType = cast<VectorType>(acc.getType());
1516-
Type elementType = accType.getElementType();
1517-
1518-
Value evens = vector::ShuffleOp::create(builder, loc, acc, acc,
1519-
ArrayRef<int64_t>{0, 2});
1520-
Value odds = vector::ShuffleOp::create(builder, loc, acc, acc,
1521-
ArrayRef<int64_t>{1, 3});
1522-
1526+
// Collapses a 4-element native SMFMAC ACC back to the 2-element semantic form.
1527+
// Deinterleaves into evens [d0, d2] and odds [d1, d3], then sums pairwise:
1528+
// [d0, d1, d2, d3] -> [d0+d1, d2+d3].
1529+
Value collapseAccumulator(OpBuilder &builder, Location loc, Value acc) {
1530+
Type elementType = cast<VectorType>(acc.getType()).getElementType();
1531+
auto deinterleave = vector::DeinterleaveOp::create(builder, loc, acc);
1532+
Value evens = deinterleave.getRes1();
1533+
Value odds = deinterleave.getRes2();
15231534
if (isa<FloatType>(elementType)) {
15241535
return arith::AddFOp::create(builder, loc, evens, odds);
15251536
}
@@ -1574,6 +1585,27 @@ struct VDMFMAConfig {
15741585
// (e.g., vector<2xf32>). buildVDMFMAOps handles the translation: it expands
15751586
// a collapsed accumulator into the 4-element physical form before the smfmac
15761587
// chain, then collapses the result back afterward.
1588+
1589+
bool isVDMFMAIntrinsic(VirtualMMAIntrinsic intrinsic) {
1590+
switch (intrinsic) {
1591+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
1592+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
1593+
case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1594+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1595+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1596+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1597+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ:
1598+
return true;
1599+
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
1600+
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
1601+
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
1602+
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
1603+
return false;
1604+
}
1605+
assert(false && "unhandled virtual mma intrinsic type");
1606+
return false;
1607+
}
1608+
15771609
static LogicalResult buildVDMFMAOps(OpBuilder &builder, Location loc,
15781610
const VDMFMAConfig &config,
15791611
ValueRange inputs, Value acc,
@@ -1695,10 +1727,10 @@ LogicalResult VirtualMMAAttr::buildUnderlyingOperations(
16951727
return buildVDMFMAOps(builder, loc, config, inputs, outputs[0], results);
16961728
}
16971729
case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1698-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8:
1699-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8_FP8:
1700-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8_BF8:
1701-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8: {
1730+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1731+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1732+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1733+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ: {
17021734
if (getColMajor()) {
17031735
return failure();
17041736
}
@@ -1730,10 +1762,10 @@ int64_t VirtualMMAAttr::getBlockSize() const {
17301762
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_F16:
17311763
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x64_BF16:
17321764
case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1733-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8:
1734-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8_FP8:
1735-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8_BF8:
1736-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8: {
1765+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1766+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1767+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1768+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ: {
17371769
return 1;
17381770
}
17391771
}
@@ -1813,10 +1845,10 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic,
18131845
/*element=*/{2, 1}};
18141846
}
18151847
case VirtualMMAIntrinsic::VDMFMA_I32_8x16x128_I8:
1816-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8:
1817-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_BF8_FP8:
1818-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8_BF8:
1819-
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_FP8:
1848+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ:
1849+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ:
1850+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ:
1851+
case VirtualMMAIntrinsic::VDMFMA_F32_8x16x128_F8E4M3FNUZ:
18201852
switch (operandIndex) {
18211853
case kMMAOperandLhs:
18221854
return {/*outer=*/{1, 1}, /*thread=*/{8, 4}, /*tstrides=*/{2, 16},

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,23 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(ScaledMMAIntrinsic intrinsic,
297297
/// attribute.
298298
StringRef getTilingLevelName(GPU::TilingLevel level);
299299

300+
//===----------------------------------------------------------------------===//
301+
// VDMFMA accumulator utilities
302+
//===----------------------------------------------------------------------===//
303+
304+
/// Returns true if the given VirtualMMAIntrinsic is a VDMFMA (virtual dense
305+
/// MFMA via sparse trick) intrinsic.
306+
bool isVDMFMAIntrinsic(VirtualMMAIntrinsic intrinsic);
307+
308+
/// Expands a collapsed 2-element ACC into the 4-element native SMFMAC form
309+
/// by interleaving with zeros: [c0, c1] -> [c0, 0, c1, 0].
310+
Value expandAccumulator(OpBuilder &builder, Location loc, Value acc);
311+
312+
/// Collapses a 4-element native SMFMAC ACC back to the 2-element semantic
313+
/// form. Deinterleaves into evens [d0, d2] and odds [d1, d3], then sums
314+
/// pairwise: [d0, d1, d2, d3] -> [d0+d1, d2+d3].
315+
Value collapseAccumulator(OpBuilder &builder, Location loc, Value acc);
316+
300317
//===----------------------------------------------------------------------===//
301318
// Implementations for operand promotion
302319
//===----------------------------------------------------------------------===//

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -376,10 +376,10 @@ def VMFMA_F32_32x32x16_F8E4M3FNUZ : I32EnumAttrCase<"VMFMA_F32_32x32x16_F8E4M3F
376376
def VDMFMA_F32_8x16x64_F16 : I32EnumAttrCase<"VDMFMA_F32_8x16x64_F16", 4>;
377377
def VDMFMA_I32_8x16x128_I8 : I32EnumAttrCase<"VDMFMA_I32_8x16x128_I8", 5>;
378378
def VDMFMA_F32_8x16x64_BF16 : I32EnumAttrCase<"VDMFMA_F32_8x16x64_BF16", 6>;
379-
def VDMFMA_F32_8x16x128_BF8 : I32EnumAttrCase<"VDMFMA_F32_8x16x128_BF8", 7>;
380-
def VDMFMA_F32_8x16x128_BF8_FP8 : I32EnumAttrCase<"VDMFMA_F32_8x16x128_BF8_FP8", 8>;
381-
def VDMFMA_F32_8x16x128_FP8_BF8 : I32EnumAttrCase<"VDMFMA_F32_8x16x128_FP8_BF8", 9>;
382-
def VDMFMA_F32_8x16x128_FP8 : I32EnumAttrCase<"VDMFMA_F32_8x16x128_FP8", 10>;
379+
def VDMFMA_F32_8x16x128_F8E5M2FNUZ : I32EnumAttrCase<"VDMFMA_F32_8x16x128_F8E5M2FNUZ", 7>;
380+
def VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ : I32EnumAttrCase<"VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ", 8>;
381+
def VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ : I32EnumAttrCase<"VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ", 9>;
382+
def VDMFMA_F32_8x16x128_F8E4M3FNUZ : I32EnumAttrCase<"VDMFMA_F32_8x16x128_F8E4M3FNUZ", 10>;
383383

384384
def IREEGPU_VirtualMMAIntrinsic : IREEGPU_I32EnumAttr<"VirtualMMAIntrinsic",
385385
"Descriptor for different Virtual MMA intrinsics", [
@@ -392,10 +392,10 @@ def IREEGPU_VirtualMMAIntrinsic : IREEGPU_I32EnumAttr<"VirtualMMAIntrinsic",
392392
VDMFMA_F32_8x16x64_BF16,
393393
// 8-bit VDMFMA variants.
394394
VDMFMA_I32_8x16x128_I8,
395-
VDMFMA_F32_8x16x128_BF8,
396-
VDMFMA_F32_8x16x128_BF8_FP8,
397-
VDMFMA_F32_8x16x128_FP8_BF8,
398-
VDMFMA_F32_8x16x128_FP8,
395+
VDMFMA_F32_8x16x128_F8E5M2FNUZ,
396+
VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ,
397+
VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ,
398+
VDMFMA_F32_8x16x128_F8E4M3FNUZ,
399399
]>;
400400

401401
// Enum for scaled mma intrinsic, loosely matching the MMAIntrinsic enum above

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_attrs.mlir

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,40 +64,40 @@ module {
6464
// CHECK-SAME: mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x64_BF16>
6565

6666
module {
67-
func.func @test_vdmfma_bf8_8x16x128() attributes {
68-
mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_BF8>} {
67+
func.func @test_vdmfma_f8E5M2FNUZ_8x16x128() attributes {
68+
mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_F8E5M2FNUZ>} {
6969
return
7070
}
7171
}
72-
// CHECK-LABEL: func @test_vdmfma_bf8_8x16x128
73-
// CHECK-SAME: mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_BF8>
72+
// CHECK-LABEL: func @test_vdmfma_f8E5M2FNUZ_8x16x128
73+
// CHECK-SAME: mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_F8E5M2FNUZ>
7474

7575
module {
76-
func.func @test_vdmfma_bf8_fp8_8x16x128() attributes {
77-
mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_BF8_FP8>} {
76+
func.func @test_vdmfma_f8E5M2FNUZ_f8E4M3FNUZ_8x16x128() attributes {
77+
mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ>} {
7878
return
7979
}
8080
}
81-
// CHECK-LABEL: func @test_vdmfma_bf8_fp8_8x16x128
82-
// CHECK-SAME: mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_BF8_FP8>
81+
// CHECK-LABEL: func @test_vdmfma_f8E5M2FNUZ_f8E4M3FNUZ_8x16x128
82+
// CHECK-SAME: mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_F8E5M2FNUZ_F8E4M3FNUZ>
8383

8484
module {
85-
func.func @test_vdmfma_fp8_bf8_8x16x128() attributes {
86-
mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_FP8_BF8>} {
85+
func.func @test_vdmfma_f8E4M3FNUZ_f8E5M2FNUZ_8x16x128() attributes {
86+
mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ>} {
8787
return
8888
}
8989
}
90-
// CHECK-LABEL: func @test_vdmfma_fp8_bf8_8x16x128
91-
// CHECK-SAME: mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_FP8_BF8>
90+
// CHECK-LABEL: func @test_vdmfma_f8E4M3FNUZ_f8E5M2FNUZ_8x16x128
91+
// CHECK-SAME: mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_F8E4M3FNUZ_F8E5M2FNUZ>
9292

9393
module {
94-
func.func @test_vdmfma_fp8_8x16x128() attributes {
95-
mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_FP8>} {
94+
func.func @test_vdmfma_f8E4M3FNUZ_8x16x128() attributes {
95+
mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_F8E4M3FNUZ>} {
9696
return
9797
}
9898
}
99-
// CHECK-LABEL: func @test_vdmfma_fp8_8x16x128
100-
// CHECK-SAME: mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_FP8>
99+
// CHECK-LABEL: func @test_vdmfma_f8E4M3FNUZ_8x16x128
100+
// CHECK-SAME: mma_types = #iree_gpu.virtual_mma_layout<VDMFMA_F32_8x16x128_F8E4M3FNUZ>
101101

102102
module {
103103
func.func @test_WMMAR3_f16_16x16x16_f32() attributes {

0 commit comments

Comments
 (0)