@@ -966,43 +966,24 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides(
966966
967967 // Build broadcastIndex from unused delinearize results. Unused results
968968 // (not mapped by dimToVtid) represent broadcast lanes that see duplicate
969- // data. When a splitDim exists, we linearize these unused components into
970- // a single index in [0, broadcastFactor) to differentiate those lanes.
969+ // data. When a splitDim exists, we use the single unused component as an
970+ // index in [0, broadcastFactor) to differentiate those lanes.
971971 Value broadcastIndex;
972972 if (splitDim) {
973973 llvm::SmallDenseSet<size_t , 4 > usedResults (dimToVtid.begin (),
974974 dimToVtid.end ());
975-
976- // Collect unused basis entries and compute overflow size.
977- int64_t unusedBasisProduct = 1 ;
978- SmallVector<Value> unusedLaneIdComponents;
979- SmallVector<int64_t > unusedBasisSizes;
975+ // Find the single unused delinearize result representing broadcast lanes.
976+ std::optional<size_t > unusedResultIdx;
980977 for (size_t i = 1 , e = vtidBasis.size (); i <= e; ++i) {
981978 if (!usedResults.contains (i)) {
982- unusedBasisProduct *= vtidBasis[i - 1 ];
983- unusedLaneIdComponents.push_back (splitLaneId.getResult (i));
984- unusedBasisSizes.push_back (vtidBasis[i - 1 ]);
979+ assert (!unusedResultIdx && " expected exactly one unused basis entry" );
980+ unusedResultIdx = i;
985981 }
986982 }
987- int64_t overflowSize = broadcastFactor / unusedBasisProduct;
988-
989- // Collect in most-significant-first order: overflow, then basis entries.
990- if (overflowSize > 1 ) {
991- unusedLaneIdComponents.insert (unusedLaneIdComponents.begin (),
992- splitLaneId.getResult (0 ));
993- unusedBasisSizes.insert (unusedBasisSizes.begin (), overflowSize);
994- }
995-
996- assert (llvm::product_of (unusedBasisSizes) == broadcastFactor &&
997- " unused basis sizes product must equal broadcast factor" );
998-
999- if (unusedLaneIdComponents.size () == 1 ) {
1000- broadcastIndex = unusedLaneIdComponents[0 ];
1001- } else if (unusedLaneIdComponents.size () > 1 ) {
1002- broadcastIndex = affine::AffineLinearizeIndexOp::create (
1003- builder, loc, unusedLaneIdComponents, unusedBasisSizes,
1004- /* disjoint=*/ true );
1005- }
983+ assert (unusedResultIdx && " expected one unused basis entry for broadcast" );
984+ assert (vtidBasis[*unusedResultIdx - 1 ] == broadcastFactor &&
985+ " unused basis size must equal broadcast factor" );
986+ broadcastIndex = splitLaneId.getResult (*unusedResultIdx);
1006987 }
1007988
1008989 // Each thread grabs `element` contiguous data, so the vtid needs to be
@@ -1026,17 +1007,10 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides(
10261007 if (splitDim && dimIdx == *splitDim) {
10271008 // offset = vtid * element + broadcastIndex * perLaneElement.
10281009 int64_t perLaneElement = element / broadcastFactor;
1029- if (perLaneElement != 1 ) {
1030- vtid = affine::AffineLinearizeIndexOp::create (
1031- builder, loc, ValueRange{vtid, broadcastIndex, cZero},
1032- ArrayRef<int64_t >{vtidLen, broadcastFactor, perLaneElement},
1033- /* disjoint=*/ true );
1034- } else {
1035- vtid = affine::AffineLinearizeIndexOp::create (
1036- builder, loc, ValueRange{vtid, broadcastIndex},
1037- ArrayRef<int64_t >{vtidLen, broadcastFactor},
1038- /* disjoint=*/ true );
1039- }
1010+ vtid = affine::AffineLinearizeIndexOp::create (
1011+ builder, loc, ValueRange{vtid, broadcastIndex, cZero},
1012+ ArrayRef<int64_t >{vtidLen, broadcastFactor, perLaneElement},
1013+ /* disjoint=*/ true );
10401014 } else if (element != 1 ) {
10411015 vtid = affine::AffineLinearizeIndexOp::create (
10421016 builder, loc, ValueRange{vtid, cZero},
0 commit comments