Skip to content

Commit 441faac

Browse files
authored
[RELAND] Infer src/dst of allowReorder reshape (#9997)
Reland of #9926. Always infer the src/dst of reshapes, even if allowReorder is set. The result is valid for allowReorder reshapes, even if there isn't a single canonical encoding. When the existing encoding is one of the possible results, we prefer that to minimize changes. This allows inference to always succeed on reshapes, and any heuristics on whether to use the inferred value can be maintained by the caller. One example I identified while looking at this was that allowReorder reshapes will currently fail backward remat in RemoveLayoutConversions if the reshape cannot be rematerialised with the same source encoding. This PR instead changes RemoveLayoutConversions to check specifically for whether the reshape has been marked as efficient, and otherwise just do the remat. (this is a potentially perf sensitive change)
1 parent ab1f012 commit 441faac

File tree

9 files changed

+81
-29
lines changed

9 files changed

+81
-29
lines changed

include/triton/Dialect/Triton/IR/Dialect.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,13 @@ class DialectInferLayoutInterface
5959
// makes the reshape a "nop", i.e. the same GPU threads contain the same
6060
// elements as before the reshape using legacy layouts. This is not always
6161
// possible (in which case we fallback to using LinearLayouts)
62+
// If allowReorder is set, an existing value in dstEnc is preferred when it
63+
// still yields a non-expensive view.
6264
// In the future we'll always use LinearLayouts
6365
virtual LogicalResult
6466
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
6567
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
68+
bool allowReorder,
6669
std::optional<Location> loc) const = 0;
6770

6871
// Check if two layouts are structurally the same, even if their names are

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,14 @@ SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
264264
bool isLegalCatEncoding(CatOp cat, Attribute targetEncoding);
265265

266266
// Return true if a view between the two types cannot be implemented as a no-op.
267-
bool isExpensiveView(Type srcType, Type dstType);
267+
bool isExpensiveView(ArrayRef<int64_t> srcShape, Attribute srcEncoding,
268+
ArrayRef<int64_t> dstShape, Attribute dstEncoding);
269+
inline bool isExpensiveView(Type srcType, Type dstType) {
270+
auto tensorSrcType = cast<RankedTensorType>(srcType);
271+
auto tensorDstType = cast<RankedTensorType>(dstType);
272+
return isExpensiveView(tensorSrcType.getShape(), tensorSrcType.getEncoding(),
273+
tensorDstType.getShape(), tensorDstType.getEncoding());
274+
}
268275

269276
// Return a blocked encoding where the shape is distributed contiguously amongst
270277
// the threads, warps, CTAs with 1 element per threads.

