Skip to content

Commit ba1ed62

Browse files
authored
[RELAND] Verify allowReorder reshapes (#9998)
Reland of #9905. allowReorder reshapes still have a restriction that they cannot imply moving elements between threads and warps. Now that inferring the encoding is guaranteed to produce the given hint encoding if it is valid, we can check this with the existing verification code.
1 parent 87c7072 commit ba1ed62

6 files changed

Lines changed: 38 additions & 37 deletions

File tree

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,7 @@ struct ReshapeOpConversion : public ConvertOpToLLVMPattern<ReshapeOp> {
299299
matchAndRewrite(ReshapeOp op, OpAdaptor adaptor,
300300
ConversionPatternRewriter &rewriter) const override {
301301
Location loc = op->getLoc();
302-
if (triton::gpu::isExpensiveView(op.getSrc().getType(), op.getType())) {
303-
return emitOptionalError(loc,
304-
"expensive view not supported on reshape op");
305-
}
302+
assert(!isExpensiveView(op.getSrc().getType(), op.getType()));
306303
auto resultTy = cast<RankedTensorType>(op.getType());
307304
auto typeConverter = getTypeConverter();
308305
auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter);

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -958,18 +958,20 @@ LogicalResult ReshapeOp::verify() {
958958
"encodings, or (b) neither does.");
959959
}
960960

961-
if (!srcEnc || getAllowReorder()) {
961+
if (!srcEnc) {
962962
return success();
963963
}
964964

965-
// Check that we can infer the dst encoding from the src encoding
966-
// and that the inferred dst encoding is the same as the given dst encoding
967-
Attribute inferredDstEnc;
965+
// Check that we can infer the dst encoding from the src encoding and that the
966+
// inferred dst encoding is the same as the given dst encoding. We pass the
967+
// current dst encoding as a hint so that allowReorder reshapes are guaranteed
968+
// to produce the current encoding iff it is valid.
969+
Attribute inferredDstEnc = dstEnc;
968970
auto layoutInterface =
969971
cast<DialectInferLayoutInterface>(&srcEnc.getDialect());
970972
auto result = layoutInterface->inferReshapeOpEncoding(
971973
srcTy.getShape(), srcEnc, dstTy.getShape(), inferredDstEnc,
972-
/*allowReorder=*/false, getLoc());
974+
getAllowReorder(), getLoc());
973975
if (failed(result))
974976
return failure();
975977
return layoutInterface->verifyLayoutsAreEqual(

test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -248,26 +248,27 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
248248

249249
// -----
250250

251+
#blockedsrc = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}>
251252
#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
252253
#blockedtrans = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
253-
#blocked1 = #ttg.slice<{dim=0, parent=#blocked}>
254-
#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
254+
#blocked1 = #ttg.slice<{dim=0, parent=#blockedsrc}>
255+
#blocked2 = #ttg.slice<{dim=0, parent=#blockedtrans}>
255256
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
256257
// COMMON-LABEL: unary_triton_ops_transitive_nonneg
257258
tt.func @unary_triton_ops_transitive_nonneg(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
258259
%c10_i32 = arith.constant 5 : i32
259260
%0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked1>
260-
%1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32, #blocked1> -> tensor<1x16xi32, #blocked>
261-
%2 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blocked> -> tensor<8x2xi32, #blocked>
262-
%3 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blocked> -> tensor<2x8xi32, #blocked>
263-
%4 = tt.trans %3 {order = array<i32: 1, 0>} : tensor<2x8xi32, #blocked> -> tensor<8x2xi32, #blockedtrans>
264-
%5 = ttg.convert_layout %4 : tensor<8x2xi32, #blockedtrans> -> tensor<8x2xi32, #blocked>
261+
%1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32, #blocked1> -> tensor<1x16xi32, #blockedsrc>
262+
%2 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blockedsrc> -> tensor<8x2xi32, #blocked>
263+
%3 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blockedsrc> -> tensor<2x8xi32, #blockedtrans>
264+
%4 = tt.trans %3 {order = array<i32: 1, 0>} : tensor<2x8xi32, #blockedtrans> -> tensor<8x2xi32, #blocked>
265+
%5 = ttg.convert_layout %4 : tensor<8x2xi32, #blocked> -> tensor<8x2xi32, #blocked>
265266
%6 = arith.addi %5, %2 : tensor<8x2xi32, #blocked>
266267
%7 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32, #blocked2>
267-
%8 = ttg.convert_layout %7 : tensor<8xi32, #blocked2> -> tensor<8xi32, #blocked1>
268-
%9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<8xi32, #blocked1> -> tensor<1x8xi32, #blocked>
269-
%10 = tt.broadcast %9 : tensor<1x8xi32, #blocked> -> tensor<2x8xi32, #blocked>
270-
%11 = tt.reshape %10 allow_reorder : tensor<2x8xi32, #blocked> -> tensor<8x2xi32, #blocked>
268+
%8 = ttg.convert_layout %7 : tensor<8xi32, #blocked2> -> tensor<8xi32, #blocked2>
269+
%9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<8xi32, #blocked2> -> tensor<1x8xi32, #blockedtrans>
270+
%10 = tt.broadcast %9 : tensor<1x8xi32, #blockedtrans> -> tensor<2x8xi32, #blockedtrans>
271+
%11 = tt.reshape %10 allow_reorder : tensor<2x8xi32, #blockedtrans> -> tensor<8x2xi32, #blocked>
271272
%12 = tt.splat %c10_i32 : i32 -> tensor<8x2xi32, #blocked>
272273
%13 = arith.addi %11, %12 : tensor<8x2xi32, #blocked>
273274
%14 = arith.minsi %13, %5 : tensor<8x2xi32, #blocked>
@@ -293,7 +294,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
293294
// -----
294295

295296

296-
#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
297+
#blocked = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
297298
#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
298299
#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
299300
#blocked3 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>

test/TritonGPU/amd/amd-convert-buffer-ops.mlir

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -288,26 +288,27 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
288288

289289
// -----
290290

291+
#blockedsrc = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}>
291292
#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
292293
#blockedtrans = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
293-
#blocked1 = #ttg.slice<{dim=0, parent=#blocked}>
294-
#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
294+
#blocked1 = #ttg.slice<{dim=0, parent=#blockedsrc}>
295+
#blocked2 = #ttg.slice<{dim=0, parent=#blockedtrans}>
295296
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
296297
// COMMON-LABEL: unary_triton_ops_transitive_nonneg
297298
tt.func @unary_triton_ops_transitive_nonneg(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
298299
%c10_i32 = arith.constant 5 : i32
299300
%0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked1>
300-
%1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32, #blocked1> -> tensor<1x16xi32, #blocked>
301-
%2 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blocked> -> tensor<8x2xi32, #blocked>
302-
%3 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blocked> -> tensor<2x8xi32, #blocked>
303-
%4 = tt.trans %3 {order = array<i32: 1, 0>} : tensor<2x8xi32, #blocked> -> tensor<8x2xi32, #blockedtrans>
304-
%5 = ttg.convert_layout %4 : tensor<8x2xi32, #blockedtrans> -> tensor<8x2xi32, #blocked>
301+
%1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32, #blocked1> -> tensor<1x16xi32, #blockedsrc>
302+
%2 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blockedsrc> -> tensor<8x2xi32, #blocked>
303+
%3 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blockedsrc> -> tensor<2x8xi32, #blockedtrans>
304+
%4 = tt.trans %3 {order = array<i32: 1, 0>} : tensor<2x8xi32, #blockedtrans> -> tensor<8x2xi32, #blocked>
305+
%5 = ttg.convert_layout %4 : tensor<8x2xi32, #blocked> -> tensor<8x2xi32, #blocked>
305306
%6 = arith.addi %5, %2 : tensor<8x2xi32, #blocked>
306307
%7 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32, #blocked2>
307-
%8 = ttg.convert_layout %7 : tensor<8xi32, #blocked2> -> tensor<8xi32, #blocked1>
308-
%9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<8xi32, #blocked1> -> tensor<1x8xi32, #blocked>
309-
%10 = tt.broadcast %9 : tensor<1x8xi32, #blocked> -> tensor<2x8xi32, #blocked>
310-
%11 = tt.reshape %10 allow_reorder : tensor<2x8xi32, #blocked> -> tensor<8x2xi32, #blocked>
308+
%8 = ttg.convert_layout %7 : tensor<8xi32, #blocked2> -> tensor<8xi32, #blocked2>
309+
%9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<8xi32, #blocked2> -> tensor<1x8xi32, #blockedtrans>
310+
%10 = tt.broadcast %9 : tensor<1x8xi32, #blockedtrans> -> tensor<2x8xi32, #blockedtrans>
311+
%11 = tt.reshape %10 allow_reorder : tensor<2x8xi32, #blockedtrans> -> tensor<8x2xi32, #blocked>
311312
%12 = tt.splat %c10_i32 : i32 -> tensor<8x2xi32, #blocked>
312313
%13 = arith.addi %11, %12 : tensor<8x2xi32, #blocked>
313314
%14 = arith.minsi %13, %5 : tensor<8x2xi32, #blocked>
@@ -333,7 +334,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
333334
// -----
334335

335336

336-
#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
337+
#blocked = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
337338
#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
338339
#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
339340
#blocked3 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>

test/TritonGPU/canonicalize.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
// CHECK: tt.return %[[V]]
99
#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
1010
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
11-
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1]}>
11+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [0, 1]}>
1212

1313
module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} {
1414
tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> {
@@ -68,7 +68,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
6868
// CHECK: tt.return %[[V]]
6969
#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
7070
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
71-
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1]}>
71+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [0, 1]}>
7272

7373
module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} {
7474
tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> {

test/TritonGPU/combine.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2199,8 +2199,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
21992199
// -----
22002200

22012201
#blocked = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}>
2202-
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
2203-
#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
2202+
#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
2203+
#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
22042204

22052205
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
22062206
// CHECK-LABEL: @permuting_reshape_propagate

0 commit comments

Comments
 (0)