@@ -908,7 +908,9 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides(
908908 ArrayRef<int64_t > permutation, MMASingleSubgroupLayout subgroupLayout,
909909 SmallVectorImpl<OpFoldResult> &canonicalOffsets,
910910 SmallVectorImpl<OpFoldResult> &canonicalSizes,
911- SmallVectorImpl<OpFoldResult> &canonicalStrides) {
911+ SmallVectorImpl<OpFoldResult> &canonicalStrides,
912+ int64_t broadcastFactor = 1 ) {
913+ assert (broadcastFactor >= 1 && " broadcast factor must be at least 1" );
912914 SmallVector<int64_t > rankReducedShape;
913915
914916 for (auto [outer, thread, element] :
@@ -946,6 +948,44 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides(
946948 SmallVector<Value> hintedSplitLaneId = createTransposeLoadIndexHint (
947949 builder, loc, splitLaneId.getResults (), vtidBasis);
948950
951+ // Find the unique element dimension eligible for lane differentiation:
952+ // exactly one dimension with element[dim] > 1 and divisible by
953+ // broadcastFactor.
954+ std::optional<size_t > splitDim;
955+ if (broadcastFactor > 1 ) {
956+ for (auto [dimIdx, element] : llvm::enumerate (subgroupLayout.element )) {
957+ if (element > 1 && element % broadcastFactor == 0 ) {
958+ if (splitDim) {
959+ splitDim = std::nullopt ;
960+ break ;
961+ }
962+ splitDim = dimIdx;
963+ }
964+ }
965+ }
966+
967+ // Build broadcastIndex from unused delinearize results. Unused results
968+ // (not mapped by dimToVtid) represent broadcast lanes that see duplicate
969+ // data. When a splitDim exists, we use the single unused component as an
970+ // index in [0, broadcastFactor) to differentiate those lanes.
971+ Value broadcastIndex;
972+ if (splitDim) {
973+ llvm::SmallDenseSet<size_t , 4 > usedResults (dimToVtid.begin (),
974+ dimToVtid.end ());
975+ // Find the single unused delinearize result representing broadcast lanes.
976+ std::optional<size_t > unusedResultIdx;
977+ for (size_t i = 1 , e = vtidBasis.size (); i <= e; ++i) {
978+ if (!usedResults.contains (i)) {
979+ assert (!unusedResultIdx && " expected exactly one unused basis entry" );
980+ unusedResultIdx = i;
981+ }
982+ }
983+ assert (unusedResultIdx && " expected one unused basis entry for broadcast" );
984+ assert (vtidBasis[*unusedResultIdx - 1 ] == broadcastFactor &&
985+ " unused basis size must equal broadcast factor" );
986+ broadcastIndex = splitLaneId.getResult (*unusedResultIdx);
987+ }
988+
949989 // Each thread grabs `element` contiguous data, so the vtid needs to be
950990 // multiplied by `element` to get the next bunch of data.
951991 // vtid: virtual thread id
@@ -955,11 +995,23 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides(
955995 // Instead of computing those maps, we use one big `delinearize` expression
956996 // in order to prevent unwanted "simplifications" on affine maps that
957997 // worsen the generated code quality.
958- for (auto [splitResultIdx, element] :
959- llvm::zip_equal (dimToVtid, subgroupLayout.element )) {
998+ //
999+ // When broadcastFactor > 1, the splitDim offset also incorporates
1000+ // broadcastIndex so each broadcast lane gets a disjoint slice instead.
1001+ for (auto [dimIdx, vtidAndElement] :
1002+ llvm::enumerate (llvm::zip_equal (dimToVtid, subgroupLayout.element ))) {
1003+ auto [splitResultIdx, element] = vtidAndElement;
9601004 Value vtid = hintedSplitLaneId[splitResultIdx];
9611005 int64_t vtidLen = vtidBasis[splitResultIdx - 1 ];
962- if (element != 1 ) {
1006+
1007+ if (splitDim && dimIdx == *splitDim) {
1008+ // offset = vtid * element + broadcastIndex * perLaneElement.
1009+ int64_t perLaneElement = element / broadcastFactor;
1010+ vtid = affine::AffineLinearizeIndexOp::create (
1011+ builder, loc, ValueRange{vtid, broadcastIndex, cZero},
1012+ ArrayRef<int64_t >{vtidLen, broadcastFactor, perLaneElement},
1013+ /* disjoint=*/ true );
1014+ } else if (element != 1 ) {
9631015 vtid = affine::AffineLinearizeIndexOp::create (
9641016 builder, loc, ValueRange{vtid, cZero},
9651017 ArrayRef<int64_t >{vtidLen, element},
@@ -968,15 +1020,20 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides(
9681020 vtids.push_back (vtid);
9691021 }
9701022
971- int64_t idx = 0 ;
972- for (auto [element, outer] :
973- llvm::zip_equal (subgroupLayout.element , subgroupLayout.outer )) {
1023+ int64_t vtidIdx = 0 ;
1024+ for (auto [dimIdx, elementAndOuter] : llvm::enumerate (
1025+ llvm::zip_equal (subgroupLayout.element , subgroupLayout.outer ))) {
1026+ auto [element, outer] = elementAndOuter;
1027+ int64_t perLaneElement = element;
1028+ if (splitDim && dimIdx == *splitDim) {
1029+ perLaneElement = element / broadcastFactor;
1030+ }
9741031 if (outer != 1 ) {
9751032 canonicalSizes.push_back (builder.getIndexAttr (outer));
9761033 canonicalOffsets.push_back (zero);
9771034 }
978- canonicalSizes.push_back (builder.getIndexAttr (element ));
979- canonicalOffsets.push_back (vtids[idx ++]);
1035+ canonicalSizes.push_back (builder.getIndexAttr (perLaneElement ));
1036+ canonicalOffsets.push_back (vtids[vtidIdx ++]);
9801037 }
9811038 canonicalOffsets.assign (applyPermutation (canonicalOffsets, permutation));
9821039 canonicalSizes.assign (applyPermutation (canonicalSizes, permutation));
0 commit comments