lib/Dialect/Gluon/IR/Dialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ struct GluonInferLayoutInterface : public triton::DialectInferLayoutInterface {
7878

7979
LogicalResult
8080
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
81-
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
81+
ArrayRef<int64_t> dstShape, Attribute &dstEnc, bool,
8282
std::optional<Location> loc) const override {
8383
return inferAutoEncoding(srcEnc, dstEnc);
8484
}

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -895,9 +895,10 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &state,
895895
auto srcEnc = srcTy.getEncoding();
896896
Attribute dstEnc;
897897
if (srcEnc) {
898-
auto result = cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
899-
->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, shape,
900-
dstEnc, state.location);
898+
auto result =
899+
cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
900+
->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, shape, dstEnc,
901+
allowReorder, state.location);
901902
assert(succeeded(result));
902903
}
903904
auto dstTy = RankedTensorType::get(shape, srcTy.getElementType(), dstEnc);
@@ -967,7 +968,8 @@ LogicalResult ReshapeOp::verify() {
967968
auto layoutInterface =
968969
cast<DialectInferLayoutInterface>(&srcEnc.getDialect());
969970
auto result = layoutInterface->inferReshapeOpEncoding(
970-
srcTy.getShape(), srcEnc, dstTy.getShape(), inferredDstEnc, getLoc());
971+
srcTy.getShape(), srcEnc, dstTy.getShape(), inferredDstEnc,
972+
/*allowReorder=*/false, getLoc());
971973
if (failed(result))
972974
return failure();
973975
return layoutInterface->verifyLayoutsAreEqual(

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,10 @@ SmallVector<unsigned> getContigPerThread(RankedTensorType type) {
114114
return toLinearEncoding(type).getContigPerThread();
115115
}
116116

117-
bool isExpensiveView(Type srcType, Type dstType) {
118-
auto tensorSrcType = cast<RankedTensorType>(srcType);
119-
auto tensorDstType = cast<RankedTensorType>(dstType);
120-
auto llSrc = toLinearLayout(tensorSrcType);
121-
auto llDst = toLinearLayout(tensorDstType);
117+
bool isExpensiveView(ArrayRef<int64_t> srcShape, Attribute srcEncoding,
118+
ArrayRef<int64_t> dstShape, Attribute dstEncoding) {
119+
auto llSrc = toLinearLayout(srcShape, srcEncoding);
120+
auto llDst = toLinearLayout(dstShape, dstEncoding);
122121
// In case there are replicated value we need to make sure the new and old
123122
// layout have matching masks.
124123
for (auto [srcMask, dstMask] :
@@ -127,7 +126,8 @@ bool isExpensiveView(Type srcType, Type dstType) {
127126
if (srcMask.second != dstMask.second)
128127
return true;
129128
}
130-
return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType);
129+
return getTotalElemsPerThread(srcEncoding, srcShape) !=
130+
getTotalElemsPerThread(dstEncoding, dstShape);
131131
}
132132

133133
/* Utility function used by get.*Order methods of SliceEncodingAttr.
@@ -3309,11 +3309,17 @@ struct TritonGPUInferLayoutInterface
33093309
LogicalResult
33103310
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
33113311
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
3312+
bool allowReorder,
33123313
std::optional<Location> loc) const override {
33133314
if (product(srcShape) != product(dstShape)) {
33143315
return emitOptionalError(loc, "numel of dst shape does not match "
33153316
"numel of src shape");
33163317
}
3318+
// If allowReorder is true, there are multiple valid encodings. Prefer the
3319+
// hint if it is set and valid.
3320+
if (allowReorder && dstEnc)
3321+
if (!isExpensiveView(srcShape, srcEnc, dstShape, dstEnc))
3322+
return success();
33173323
auto result =
33183324
inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc);
33193325
if (succeeded(result)) {

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,8 @@ bool canBeRemat(Operation *op) {
657657
return false;
658658
if (auto gather = dyn_cast<GatherOp>(op))
659659
return !gather.getEfficientLayout();
660+
if (auto reshape = dyn_cast<ReshapeOp>(op))
661+
return !reshape.getEfficientLayout();
660662

661663
if (isa<scf::WhileOp, scf::ConditionOp>(op))
662664
return false;

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -465,26 +465,22 @@ static Attribute inferSrcEncoding(triton::TransposeOpInterface op,
465465
static Attribute inferReshapeOpDstEncoding(ArrayRef<int64_t> srcShape,
466466
Attribute srcEnc,
467467
ArrayRef<int64_t> dstShape,
468-
bool allowReorder) {
469-
// We don't do anything smart to allow-reorder reshapes here. They are
470-
// handled in OptimizeThreadLocality.
471-
if (allowReorder)
472-
return {};
473-
474-
Attribute dstEnc;
468+
Attribute dstEncHint = {},
469+
bool allowReorder = false) {
470+
Attribute dstEnc = dstEncHint;
475471
auto result =
476472
srcEnc.getDialect()
477473
.getRegisteredInterface<triton::DialectInferLayoutInterface>()
478474
->inferReshapeOpEncoding(srcShape, srcEnc, dstShape, dstEnc,
479-
/*loc=*/std::nullopt);
475+
allowReorder, /*loc=*/std::nullopt);
480476
assert(succeeded(result));
481477
return dstEnc;
482478
}
483479

484480
static Attribute inferDstEncoding(triton::ReshapeOp op, Attribute encoding) {
485-
return inferReshapeOpDstEncoding(op.getSrc().getType().getShape(), encoding,
486-
op.getType().getShape(),
487-
op.getAllowReorder());
481+
return inferReshapeOpDstEncoding(
482+
op.getSrc().getType().getShape(), encoding, op.getType().getShape(),
483+
op.getType().getEncoding(), op.getAllowReorder());
488484
}
489485

