Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,10 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides(
ArrayRef<int64_t> permutation, MMASingleSubgroupLayout subgroupLayout,
SmallVectorImpl<OpFoldResult> &canonicalOffsets,
SmallVectorImpl<OpFoldResult> &canonicalSizes,
SmallVectorImpl<OpFoldResult> &canonicalStrides) {
SmallVectorImpl<OpFoldResult> &canonicalStrides,
int64_t physicalLanesPerThread = 1) {
assert(physicalLanesPerThread >= 1 &&
"physicalLanesPerThread must be at least 1");
SmallVector<int64_t> rankReducedShape;

for (auto [outer, thread, element] :
Expand Down Expand Up @@ -946,6 +949,44 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides(
SmallVector<Value> hintedSplitLaneId = createTransposeLoadIndexHint(
builder, loc, splitLaneId.getResults(), vtidBasis);

// Find the unique element dimension eligible for lane differentiation:
// exactly one dimension with element[dim] > 1 and divisible by
// physicalLanesPerThread.
std::optional<size_t> splitDim;
if (physicalLanesPerThread > 1) {
for (auto [dimIdx, element] : llvm::enumerate(subgroupLayout.element)) {
if (element > 1 && element % physicalLanesPerThread == 0) {
if (splitDim) {
splitDim = std::nullopt;
break;
}
splitDim = dimIdx;
}
}
}

// Build laneGroupIndex from unused delinearize results. Unused results
// (not mapped by dimToVtid) represent grouped lanes that see duplicate
// data. When a splitDim exists, we use the single unused component as an
// index in [0, physicalLanesPerThread) to differentiate those lanes.
Value laneGroupIndex;
if (splitDim) {
llvm::SmallDenseSet<size_t, 4> usedResults(dimToVtid.begin(),
dimToVtid.end());
// Find the single unused delinearize result representing grouped lanes.
std::optional<size_t> unusedResultIdx;
for (size_t i = 1, e = vtidBasis.size(); i <= e; ++i) {
if (!usedResults.contains(i)) {
assert(!unusedResultIdx && "expected exactly one unused basis entry");
unusedResultIdx = i;
}
}
assert(unusedResultIdx && "expected one unused basis entry for lane group");
assert(vtidBasis[*unusedResultIdx - 1] == physicalLanesPerThread &&
"unused basis size must equal physicalLanesPerThread");
laneGroupIndex = splitLaneId.getResult(*unusedResultIdx);
}

// Each thread grabs `element` contiguous data, so the vtid needs to be
// multiplied by `element` to get the next bunch of data.
// vtid: virtual thread id
Expand All @@ -955,11 +996,23 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides(
// Instead of computing those maps, we use one big `delinearize` expression
// in order to prevent unwanted "simplifications" on affine maps that
// worsen the generated code quality.
for (auto [splitResultIdx, element] :
llvm::zip_equal(dimToVtid, subgroupLayout.element)) {
//
// When physicalLanesPerThread > 1, the splitDim offset also incorporates
// laneGroupIndex so each grouped lane gets a disjoint slice instead.
for (auto [dimIdx, vtidAndElement] :
llvm::enumerate(llvm::zip_equal(dimToVtid, subgroupLayout.element))) {
auto [splitResultIdx, element] = vtidAndElement;
Value vtid = hintedSplitLaneId[splitResultIdx];
int64_t vtidLen = vtidBasis[splitResultIdx - 1];
if (element != 1) {

if (splitDim && dimIdx == *splitDim) {
// offset = vtid * element + laneGroupIndex * perLaneElement.
int64_t perLaneElement = element / physicalLanesPerThread;
vtid = affine::AffineLinearizeIndexOp::create(
builder, loc, ValueRange{vtid, laneGroupIndex, cZero},
ArrayRef<int64_t>{vtidLen, physicalLanesPerThread, perLaneElement},
/*disjoint=*/true);
} else if (element != 1) {
vtid = affine::AffineLinearizeIndexOp::create(
builder, loc, ValueRange{vtid, cZero},
ArrayRef<int64_t>{vtidLen, element},
Expand All @@ -968,15 +1021,20 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides(
vtids.push_back(vtid);
}

int64_t idx = 0;
for (auto [element, outer] :
llvm::zip_equal(subgroupLayout.element, subgroupLayout.outer)) {
int64_t vtidIdx = 0;
for (auto [dimIdx, elementAndOuter] : llvm::enumerate(
llvm::zip_equal(subgroupLayout.element, subgroupLayout.outer))) {
auto [element, outer] = elementAndOuter;
int64_t perLaneElement = element;
if (splitDim && dimIdx == *splitDim) {
perLaneElement = element / physicalLanesPerThread;
}
if (outer != 1) {
canonicalSizes.push_back(builder.getIndexAttr(outer));
canonicalOffsets.push_back(zero);
}
canonicalSizes.push_back(builder.getIndexAttr(element));
canonicalOffsets.push_back(vtids[idx++]);
canonicalSizes.push_back(builder.getIndexAttr(perLaneElement));
canonicalOffsets.push_back(vtids[vtidIdx++]);
}
canonicalOffsets.assign(applyPermutation(canonicalOffsets, permutation));
canonicalSizes.assign(applyPermutation(canonicalSizes, permutation));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,15 @@ namespace mlir::iree_compiler::IREE::GPU {
// 2. The product of all the outer[i] times all the element[i] equals the
// length of the vector operand to the intrinsic. It is the number of
// elements that one intrinsic consumes on one thread.
// 3. The product of all the thread[i] is a divisor of subgroup size. It is
// almost always equal to subgroup size. If not, then it is a strict divisor
// of subgroup size and that means that multiple threads get the exact same
// data, i.e., there is an implied broadcasting, as will be seen in the
// modulo (t % thread [0]) below.
// 3. The product of all the thread[i] is a divisor of subgroup size. Let
// physicalLanesPerThread = subgroupSize / product(thread[i]). It is
// almost always 1. If not, then it is greater than 1 and that means
// that multiple threads get the exact same data, i.e., there is an
// implied broadcasting, as will be seen in the modulo (t % thread[0])
// below. When greater than 1, that many physical lanes share the same
// position in the thread[i] decomposition. Which specific lanes are
// grouped is determined by tstrides. The element dimension may be split
// by this factor so each grouped lane loads a disjoint slice.
//
// Detailed semantics: case of semantic rank 1
// -------------------------------------------
Expand Down
Loading