Skip to content

Commit c7877d4

Browse files
committed
more clean up
Signed-off-by: Eric Feng <Eric.Feng@amd.com>
1 parent 7742c66 commit c7877d4

1 file changed

Lines changed: 14 additions & 40 deletions

File tree

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp

Lines changed: 14 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -966,43 +966,24 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides(
966966

967967
// Build broadcastIndex from unused delinearize results. Unused results
968968
// (not mapped by dimToVtid) represent broadcast lanes that see duplicate
969-
// data. When a splitDim exists, we linearize these unused components into
970-
// a single index in [0, broadcastFactor) to differentiate those lanes.
969+
// data. When a splitDim exists, we use the single unused component as an
970+
// index in [0, broadcastFactor) to differentiate those lanes.
971971
Value broadcastIndex;
972972
if (splitDim) {
973973
llvm::SmallDenseSet<size_t, 4> usedResults(dimToVtid.begin(),
974974
dimToVtid.end());
975-
976-
// Collect unused basis entries and compute overflow size.
977-
int64_t unusedBasisProduct = 1;
978-
SmallVector<Value> unusedLaneIdComponents;
979-
SmallVector<int64_t> unusedBasisSizes;
975+
// Find the single unused delinearize result representing broadcast lanes.
976+
std::optional<size_t> unusedResultIdx;
980977
for (size_t i = 1, e = vtidBasis.size(); i <= e; ++i) {
981978
if (!usedResults.contains(i)) {
982-
unusedBasisProduct *= vtidBasis[i - 1];
983-
unusedLaneIdComponents.push_back(splitLaneId.getResult(i));
984-
unusedBasisSizes.push_back(vtidBasis[i - 1]);
979+
assert(!unusedResultIdx && "expected exactly one unused basis entry");
980+
unusedResultIdx = i;
985981
}
986982
}
987-
int64_t overflowSize = broadcastFactor / unusedBasisProduct;
988-
989-
// Collect in most-significant-first order: overflow, then basis entries.
990-
if (overflowSize > 1) {
991-
unusedLaneIdComponents.insert(unusedLaneIdComponents.begin(),
992-
splitLaneId.getResult(0));
993-
unusedBasisSizes.insert(unusedBasisSizes.begin(), overflowSize);
994-
}
995-
996-
assert(llvm::product_of(unusedBasisSizes) == broadcastFactor &&
997-
"unused basis sizes product must equal broadcast factor");
998-
999-
if (unusedLaneIdComponents.size() == 1) {
1000-
broadcastIndex = unusedLaneIdComponents[0];
1001-
} else if (unusedLaneIdComponents.size() > 1) {
1002-
broadcastIndex = affine::AffineLinearizeIndexOp::create(
1003-
builder, loc, unusedLaneIdComponents, unusedBasisSizes,
1004-
/*disjoint=*/true);
1005-
}
983+
assert(unusedResultIdx && "expected one unused basis entry for broadcast");
984+
assert(vtidBasis[*unusedResultIdx - 1] == broadcastFactor &&
985+
"unused basis size must equal broadcast factor");
986+
broadcastIndex = splitLaneId.getResult(*unusedResultIdx);
1006987
}
1007988

1008989
// Each thread grabs `element` contiguous data, so the vtid needs to be
@@ -1026,17 +1007,10 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides(
10261007
if (splitDim && dimIdx == *splitDim) {
10271008
// offset = vtid * element + broadcastIndex * perLaneElement.
10281009
int64_t perLaneElement = element / broadcastFactor;
1029-
if (perLaneElement != 1) {
1030-
vtid = affine::AffineLinearizeIndexOp::create(
1031-
builder, loc, ValueRange{vtid, broadcastIndex, cZero},
1032-
ArrayRef<int64_t>{vtidLen, broadcastFactor, perLaneElement},
1033-
/*disjoint=*/true);
1034-
} else {
1035-
vtid = affine::AffineLinearizeIndexOp::create(
1036-
builder, loc, ValueRange{vtid, broadcastIndex},
1037-
ArrayRef<int64_t>{vtidLen, broadcastFactor},
1038-
/*disjoint=*/true);
1039-
}
1010+
vtid = affine::AffineLinearizeIndexOp::create(
1011+
builder, loc, ValueRange{vtid, broadcastIndex, cZero},
1012+
ArrayRef<int64_t>{vtidLen, broadcastFactor, perLaneElement},
1013+
/*disjoint=*/true);
10401014
} else if (element != 1) {
10411015
vtid = affine::AffineLinearizeIndexOp::create(
10421016
builder, loc, ValueRange{vtid, cZero},

0 commit comments

Comments
 (0)