Skip to content

Commit 5f96878

Browse files
authored
Infer src/dst of allowReorder reshapes (#9926)
Always infer the src/dst of reshapes, even if allowReorder is set. The result is valid for allowReorder reshapes, even if there isn't a single canonical encoding. When the existing encoding is one of the possible results, we prefer that to minimize changes. This allows inference to always succeed on reshapes, and any heuristics on whether to use the inferred value can be maintained by the caller. One example I identified while looking at this was that allowReorder reshapes will currently fail backward remat in RemoveLayoutConversions if the reshape cannot be rematerialised with the same source encoding. This PR instead changes RemoveLayoutConversions to check specifically for whether the reshape has been marked as efficient, and otherwise just do the remat. (this is a potentially perf sensitive change)
1 parent 42e54cb commit 5f96878

9 files changed

Lines changed: 81 additions & 29 deletions

File tree

include/triton/Dialect/Triton/IR/Dialect.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,13 @@ class DialectInferLayoutInterface
5959
// makes the reshape a "nop", i.e. the same GPU threads contain the same
6060
// elements as before the reshape using legacy layouts. This is not always
6161
// possible (in which case we fallback to using LinearLayouts)
62+
// If allowReorder is set, an existing value in dstEnc is preferred when it
63+
// still yields a non-expensive view.
6264
// In the future we'll always use LinearLayouts
6365
virtual LogicalResult
6466
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
6567
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
68+
bool allowReorder,
6669
std::optional<Location> loc) const = 0;
6770

6871
// Check if two layouts are structurally the same, even if their names are

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,14 @@ SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
263263
bool isExpensiveCat(CatOp cat, Attribute targetEncoding);
264264

265265
// Return true if a view between the two types cannot be implemented as a no-op.
266-
bool isExpensiveView(Type srcType, Type dstType);
266+
bool isExpensiveView(ArrayRef<int64_t> srcShape, Attribute srcEncoding,
267+
ArrayRef<int64_t> dstShape, Attribute dstEncoding);
268+
inline bool isExpensiveView(Type srcType, Type dstType) {
269+
auto tensorSrcType = cast<RankedTensorType>(srcType);
270+
auto tensorDstType = cast<RankedTensorType>(dstType);
271+
return isExpensiveView(tensorSrcType.getShape(), tensorSrcType.getEncoding(),
272+
tensorDstType.getShape(), tensorDstType.getEncoding());
273+
}
267274

268275
// Return a blocked encoding where the shape is distributed contiguously amongst
269276
// the threads, warps, CTAs with 1 element per threads.