490486
static Attribute inferDstEncoding(GatherOp op, Attribute encoding) {
@@ -499,9 +495,9 @@ static Attribute inferSrcEncoding(triton::ReshapeOp op, Attribute encoding) {
499495
// as the encoding of x given the encoding of y in `reshape(y) -> x`. It's an
500496
// invariant of inferReshapeOpNoReorderEncoding that it's symmetric in this
501497
// way.
502-
return inferReshapeOpDstEncoding(op.getType().getShape(), encoding,
503-
op.getSrc().getType().getShape(),
504-
op.getAllowReorder());
498+
return inferReshapeOpDstEncoding(
499+
op.getType().getShape(), encoding, op.getSrc().getType().getShape(),
500+
op.getSrc().getType().getEncoding(), op.getAllowReorder());
505501
}
506502

507503
static bool isSingleValue(Value value) {

test/TritonGPU/combine.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2217,6 +2217,41 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
22172217

22182218
// -----
22192219

2220+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}>
2221+
#blocked3 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
2222+
#blocked4 = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
2223+
#blocked1 = #ttg.slice<{dim = 0, parent = #blocked}>
2224+
2225+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
2226+
// CHECK-LABEL: @permuting_reshape_backward_remat
2227+
// CHECK-NOT: ttg.convert_layout
2228+
// CHECK: tt.return
2229+
tt.func public @permuting_reshape_backward_remat(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}) -> tensor<8x2xi32, #blocked3> {
2230+
%0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked1>
2231+
%1 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<16x!tt.ptr<i32>, #blocked1>
2232+
%2 = tt.addptr %1, %0 : tensor<16x!tt.ptr<i32>, #blocked1>, tensor<16xi32, #blocked1>
2233+
%3 = tt.load %2 : tensor<16x!tt.ptr<i32>, #blocked1>
2234+
%4 = tt.reshape %3 allow_reorder : tensor<16xi32, #blocked1> -> tensor<8x2xi32, #blocked4>
2235+
%5 = ttg.convert_layout %4 : tensor<8x2xi32, #blocked4> -> tensor<8x2xi32, #blocked3>
2236+
tt.return %5 : tensor<8x2xi32, #blocked3>
2237+
}
2238+
2239+
// CHECK-LABEL: @permuting_reshape_no_backward_remat_efficient_layout
2240+
// CHECK: ttg.convert_layout
2241+
// CHECK: tt.return
2242+
tt.func public @permuting_reshape_no_backward_remat_efficient_layout(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}) -> tensor<8x2xi32, #blocked3> {
2243+
%0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked1>
2244+
%1 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<16x!tt.ptr<i32>, #blocked1>
2245+
%2 = tt.addptr %1, %0 : tensor<16x!tt.ptr<i32>, #blocked1>, tensor<16xi32, #blocked1>
2246+
%3 = tt.load %2 : tensor<16x!tt.ptr<i32>, #blocked1>
2247+
%4 = tt.reshape %3 allow_reorder efficient_layout : tensor<16xi32, #blocked1> -> tensor<8x2xi32, #blocked4>
2248+
%5 = ttg.convert_layout %4 : tensor<8x2xi32, #blocked4> -> tensor<8x2xi32, #blocked3>
2249+
tt.return %5 : tensor<8x2xi32, #blocked3>
2250+
}
2251+
}
2252+
2253+
// -----
2254+
22202255
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
22212256
#slice1dim1 = #ttg.slice<{dim = 1, parent = #blocked1}>
22222257
#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>

unittest/Dialect/TritonGPU/DialectTest.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ void testReshape(RankedTensorType srcTy, RankedTensorType dstTy,
139139
ctx, [&](Diagnostic &diag) { diags.push_back(" - " + diag.str()); });
140140
result = inferLayout->inferReshapeOpEncoding(
141141
srcTy.getShape(), srcTy.getEncoding(), dstTy.getShape(), inferredEnc,
142-
UnknownLoc::get(ctx));
142+
/*allowReorder=*/false, UnknownLoc::get(ctx));
143143
}
144144

145145
// We expect the reshape to succeed as long as the inputs have the same
@@ -164,7 +164,7 @@ void testReshape(RankedTensorType srcTy, RankedTensorType dstTy,
164164
Attribute inferredSrcEnc;
165165
auto result = inferLayout->inferReshapeOpEncoding(
166166
dstTy.getShape(), inferredEnc, srcTy.getShape(), inferredSrcEnc,
167-
UnknownLoc::get(ctx));
167+
/*allowReorder=*/false, UnknownLoc::get(ctx));
168168
EXPECT_TRUE(succeeded(result))
169169
<< "Inverse encoding inference (" << triton::join(dstTy.getShape(), "x")
170170
<< " " << stringifyLLVMType(inferredEnc) << " -> "
@@ -439,7 +439,8 @@ TEST_F(JoinOpTest, JoinOpLayoutPropagation) {
439439
}
440440
Attribute reshapedEnc;
441441
result = inferLayout->inferReshapeOpEncoding(
442-
transShape, transEnc, newShape, reshapedEnc, std::nullopt);
442+
transShape, transEnc, newShape, reshapedEnc,
443+
/*allowReorder=*/false, std::nullopt);
443444
assert(succeeded(result));
444445
// The layouts should be structurally the same
445446
// but reshapeEnc will likely be a LinearEncodingAttr

0 commit comments

Comments
 (0)