Skip to content

Commit cd7b094

Browse files
efricclaude
andcommitted
gap-aware element-wise split for lane loads
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b1d73d0 commit cd7b094

2 files changed

Lines changed: 69 additions & 10 deletions

File tree

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -908,7 +908,9 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides(
908908
ArrayRef<int64_t> permutation, MMASingleSubgroupLayout subgroupLayout,
909909
SmallVectorImpl<OpFoldResult> &canonicalOffsets,
910910
SmallVectorImpl<OpFoldResult> &canonicalSizes,
911-
SmallVectorImpl<OpFoldResult> &canonicalStrides) {
911+
SmallVectorImpl<OpFoldResult> &canonicalStrides,
912+
int64_t broadcastFactor = 1) {
913+
assert(broadcastFactor >= 1 && "broadcast factor must be at least 1");
912914
SmallVector<int64_t> rankReducedShape;
913915

914916
for (auto [outer, thread, element] :
@@ -946,6 +948,44 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides(
946948
SmallVector<Value> hintedSplitLaneId = createTransposeLoadIndexHint(
947949
builder, loc, splitLaneId.getResults(), vtidBasis);
948950

951+
// Find the unique element dimension eligible for lane differentiation:
952+
// exactly one dimension with element[dim] > 1 and divisible by
953+
// broadcastFactor.
954+
std::optional<size_t> splitDim;
955+
if (broadcastFactor > 1) {
956+
for (auto [dimIdx, element] : llvm::enumerate(subgroupLayout.element)) {
957+
if (element > 1 && element % broadcastFactor == 0) {
958+
if (splitDim) {
959+
splitDim = std::nullopt;
960+
break;
961+
}
962+
splitDim = dimIdx;
963+
}
964+
}
965+
}
966+
967+
// Build broadcastIndex from unused delinearize results. Unused results
968+
// (not mapped by dimToVtid) represent broadcast lanes that see duplicate
969+
// data. When a splitDim exists, we use the single unused component as an
970+
// index in [0, broadcastFactor) to differentiate those lanes.
971+
Value broadcastIndex;
972+
if (splitDim) {
973+
llvm::SmallDenseSet<size_t, 4> usedResults(dimToVtid.begin(),
974+
dimToVtid.end());
975+
// Find the single unused delinearize result representing broadcast lanes.
976+
std::optional<size_t> unusedResultIdx;
977+
for (size_t i = 1, e = vtidBasis.size(); i <= e; ++i) {
978+
if (!usedResults.contains(i)) {
979+
assert(!unusedResultIdx && "expected exactly one unused basis entry");
980+
unusedResultIdx = i;
981+
}
982+
}
983+
assert(unusedResultIdx && "expected one unused basis entry for broadcast");
984+
assert(vtidBasis[*unusedResultIdx - 1] == broadcastFactor &&
985+
"unused basis size must equal broadcast factor");
986+
broadcastIndex = splitLaneId.getResult(*unusedResultIdx);
987+
}
988+
949989
// Each thread grabs `element` contiguous data, so the vtid needs to be
950990
// multiplied by `element` to get the next bunch of data.
951991
// vtid: virtual thread id
@@ -955,11 +995,23 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides(
955995
// Instead of computing those maps, we use one big `delinearize` expression
956996
// in order to prevent unwanted "simplifications" on affine maps that
957997
// worsen the generated code quality.
958-
for (auto [splitResultIdx, element] :
959-
llvm::zip_equal(dimToVtid, subgroupLayout.element)) {
998+
//
999+
// When broadcastFactor > 1, the splitDim offset also incorporates
1000+
// broadcastIndex so each broadcast lane gets a disjoint slice instead.
1001+
for (auto [dimIdx, vtidAndElement] :
1002+
llvm::enumerate(llvm::zip_equal(dimToVtid, subgroupLayout.element))) {
1003+
auto [splitResultIdx, element] = vtidAndElement;
9601004
Value vtid = hintedSplitLaneId[splitResultIdx];
9611005
int64_t vtidLen = vtidBasis[splitResultIdx - 1];
962-
if (element != 1) {
1006+
1007+
if (splitDim && dimIdx == *splitDim) {
1008+
// offset = vtid * element + broadcastIndex * perLaneElement.
1009+
int64_t perLaneElement = element / broadcastFactor;
1010+
vtid = affine::AffineLinearizeIndexOp::create(
1011+
builder, loc, ValueRange{vtid, broadcastIndex, cZero},
1012+
ArrayRef<int64_t>{vtidLen, broadcastFactor, perLaneElement},
1013+
/*disjoint=*/true);
1014+
} else if (element != 1) {
9631015
vtid = affine::AffineLinearizeIndexOp::create(
9641016
builder, loc, ValueRange{vtid, cZero},
9651017
ArrayRef<int64_t>{vtidLen, element},
@@ -968,15 +1020,20 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides(
9681020
vtids.push_back(vtid);
9691021
}
9701022

971-
int64_t idx = 0;
972-
for (auto [element, outer] :
973-
llvm::zip_equal(subgroupLayout.element, subgroupLayout.outer)) {
1023+
int64_t vtidIdx = 0;
1024+
for (auto [dimIdx, elementAndOuter] : llvm::enumerate(
1025+
llvm::zip_equal(subgroupLayout.element, subgroupLayout.outer))) {
1026+
auto [element, outer] = elementAndOuter;
1027+
int64_t perLaneElement = element;
1028+
if (splitDim && dimIdx == *splitDim) {
1029+
perLaneElement = element / broadcastFactor;
1030+
}
9741031
if (outer != 1) {
9751032
canonicalSizes.push_back(builder.getIndexAttr(outer));
9761033
canonicalOffsets.push_back(zero);
9771034
}
978-
canonicalSizes.push_back(builder.getIndexAttr(element));
979-
canonicalOffsets.push_back(vtids[idx++]);
1035+
canonicalSizes.push_back(builder.getIndexAttr(perLaneElement));
1036+
canonicalOffsets.push_back(vtids[vtidIdx++]);
9801037
}
9811038
canonicalOffsets.assign(applyPermutation(canonicalOffsets, permutation));
9821039
canonicalSizes.assign(applyPermutation(canonicalSizes, permutation));

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ namespace mlir::iree_compiler::IREE::GPU {
7474
// almost always equal to subgroup size. If not, then it is a strict divisor
7575
// of subgroup size and that means that multiple threads get the exact same
7676
// data, i.e., there is an implied broadcasting, as will be seen in the
77-
// modulo (t % thread [0]) below.
77+
// modulo (t % thread [0]) below. Callers may opt to split an element
78+
// dimension by this factor to give each broadcast lane distinct data;
79+
// see broadcastFactor in populateCanonicalOffsetsSizesAndStrides.
7880
//
7981
// Detailed semantics: case of semantic rank 1
8082
// -------------------------------------------

0 commit comments

Comments
 (0)