lib/Dialect/Gluon/IR/Dialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ struct GluonInferLayoutInterface : public triton::DialectInferLayoutInterface {
7474

7575
LogicalResult
7676
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
77-
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
77+
ArrayRef<int64_t> dstShape, Attribute &dstEnc, bool,
7878
std::optional<Location> loc) const override {
7979
return inferAutoEncoding(srcEnc, dstEnc);
8080
}

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -850,9 +850,10 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &state,
850850
auto srcEnc = srcTy.getEncoding();
851851
Attribute dstEnc;
852852
if (srcEnc) {
853-
auto result = cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
854-
->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, shape,
855-
dstEnc, state.location);
853+
auto result =
854+
cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
855+
->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, shape, dstEnc,
856+
allowReorder, state.location);
856857
assert(succeeded(result));
857858
}
858859
auto dstTy = RankedTensorType::get(shape, srcTy.getElementType(), dstEnc);
@@ -922,7 +923,8 @@ LogicalResult ReshapeOp::verify() {
922923
auto layoutInterface =
923924
cast<DialectInferLayoutInterface>(&srcEnc.getDialect());
924925
auto result = layoutInterface->inferReshapeOpEncoding(
925-
srcTy.getShape(), srcEnc, dstTy.getShape(), inferredDstEnc, getLoc());
926+
srcTy.getShape(), srcEnc, dstTy.getShape(), inferredDstEnc,
927+
/*allowReorder=*/false, getLoc());
926928
if (failed(result))
927929
return failure();
928930
return layoutInterface->verifyLayoutsAreEqual(

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,10 @@ SmallVector<unsigned> getContigPerThread(RankedTensorType type) {
114114
return toLinearEncoding(type).getContigPerThread();
115115
}
116116

117-
bool isExpensiveView(Type srcType, Type dstType) {
118-
auto tensorSrcType = cast<RankedTensorType>(srcType);
119-
auto tensorDstType = cast<RankedTensorType>(dstType);
120-
auto llSrc = toLinearLayout(tensorSrcType);
121-
auto llDst = toLinearLayout(tensorDstType);
117+
bool isExpensiveView(ArrayRef<int64_t> srcShape, Attribute srcEncoding,
118+
ArrayRef<int64_t> dstShape, Attribute dstEncoding) {
119+
auto llSrc = toLinearLayout(srcShape, srcEncoding);
120+
auto llDst = toLinearLayout(dstShape, dstEncoding);
122121
// In case there are replicated value we need to make sure the new and old
123122
// layout have matching masks.
124123
for (auto [srcMask, dstMask] :
@@ -127,7 +126,8 @@ bool isExpensiveView(Type srcType, Type dstType) {
127126
if (srcMask.second != dstMask.second)
128127
return true;
129128
}
130-
return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType);
129+
return getTotalElemsPerThread(srcEncoding, srcShape) !=
130+
getTotalElemsPerThread(dstEncoding, dstShape);
131131
}
132132

133133
/* Utility function used by get.*Order methods of SliceEncodingAttr.
@@ -3285,11 +3285,17 @@ struct TritonGPUInferLayoutInterface
32853285
LogicalResult
32863286
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
32873287
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
3288+
bool allowReorder,
32883289
std::optional<Location> loc) const override {
32893290
if (product(srcShape) != product(dstShape)) {
32903291
return emitOptionalError(loc, "numel of dst shape does not match "
32913292
"numel of src shape");
32923293
}
3294+
// If allowReorder is true, there are multiple valid encodings. Prefer the
3295+
// hint if it is set and valid.
3296+
if (allowReorder && dstEnc)
3297+
if (!isExpensiveView(srcShape, srcEnc, dstShape, dstEnc))
3298+
return success();
32933299
auto result =
32943300
inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc);
32953301
if (succeeded(result)) {

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,8 @@ bool canBeRemat(Operation *op) {
657657
return false;
658658
if (auto gather = dyn_cast<GatherOp>(op))
659659
return !gather.getEfficientLayout();
660+
if (auto reshape = dyn_cast<ReshapeOp>(op))
661+
return !reshape.getEfficientLayout();
660662

661663
if (isa<scf::WhileOp, scf::ConditionOp>(op))
662664
return false;

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -465,26 +465,22 @@ static Attribute inferSrcEncoding(triton::TransposeOpInterface op,
465465
static Attribute inferReshapeOpDstEncoding(ArrayRef<int64_t> srcShape,
466466
Attribute srcEnc,
467467
ArrayRef<int64_t> dstShape,
468-
bool allowReorder) {
469-
// We don't do anything smart to allow-reorder reshapes here. They are
470-
// handled in OptimizeThreadLocality.
471-
if (allowReorder)
472-
return {};
473-
474-
Attribute dstEnc;
468+
Attribute dstEncHint = {},
469+
bool allowReorder = false) {
470+
Attribute dstEnc = dstEncHint;
475471
auto result =
476472
srcEnc.getDialect()
477473
.getRegisteredInterface<triton::DialectInferLayoutInterface>()
478474
->inferReshapeOpEncoding(srcShape, srcEnc, dstShape, dstEnc,
479-
/*loc=*/std::nullopt);
475+
allowReorder, /*loc=*/std::nullopt);
480476
assert(succeeded(result));
481477
return dstEnc;
482478
}
483479

484480
static Attribute inferDstEncoding(triton::ReshapeOp op, Attribute encoding) {
485-
return inferReshapeOpDstEncoding(op.getSrc().getType().getShape(), encoding,
486-
op.getType().getShape(),
487-
op.getAllowReorder());
481+
return inferReshapeOpDstEncoding(
482+
op.getSrc().getType().getShape(), encoding, op.getType().getShape(),
483+
op.getType().getEncoding(), op.getAllowReorder());
488484
}
489485

