@@ -1798,27 +1798,32 @@ static Value createLaneParityPredicate(OpBuilder &builder, Location loc) {
17981798 zero);
17991799}
18001800
1801- // Creates a constant sparse index vector for SMFMAC operations.
1801+ // Creates a constant sparse index for SMFMAC operations.
18021802//
18031803// The sparse index encodes which 2 positions out of each group of 4
18041804// K-elements are selected for 2:4 structured sparsity. Each 4-bit
18051805// field within selectorBits selects positions for one K-group:
18061806// 0x4 (0100b) -> positions {0,1}; 0xE (1110b) -> positions {2,3}.
18071807//
1808- // For 16-bit source data (f16/bf16): vector<4xi8>, 2 groups per i8.
1809- // For 8-bit source data (i8/f8*): vector<2xi16>, 4 groups per i16.
1808+ // gfx942 16-bit (k=32): vector<4xi8>, 2 groups per i8.
1809+ // gfx950 16-bit / gfx942 8-bit (k=64): vector<2xi16>, 4 groups per i16.
1810+ // gfx950 8-bit (k=128): i32 scalar, 8 groups packed into 32 bits.
18101811//
1811- // Only the first element carries active selector bits; remaining
1812- // elements are padding zeros.
1812+ // For vector types, only the first element carries active selector bits;
1813+ // remaining elements are padding zeros.
18131814static Value createConstSparseIndex (OpBuilder &builder, Location loc,
1814- VectorType sparseIndexVectorType,
1815- int64_t selectorBits) {
1816- Type elemTy = sparseIndexVectorType.getElementType ();
1817- Value zero = arith::ConstantOp::create (
1818- builder, loc, builder.getZeroAttr (sparseIndexVectorType));
1819- Value selector = arith::ConstantOp::create (
1820- builder, loc, builder.getIntegerAttr (elemTy, selectorBits));
1821- return vector::InsertOp::create (builder, loc, selector, zero, 0 );
1815+ Type sparseIndexType,
1816+ uint32_t selectorBits) {
1817+ if (auto vecTy = dyn_cast<VectorType>(sparseIndexType)) {
1818+ Type elemTy = vecTy.getElementType ();
1819+ Value zero =
1820+ arith::ConstantOp::create (builder, loc, builder.getZeroAttr (vecTy));
1821+ Value selector = arith::ConstantOp::create (
1822+ builder, loc, builder.getIntegerAttr (elemTy, selectorBits));
1823+ return vector::InsertOp::create (builder, loc, selector, zero, 0 );
1824+ }
1825+ return arith::ConstantOp::create (
1826+ builder, loc, builder.getIntegerAttr (sparseIndexType, selectorBits));
18221827}
18231828
18241829// Returns the number of native intrinsics chained along K per virtual
@@ -1869,9 +1874,9 @@ int64_t VirtualMMAAttr::getIntrinsicsK() const {
18691874struct VDMFMAConfig {
18701875 int64_t m, n, nativeK;
18711876 int64_t unrollFactor;
1872- VectorType sparseIndexVectorType ;
1873- int64_t evenSparseIndex;
1874- int64_t oddSparseIndex;
1877+ Type sparseIndexType ;
1878+ uint32_t evenSparseIndex;
1879+ uint32_t oddSparseIndex;
18751880 int64_t aSliceWidth;
18761881};
18771882
@@ -1987,9 +1992,9 @@ static LogicalResult buildVDMFMAOps(OpBuilder &builder, Location loc,
19871992
19881993 Value sparseIndex = arith::SelectOp::create (
19891994 builder, loc, isOddLane,
1990- createConstSparseIndex (builder, loc, config.sparseIndexVectorType ,
1995+ createConstSparseIndex (builder, loc, config.sparseIndexType ,
19911996 config.oddSparseIndex ),
1992- createConstSparseIndex (builder, loc, config.sparseIndexVectorType ,
1997+ createConstSparseIndex (builder, loc, config.sparseIndexType ,
19931998 config.evenSparseIndex ));
19941999
19952000 Value lhs = inputs[0 ];
@@ -2104,7 +2109,7 @@ LogicalResult VirtualMMAAttr::buildUnderlyingOperations(
21042109 /* n=*/ 16 ,
21052110 /* nativeK=*/ 32 ,
21062111 /* unrollFactor=*/ getIntrinsicsK (),
2107- /* sparseIndexVectorType =*/
2112+ /* sparseIndexType =*/
21082113 VectorType::get ({4 }, builder.getIntegerType (8 )),
21092114 /* evenSparseIndex=*/ 0x44 ,
21102115 /* oddSparseIndex=*/ 0xEE ,
@@ -2123,7 +2128,7 @@ LogicalResult VirtualMMAAttr::buildUnderlyingOperations(
21232128 /* n=*/ 16 ,
21242129 /* nativeK=*/ 64 ,
21252130 /* unrollFactor=*/ getIntrinsicsK (),
2126- /* sparseIndexVectorType =*/
2131+ /* sparseIndexType =*/
21272132 VectorType::get ({2 }, builder.getIntegerType (16 )),
21282133 /* evenSparseIndex=*/ 0x4444 ,
21292134 /* oddSparseIndex=*/ 0xEEEE ,
@@ -2140,10 +2145,10 @@ LogicalResult VirtualMMAAttr::buildUnderlyingOperations(
21402145 /* n=*/ 16 ,
21412146 /* nativeK=*/ 64 ,
21422147 /* unrollFactor=*/ getIntrinsicsK (),
2143- /* sparseIndexVectorType =*/
2144- VectorType::get ({4 }, builder.getIntegerType (8 )),
2145- /* evenSparseIndex=*/ 0x44 ,
2146- /* oddSparseIndex=*/ 0xEE ,
2148+ /* sparseIndexType =*/
2149+ VectorType::get ({2 }, builder.getIntegerType (16 )),
2150+ /* evenSparseIndex=*/ 0x4444 ,
2151+ /* oddSparseIndex=*/ 0xEEEE ,
21472152 /* aSliceWidth=*/ 8 };
21482153 return buildVDMFMAOps (builder, loc, config, inputs, outputs[0 ], results);
21492154 }
@@ -2160,10 +2165,9 @@ LogicalResult VirtualMMAAttr::buildUnderlyingOperations(
21602165 /* n=*/ 16 ,
21612166 /* nativeK=*/ 128 ,
21622167 /* unrollFactor=*/ getIntrinsicsK (),
2163- /* sparseIndexVectorType=*/
2164- VectorType::get ({2 }, builder.getIntegerType (16 )),
2165- /* evenSparseIndex=*/ 0x4444 ,
2166- /* oddSparseIndex=*/ 0xEEEE ,
2168+ /* sparseIndexType=*/ builder.getI32Type (),
2169+ /* evenSparseIndex=*/ 0x44444444 ,
2170+ /* oddSparseIndex=*/ 0xEEEEEEEE ,
21672171 /* aSliceWidth=*/ 16 };
21682172 return buildVDMFMAOps (builder, loc, config, inputs, outputs[0 ], results);
21692173 }
0 commit comments