Skip to content

Commit dc2dfbc

Browse files
committed
[GPU] Update MemDescSubsliceOp verification to handle CTA dimensions; modify test cases for clarity
1 parent 52c1bff commit dc2dfbc

2 files changed

Lines changed: 15 additions & 9 deletions

File tree

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,19 +1017,25 @@ LogicalResult MemDescSubsliceOp::verify() {
10171017
for (auto d : standardOutDimNames(ctx, srcTy.getRank())) {
10181018
namedOffsets.push_back({d, 0});
10191019
}
1020-
for (int dimSize = dstTy.getDimSize(dim); dimSize < srcTy.getDimSize(dim);
1021-
dimSize *= 2) {
1022-
namedOffsets[dim] = {kDim, dimSize};
1020+
// Splitting at `dimSize` is valid as long as all points in [0, dimSize)
1021+
// stay within the same CTA.
1022+
for (int splitOffset = 0; splitOffset < dstTy.getDimSize(dim);
1023+
++splitOffset) {
1024+
namedOffsets[dim] = {kDim, splitOffset};
10231025
for (auto [inDim, val] : llInv.apply(namedOffsets)) {
1024-
if (inDim == kOffset && !llvm::isPowerOf2_32(val)) {
1025-
return emitError(
1026-
"We don't support splitting along the swizzling pattern");
1027-
}
10281026
if (inDim == kBlock && val != 0) {
10291027
return emitError("We don't support splitting along CTA dimensions");
10301028
}
10311029
}
10321030
}
1031+
for (int dimSize = dstTy.getDimSize(dim); dimSize < srcTy.getDimSize(dim);
1032+
dimSize *= 2) {
1033+
namedOffsets[dim] = {kDim, dimSize};
1034+
if (!llvm::isPowerOf2_32(llInv.apply(namedOffsets)[0].second)) {
1035+
return emitError(
1036+
"We don't support splitting along the swizzling pattern");
1037+
}
1038+
}
10331039
}
10341040
return success();
10351041
}

test/TritonGPU/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#smem = #ttg.shared_memory
55
module attributes {"ttg.num-ctas" = 2 : i32} {
66
tt.func public @subslice_non_broadcast_cga_dim(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) {
7-
// expected-error @+1 {{non-broadcast CGA dimensions}}
7+
// expected-error @+1 {{CTA dimensions}}
88
%a = ttg.memdesc_subslice %arg0 [0, 0] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<8x8xf32, #shared, #smem>
99
tt.return
1010
}
@@ -98,7 +98,7 @@ tt.func public @result_1d_to_1d(%arg0: !ttg.memdesc<8xf32, #shared, #smem>) {
9898

9999
// -----
100100

101-
#shared = #ttg.shared_linear<{offset = [[0, 1], [0, 2], [1, 0], [2, 0]], block = [[4, 0]]}, alignment = 16>
101+
#shared = #ttg.shared_linear<{offset = [[0, 1], [0, 2], [1, 0], [4, 0]], block = [[2, 0]]}, alignment = 16>
102102
#smem = #ttg.shared_memory
103103
module attributes {"ttg.num-ctas" = 2 : i32} {
104104
tt.func public @subview_split_on_cta_dim(%arg0: !ttg.memdesc<8x4xf32, #shared, #smem>) {

0 commit comments

Comments
 (0)