490486
static Attribute inferDstEncoding(GatherOp op, Attribute encoding) {
@@ -499,9 +495,9 @@ static Attribute inferSrcEncoding(triton::ReshapeOp op, Attribute encoding) {
499495
// as the encoding of x given the encoding of y in `reshape(y) -> x`. It's an
500496
// invariant of inferReshapeOpNoReorderEncoding that it's symmetric in this
501497
// way.
502-
return inferReshapeOpDstEncoding(op.getType().getShape(), encoding,
503-
op.getSrc().getType().getShape(),
504-
op.getAllowReorder());
498+
return inferReshapeOpDstEncoding(
499+
op.getType().getShape(), encoding, op.getSrc().getType().getShape(),
500+
op.getSrc().getType().getEncoding(), op.getAllowReorder());
505501
}
506502

507503
static bool isSingleValue(Value value) {

test/TritonGPU/combine.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2217,6 +2217,41 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
22172217

22182218
// -----
22192219

2220+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}>
2221+
#blocked3 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
2222+
#blocked4 = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
2223+
#blocked1 = #ttg.slice<{dim = 0, parent = #blocked}>
2224+
2225+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
2226+
// CHECK-LABEL: @permuting_reshape_backward_remat
2227+
// CHECK-NOT: ttg.convert_layout
2228+
// CHECK: tt.return
2229+
tt.func public @permuting_reshape_backward_remat(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}) -> tensor<8x2xi32, #blocked3> {
2230+
%0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked1>
2231+
%1 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<16x!tt.ptr<i32>, #blocked1>
2232+
%2 = tt.addptr %1, %0 : tensor<16x!tt.ptr<i32>, #blocked1>, tensor<16xi32, #blocked1>
2233+
%3 = tt.load %2 : tensor<16x!tt.ptr<i32>, #blocked1>
2234+
%4 = tt.reshape %3 allow_reorder : tensor<16xi32, #blocked1> -> tensor<8x2xi32, #blocked4>
2235+
%5 = ttg.convert_layout %4 : tensor<8x2xi32, #blocked4> -> tensor<8x2xi32, #blocked3>
2236+
tt.return %5 : tensor<8x2xi32, #blocked3>
2237+
}
2238+
2239+
// CHECK-LABEL: @permuting_reshape_no_backward_remat_efficient_layout
2240+
// CHECK: ttg.convert_layout
2241+
// CHECK: tt.return
2242+
tt.func public @permuting_reshape_no_backward_remat_efficient_layout(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}) -> tensor<8x2xi32, #blocked3> {
2243+
%0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked1>
2244+
%1 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<16x!tt.ptr<i32>, #blocked1>
2245+
%2 = tt.addptr %1, %0 : tensor<16x!tt.ptr<i32>, #blocked1>, tensor<16xi32, #blocked1>
2246+
%3 = tt.load %2 : tensor<16x!tt.ptr<i32>, #blocked1>
2247+
%4 = tt.reshape %3 allow_reorder efficient_layout : tensor<16xi32, #blocked1> -> tensor<8x2xi32, #blocked4>
2248+
%5 = ttg.convert_layout %4 : tensor<8x2xi32, #blocked4> -> tensor<8x2xi32, #blocked3>
2249+
tt.return %5 : tensor<8x2xi32, #blocked3>
2250+
}
2251+
}
2252+
2253+
// -----
2254+
22202255
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
22212256
#slice1dim1 = #ttg.slice<{dim = 1, parent = #blocked1}>
22222257
#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>

unittest/Dialect/TritonGPU/DialectTest.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ void testReshape(RankedTensorType srcTy, RankedTensorType dstTy,
139139
ctx, [&](Diagnostic &diag) { diags.push_back(" - " + diag.str()); });
140140
result = inferLayout->inferReshapeOpEncoding(
141141
srcTy.getShape(), srcTy.getEncoding(), dstTy.getShape(), inferredEnc,
142-
UnknownLoc::get(ctx));
142+
/*allowReorder=*/false, UnknownLoc::get(ctx));
143143
}
144144

145145
// We expect the reshape to succeed as long as the inputs have the same
@@ -164,7 +164,7 @@ void testReshape(RankedTensorType srcTy, RankedTensorType dstTy,
164164
Attribute inferredSrcEnc;
165165
auto result = inferLayout->inferReshapeOpEncoding(
166166
dstTy.getShape(), inferredEnc, srcTy.getShape(), inferredSrcEnc,
167-
UnknownLoc::get(ctx));
167+
/*allowReorder=*/false, UnknownLoc::get(ctx));
168168
EXPECT_TRUE(succeeded(result))
169169
<< "Inverse encoding inference (" << triton::join(dstTy.getShape(), "x")
170170
<< " " << stringifyLLVMType(inferredEnc) << " -> "
@@ -439,7 +439,8 @@ TEST_F(JoinOpTest, JoinOpLayoutPropagation) {
439439
}
440440
Attribute reshapedEnc;
441441
result = inferLayout->inferReshapeOpEncoding(
442-
transShape, transEnc, newShape, reshapedEnc, std::nullopt);
442+
transShape, transEnc, newShape, reshapedEnc,
443+
/*allowReorder=*/false, std::nullopt);
443444
assert(succeeded(result));
444445
// The layouts should be structurally the same
445446
// but reshapeEnc will likely be a LinearEncodingAttr

0 commit comments

Comments
 (0)