Skip to content

Commit 6786088

Browse files
authored
Verify allowReorder reshapes (#9905)
allowReorder reshapes still have a restriction that they cannot imply moving elements between threads and warps. Now that inferring the encoding is guaranteed to produce the given hint encoding if it is valid, we can check this with the existing verification code.
1 parent 5f96878 commit 6786088

6 files changed

Lines changed: 38 additions & 37 deletions

File tree

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,10 +307,7 @@ struct ReshapeOpConversion : public ConvertOpToLLVMPattern<ReshapeOp> {
307307
matchAndRewrite(ReshapeOp op, OpAdaptor adaptor,
308308
ConversionPatternRewriter &rewriter) const override {
309309
Location loc = op->getLoc();
310-
if (triton::gpu::isExpensiveView(op.getSrc().getType(), op.getType())) {
311-
return emitOptionalError(loc,
312-
"expensive view not supported on reshape op");
313-
}
310+
assert(!isExpensiveView(op.getSrc().getType(), op.getType()));
314311
auto resultTy = cast<RankedTensorType>(op.getType());
315312
auto srcTy = cast<RankedTensorType>(op.getSrc().getType());
316313
auto typeConverter = getTypeConverter();

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -913,18 +913,20 @@ LogicalResult ReshapeOp::verify() {
913913
"encodings, or (b) neither does.");
914914
}
915915

916-
if (!srcEnc || getAllowReorder()) {
916+
if (!srcEnc) {
917917
return success();
918918
}
919919

920-
// Check that we can infer the dst encoding from the src encoding
921-
// and that the inferred dst encoding is the same as the given dst encoding
922-
Attribute inferredDstEnc;
920+
// Check that we can infer the dst encoding from the src encoding and that the
921+
// inferred dst encoding is the same as the given dst encoding. We pass the
922+
// current dst encoding as a hint so that allowReorder reshapes are guaranteed
923+
// to produce the current encoding iff it is valid.
924+
Attribute inferredDstEnc = dstEnc;
923925
auto layoutInterface =
924926
cast<DialectInferLayoutInterface>(&srcEnc.getDialect());
925927
auto result = layoutInterface->inferReshapeOpEncoding(
926928
srcTy.getShape(), srcEnc, dstTy.getShape(), inferredDstEnc,
927-
/*allowReorder=*/false, getLoc());
929+
getAllowReorder(), getLoc());
928930
if (failed(result))
929931
return failure();
930932
return layoutInterface->verifyLayoutsAreEqual(

test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -248,26 +248,27 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
248248

249249
// -----
250250

251+
#blockedsrc = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}>
251252
#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
252253
#blockedtrans = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
253-
#blocked1 = #ttg.slice<{dim=0, parent=#blocked}>
254-
#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
254+
#blocked1 = #ttg.slice<{dim=0, parent=#blockedsrc}>
255+
#blocked2 = #ttg.slice<{dim=0, parent=#blockedtrans}>
255256
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
256257
// COMMON-LABEL: unary_triton_ops_transitive_nonneg
257258
tt.func @unary_triton_ops_transitive_nonneg(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
258259
%c10_i32 = arith.constant 5 : i32
259260
%0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked1>
260-
%1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32, #blocked1> -> tensor<1x16xi32, #blocked>
261-
%2 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blocked> -> tensor<8x2xi32, #blocked>
262-
%3 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blocked> -> tensor<2x8xi32, #blocked>
263-
%4 = tt.trans %3 {order = array<i32: 1, 0>} : tensor<2x8xi32, #blocked> -> tensor<8x2xi32, #blockedtrans>
264-
%5 = ttg.convert_layout %4 : tensor<8x2xi32, #blockedtrans> -> tensor<8x2xi32, #blocked>
261+
%1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32, #blocked1> -> tensor<1x16xi32, #blockedsrc>
262+
%2 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blockedsrc> -> tensor<8x2xi32, #blocked>
263+
%3 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blockedsrc> -> tensor<2x8xi32, #blockedtrans>
264+
%4 = tt.trans %3 {order = array<i32: 1, 0>} : tensor<2x8xi32, #blockedtrans> -> tensor<8x2xi32, #blocked>
265+
%5 = ttg.convert_layout %4 : tensor<8x2xi32, #blocked> -> tensor<8x2xi32, #blocked>
265266
%6 = arith.addi %5, %2 : tensor<8x2xi32, #blocked>
266267
%7 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32, #blocked2>
267-
%8 = ttg.convert_layout %7 : tensor<8xi32, #blocked2> -> tensor<8xi32, #blocked1>
268-
%9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<8xi32, #blocked1> -> tensor<1x8xi32, #blocked>
269-
%10 = tt.broadcast %9 : tensor<1x8xi32, #blocked> -> tensor<2x8xi32, #blocked>
270-
%11 = tt.reshape %10 allow_reorder : tensor<2x8xi32, #blocked> -> tensor<8x2xi32, #blocked>
268+
%8 = ttg.convert_layout %7 : tensor<8xi32, #blocked2> -> tensor<8xi32, #blocked2>
269+
%9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<8xi32, #blocked2> -> tensor<1x8xi32, #blockedtrans>
270+
%10 = tt.broadcast %9 : tensor<1x8xi32, #blockedtrans> -> tensor<2x8xi32, #blockedtrans>
271+
%11 = tt.reshape %10 allow_reorder : tensor<2x8xi32, #blockedtrans> -> tensor<8x2xi32, #blocked>
271272
%12 = tt.splat %c10_i32 : i32 -> tensor<8x2xi32, #blocked>
272273
%13 = arith.addi %11, %12 : tensor<8x2xi32, #blocked>
273274
%14 = arith.minsi %13, %5 : tensor<8x2xi32, #blocked>
@@ -293,7 +294,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
293294
// -----
294295

295296

296-
#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
297+
#blocked = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
297298
#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
298299
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
299300
// COMMON-LABEL: join_cat_transitive_nonneg

test/TritonGPU/amd/amd-convert-buffer-ops.mlir

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -288,26 +288,27 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
288288

289289
// -----
290290

291+
#blockedsrc = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}>
291292
#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
292293
#blockedtrans = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
293-
#blocked1 = #ttg.slice<{dim=0, parent=#blocked}>
294-
#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
294+
#blocked1 = #ttg.slice<{dim=0, parent=#blockedsrc}>
295+
#blocked2 = #ttg.slice<{dim=0, parent=#blockedtrans}>
295296
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
296297
// COMMON-LABEL: unary_triton_ops_transitive_nonneg
297298
tt.func @unary_triton_ops_transitive_nonneg(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
298299
%c10_i32 = arith.constant 5 : i32
299300
%0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked1>
300-
%1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32, #blocked1> -> tensor<1x16xi32, #blocked>
301-
%2 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blocked> -> tensor<8x2xi32, #blocked>
302-
%3 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blocked> -> tensor<2x8xi32, #blocked>
303-
%4 = tt.trans %3 {order = array<i32: 1, 0>} : tensor<2x8xi32, #blocked> -> tensor<8x2xi32, #blockedtrans>
304-
%5 = ttg.convert_layout %4 : tensor<8x2xi32, #blockedtrans> -> tensor<8x2xi32, #blocked>
301+
%1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32, #blocked1> -> tensor<1x16xi32, #blockedsrc>
302+
%2 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blockedsrc> -> tensor<8x2xi32, #blocked>
303+
%3 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blockedsrc> -> tensor<2x8xi32, #blockedtrans>
304+
%4 = tt.trans %3 {order = array<i32: 1, 0>} : tensor<2x8xi32, #blockedtrans> -> tensor<8x2xi32, #blocked>
305+
%5 = ttg.convert_layout %4 : tensor<8x2xi32, #blocked> -> tensor<8x2xi32, #blocked>
305306
%6 = arith.addi %5, %2 : tensor<8x2xi32, #blocked>
306307
%7 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32, #blocked2>
307-
%8 = ttg.convert_layout %7 : tensor<8xi32, #blocked2> -> tensor<8xi32, #blocked1>
308-
%9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<8xi32, #blocked1> -> tensor<1x8xi32, #blocked>
309-
%10 = tt.broadcast %9 : tensor<1x8xi32, #blocked> -> tensor<2x8xi32, #blocked>
310-
%11 = tt.reshape %10 allow_reorder : tensor<2x8xi32, #blocked> -> tensor<8x2xi32, #blocked>
308+
%8 = ttg.convert_layout %7 : tensor<8xi32, #blocked2> -> tensor<8xi32, #blocked2>
309+
%9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<8xi32, #blocked2> -> tensor<1x8xi32, #blockedtrans>
310+
%10 = tt.broadcast %9 : tensor<1x8xi32, #blockedtrans> -> tensor<2x8xi32, #blockedtrans>
311+
%11 = tt.reshape %10 allow_reorder : tensor<2x8xi32, #blockedtrans> -> tensor<8x2xi32, #blocked>
311312
%12 = tt.splat %c10_i32 : i32 -> tensor<8x2xi32, #blocked>
312313
%13 = arith.addi %11, %12 : tensor<8x2xi32, #blocked>
313314
%14 = arith.minsi %13, %5 : tensor<8x2xi32, #blocked>
@@ -333,7 +334,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
333334
// -----
334335

335336

336-
#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
337+
#blocked = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
337338
#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
338339
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
339340
// COMMON-LABEL: join_cat_transitive_nonneg

test/TritonGPU/canonicalize.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
// CHECK: tt.return %[[V]]
99
#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
1010
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
11-
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1]}>
11+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [0, 1]}>
1212

1313
module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} {
1414
tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> {
@@ -68,7 +68,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
6868
// CHECK: tt.return %[[V]]
6969
#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
7070
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
71-
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1]}>
71+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [0, 1]}>
7272

7373
module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} {
7474
tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> {

test/TritonGPU/combine.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2199,8 +2199,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
21992199
// -----
22002200

22012201
#blocked = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}>
2202-
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
2203-
#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
2202+
#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
2203+
#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
22042204

22052205
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
22062206
// CHECK-LABEL: @permuting_reshape_propagate

0 commit comments

Comments
 (0)