Skip to content

Commit ab5cfdf

Browse files
Bump LLVM to llvm-project@d48c8411dfe6 (#24267)
Carries fixes from upstream LLVM commit ef739b97b108 ([AMDGPU] Correct gfx950 smfmac sparse index verifier, llvm/llvm-project#193541) tightened verification of amdgpu.sparse_mfma sparse index types based on m/k/element-type: - gfx942 16-bit (k=32): vector<4xi8> - gfx950 16-bit / gfx942 8-bit (k=64): vector<2xi16> - gfx950 8-bit (k=128): i32 (hardware ignores CBSZ/ABID) The two CDNA4 (gfx950) x1 VDMFMA lowerings were using the wrong types: - VDMFMA_F32_8x16x64x1_{F16,BF16}: was vector<4xi8>, now vector<2xi16> with selector bits 0x4444/0xEEEE (4 groups per i16 element). - VDMFMA_{I32,F32}_8x16x128x1_*: was vector<2xi16>, now i32 scalar with all 8 groups packed: 0x44444444/0xEEEEEEEE. Also generalizes VDMFMAConfig::sparseIndexType from VectorType to Type and updates createConstSparseIndex to handle both the vector and scalar i32 cases. Carrying local patch for stablehlo because of llvm-project@1823355d06b8 - iree-org/stablehlo@fb869da --------- Co-authored-by: Claude Sonnet 4 (1M context) <noreply@anthropic.com>
1 parent 3490788 commit ab5cfdf

2 files changed

Lines changed: 33 additions & 29 deletions

File tree

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1798,27 +1798,32 @@ static Value createLaneParityPredicate(OpBuilder &builder, Location loc) {
17981798
zero);
17991799
}
18001800

1801-
// Creates a constant sparse index vector for SMFMAC operations.
1801+
// Creates a constant sparse index for SMFMAC operations.
18021802
//
18031803
// The sparse index encodes which 2 positions out of each group of 4
18041804
// K-elements are selected for 2:4 structured sparsity. Each 4-bit
18051805
// field within selectorBits selects positions for one K-group:
18061806
// 0x4 (0100b) -> positions {0,1}; 0xE (1110b) -> positions {2,3}.
18071807
//
1808-
// For 16-bit source data (f16/bf16): vector<4xi8>, 2 groups per i8.
1809-
// For 8-bit source data (i8/f8*): vector<2xi16>, 4 groups per i16.
1808+
// gfx942 16-bit (k=32): vector<4xi8>, 2 groups per i8.
1809+
// gfx950 16-bit / gfx942 8-bit (k=64): vector<2xi16>, 4 groups per i16.
1810+
// gfx950 8-bit (k=128): i32 scalar, 8 groups packed into 32 bits.
18101811
//
1811-
// Only the first element carries active selector bits; remaining
1812-
// elements are padding zeros.
1812+
// For vector types, only the first element carries active selector bits;
1813+
// remaining elements are padding zeros.
18131814
static Value createConstSparseIndex(OpBuilder &builder, Location loc,
1814-
VectorType sparseIndexVectorType,
1815-
int64_t selectorBits) {
1816-
Type elemTy = sparseIndexVectorType.getElementType();
1817-
Value zero = arith::ConstantOp::create(
1818-
builder, loc, builder.getZeroAttr(sparseIndexVectorType));
1819-
Value selector = arith::ConstantOp::create(
1820-
builder, loc, builder.getIntegerAttr(elemTy, selectorBits));
1821-
return vector::InsertOp::create(builder, loc, selector, zero, 0);
1815+
Type sparseIndexType,
1816+
uint32_t selectorBits) {
1817+
if (auto vecTy = dyn_cast<VectorType>(sparseIndexType)) {
1818+
Type elemTy = vecTy.getElementType();
1819+
Value zero =
1820+
arith::ConstantOp::create(builder, loc, builder.getZeroAttr(vecTy));
1821+
Value selector = arith::ConstantOp::create(
1822+
builder, loc, builder.getIntegerAttr(elemTy, selectorBits));
1823+
return vector::InsertOp::create(builder, loc, selector, zero, 0);
1824+
}
1825+
return arith::ConstantOp::create(
1826+
builder, loc, builder.getIntegerAttr(sparseIndexType, selectorBits));
18221827
}
18231828

18241829
// Returns the number of native intrinsics chained along K per virtual
@@ -1869,9 +1874,9 @@ int64_t VirtualMMAAttr::getIntrinsicsK() const {
18691874
struct VDMFMAConfig {
18701875
int64_t m, n, nativeK;
18711876
int64_t unrollFactor;
1872-
VectorType sparseIndexVectorType;
1873-
int64_t evenSparseIndex;
1874-
int64_t oddSparseIndex;
1877+
Type sparseIndexType;
1878+
uint32_t evenSparseIndex;
1879+
uint32_t oddSparseIndex;
18751880
int64_t aSliceWidth;
18761881
};
18771882

@@ -1987,9 +1992,9 @@ static LogicalResult buildVDMFMAOps(OpBuilder &builder, Location loc,
19871992

19881993
Value sparseIndex = arith::SelectOp::create(
19891994
builder, loc, isOddLane,
1990-
createConstSparseIndex(builder, loc, config.sparseIndexVectorType,
1995+
createConstSparseIndex(builder, loc, config.sparseIndexType,
19911996
config.oddSparseIndex),
1992-
createConstSparseIndex(builder, loc, config.sparseIndexVectorType,
1997+
createConstSparseIndex(builder, loc, config.sparseIndexType,
19931998
config.evenSparseIndex));
19941999

19952000
Value lhs = inputs[0];
@@ -2104,7 +2109,7 @@ LogicalResult VirtualMMAAttr::buildUnderlyingOperations(
21042109
/*n=*/16,
21052110
/*nativeK=*/32,
21062111
/*unrollFactor=*/getIntrinsicsK(),
2107-
/*sparseIndexVectorType=*/
2112+
/*sparseIndexType=*/
21082113
VectorType::get({4}, builder.getIntegerType(8)),
21092114
/*evenSparseIndex=*/0x44,
21102115
/*oddSparseIndex=*/0xEE,
@@ -2123,7 +2128,7 @@ LogicalResult VirtualMMAAttr::buildUnderlyingOperations(
21232128
/*n=*/16,
21242129
/*nativeK=*/64,
21252130
/*unrollFactor=*/getIntrinsicsK(),
2126-
/*sparseIndexVectorType=*/
2131+
/*sparseIndexType=*/
21272132
VectorType::get({2}, builder.getIntegerType(16)),
21282133
/*evenSparseIndex=*/0x4444,
21292134
/*oddSparseIndex=*/0xEEEE,
@@ -2140,10 +2145,10 @@ LogicalResult VirtualMMAAttr::buildUnderlyingOperations(
21402145
/*n=*/16,
21412146
/*nativeK=*/64,
21422147
/*unrollFactor=*/getIntrinsicsK(),
2143-
/*sparseIndexVectorType=*/
2144-
VectorType::get({4}, builder.getIntegerType(8)),
2145-
/*evenSparseIndex=*/0x44,
2146-
/*oddSparseIndex=*/0xEE,
2148+
/*sparseIndexType=*/
2149+
VectorType::get({2}, builder.getIntegerType(16)),
2150+
/*evenSparseIndex=*/0x4444,
2151+
/*oddSparseIndex=*/0xEEEE,
21472152
/*aSliceWidth=*/8};
21482153
return buildVDMFMAOps(builder, loc, config, inputs, outputs[0], results);
21492154
}
@@ -2160,10 +2165,9 @@ LogicalResult VirtualMMAAttr::buildUnderlyingOperations(
21602165
/*n=*/16,
21612166
/*nativeK=*/128,
21622167
/*unrollFactor=*/getIntrinsicsK(),
2163-
/*sparseIndexVectorType=*/
2164-
VectorType::get({2}, builder.getIntegerType(16)),
2165-
/*evenSparseIndex=*/0x4444,
2166-
/*oddSparseIndex=*/0xEEEE,
2168+
/*sparseIndexType=*/builder.getI32Type(),
2169+
/*evenSparseIndex=*/0x44444444,
2170+
/*oddSparseIndex=*/0xEEEEEEEE,
21672171
/*aSliceWidth=*/16};
21682172
return buildVDMFMAOps(builder, loc, config, inputs, outputs[0], results);
21692173
}

third_party/llvm-project

Submodule llvm-project updated 1165 files

0 commit comments

Comments
 (0)