Skip to content

Commit f0b7641

Browse files
authored
[BACKEND] Consider TMA variants in Triton passes (#10014)
Changing a TMA by a gather or a TMAStore by a scatter should not change the behaviour of our passes in most cases.
1 parent 035bcb5 commit f0b7641

26 files changed

Lines changed: 156 additions & 139 deletions

include/triton/Dialect/Triton/IR/TritonOpInterfaces.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,14 @@ def TT_DescriptorStoreLikeOpInterface : OpInterface<"DescriptorStoreLikeOpInterf
114114
];
115115
}
116116

117+
def TT_DescriptorLoadLikeOpInterface : OpInterface<"DescriptorLoadLikeOpInterface", [TT_DescriptorOpInterface]> {
118+
let description = [{
119+
Common marker interface for operations that load from tensor descriptors.
120+
}];
121+
122+
let cppNamespace = "::mlir::triton";
123+
}
124+
117125
def PredicatedOpInterface : OpInterface<"PredicatedOpInterface"> {
118126
let description = [{
119127
Common interface for operations that carry a predicate or mask operand that

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,7 +1222,7 @@ def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable
12221222
}
12231223

12241224

1225-
def TT_DescriptorLoadOp : TT_Op<"descriptor_load", [TT_DescriptorOpInterface]> {
1225+
def TT_DescriptorLoadOp : TT_Op<"descriptor_load", [TT_DescriptorLoadLikeOpInterface]> {
12261226
let summary = "Load from descriptor";
12271227
let description = [{
12281228
This operation will be lowered to Nvidia TMA load operation on targets supporting it.
@@ -1291,7 +1291,7 @@ def TT_DescriptorReduceOp : TT_Op<"descriptor_reduce", [TT_DescriptorStoreLikeOp
12911291
let hasVerifier = 1;
12921292
}
12931293

1294-
def TT_DescriptorGatherOp : TT_Op<"descriptor_gather", [TT_DescriptorOpInterface]> {
1294+
def TT_DescriptorGatherOp : TT_Op<"descriptor_gather", [TT_DescriptorLoadLikeOpInterface]> {
12951295
let summary = "gather multiple rows from a descriptor into a single tensor";
12961296
let description = [{
12971297
The `tt.descriptor_gather` op will be lowered to NVIDIA TMA

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINING_UTILITY_H_
33

44
#include "mlir/Dialect/SCF/IR/SCF.h"
5+
#include "triton/Dialect/Triton/IR/Dialect.h"
56
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
67
#include <optional>
78
#include <utility>
@@ -123,7 +124,9 @@ Value createAlloc(Operation *insertBefore, RankedTensorType ty, Location loc,
123124
gpu::SharedEncodingTrait sharedEnc, unsigned distance);
124125

125126
// Determine if the operation is a TMA load.
126-
bool isTMALoad(Operation *op);
127+
inline bool isTMALoad(Operation *op) {
128+
return isa<DescriptorLoadLikeOpInterface>(op);
129+
}
127130

128131
// Determine if the operation can be lowered to an async load.
129132
bool canBeAsyncLoad(Operation *op);

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,69 @@
33

44
include "mlir/IR/OpBase.td"
55

6+
def TMAOpInterface : OpInterface<"TMAOpInterface"> {
7+
let description = [{
8+
Common interface for asynchronous TMA operations.
9+
}];
10+
11+
let cppNamespace = "::mlir::triton::nvidia_gpu";
12+
13+
let methods = [
14+
InterfaceMethod<
15+
/*desc=*/"Get the tensor descriptor",
16+
/*retType=*/"::mlir::Value",
17+
/*methodName=*/"getDesc",
18+
/*args=*/(ins)>,
19+
];
20+
}
21+
22+
def TMALoadLikeOpInterface : OpInterface<"TMALoadLikeOpInterface", [TMAOpInterface]> {
23+
let description = [{
24+
Common interface for asynchronous TMA operations that write shared memory.
25+
}];
26+
27+
let cppNamespace = "::mlir::triton::nvidia_gpu";
28+
29+
let methods = [
30+
InterfaceMethod<
31+
/*desc=*/"Get the destination memory descriptor",
32+
/*retType=*/"::mlir::Value",
33+
/*methodName=*/"getResult",
34+
/*args=*/(ins)>,
35+
InterfaceMethod<
36+
/*desc=*/"Get the completion barrier",
37+
/*retType=*/"::mlir::Value",
38+
/*methodName=*/"getBarrier",
39+
/*args=*/(ins)>,
40+
InterfaceMethod<
41+
/*desc=*/"Get the predicate",
42+
/*retType=*/"::mlir::Value",
43+
/*methodName=*/"getPred",
44+
/*args=*/(ins)>,
45+
];
46+
}
47+
48+
def TMAStoreLikeOpInterface : OpInterface<"TMAStoreLikeOpInterface", [TMAOpInterface]> {
49+
let description = [{
50+
Common interface for asynchronous TMA operations that read shared memory.
51+
}];
52+
53+
let cppNamespace = "::mlir::triton::nvidia_gpu";
54+
55+
let methods = [
56+
InterfaceMethod<
57+
/*desc=*/"Get the source memory descriptor",
58+
/*retType=*/"::mlir::Value",
59+
/*methodName=*/"getSrc",
60+
/*args=*/(ins)>,
61+
InterfaceMethod<
62+
/*desc=*/"Get mutable source memory descriptor",
63+
/*retType=*/"::mlir::OpOperand&",
64+
/*methodName=*/"getSrcMutable",
65+
/*args=*/(ins)>,
66+
];
67+
}
68+
669
def MMAv5OpInterface : OpInterface<"MMAv5OpInterface"> {
770
let description = [{
871
This interface is implemented by MMAv5 dot and dot scaled ops.

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def TTNG_AsyncCopyMbarrierArriveOp : TTNG_Op<"async_copy_mbarrier_arrive", [
411411

412412
def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", [
413413
AttrSizedOperandSegments, DeclareOpInterfaceMethods<MBarrierOpInterface>,
414-
DeclareOpInterfaceMethods<PredicatedOpInterface>]> {
414+
DeclareOpInterfaceMethods<PredicatedOpInterface>, TMALoadLikeOpInterface]> {
415415
let summary = "copy data based on descriptor from global memory to local memory asynchronously";
416416

417417
let description = [{
@@ -474,7 +474,8 @@ def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local",
474474

475475
}
476476

477-
def TTNG_AsyncTMACopyLocalToGlobalOp : TTNG_Op<"async_tma_copy_local_to_global"> {
477+
def TTNG_AsyncTMACopyLocalToGlobalOp : TTNG_Op<"async_tma_copy_local_to_global", [
478+
TMAStoreLikeOpInterface]> {
478479
let summary = "copy data based on descriptor from local memory to global memory asynchronously";
479480

480481
let description = [{
@@ -498,7 +499,9 @@ def TTNG_AsyncTMACopyLocalToGlobalOp : TTNG_Op<"async_tma_copy_local_to_global">
498499
let hasVerifier = 1;
499500
}
500501

501-
def TTNG_AsyncTMAReduceOp : TTNG_Op<"async_tma_reduce", [MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>]> {
502+
def TTNG_AsyncTMAReduceOp : TTNG_Op<"async_tma_reduce", [
503+
MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>,
504+
TMAStoreLikeOpInterface]> {
502505
let summary = "reduce result in gmem based on a TMA descriptor";
503506

504507
let description = [{
@@ -524,7 +527,7 @@ def TTNG_AsyncTMAReduceOp : TTNG_Op<"async_tma_reduce", [MemoryEffects<[MemRead<
524527

525528
def TTNG_AsyncTMAGatherOp : TTNG_Op<"async_tma_gather", [
526529
DeclareOpInterfaceMethods<MBarrierOpInterface>,
527-
DeclareOpInterfaceMethods<PredicatedOpInterface>]> {
530+
DeclareOpInterfaceMethods<PredicatedOpInterface>, TMALoadLikeOpInterface]> {
528531
let summary = "gather data based on descriptor from global memory to local memory asynchronously";
529532

530533
let description = [{
@@ -550,7 +553,8 @@ def TTNG_AsyncTMAGatherOp : TTNG_Op<"async_tma_gather", [
550553
let hasVerifier = 1;
551554
}
552555

553-
def TTNG_AsyncTMAScatterOp : TTNG_Op<"async_tma_scatter"> {
556+
def TTNG_AsyncTMAScatterOp : TTNG_Op<"async_tma_scatter", [
557+
TMAStoreLikeOpInterface]> {
554558
let summary = "scatter data from local memory into global memory based on a descriptor asynchronously";
555559

556560
let description = [{

lib/Analysis/BufferRegion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ void BufferRegionAnalysis::calculateUsedBufferRegions(Operation *op) {
312312
bool BufferRegionAnalysis::isMemoryAccessOperation(Operation *op) {
313313
if (isa<ttg::LocalLoadOp, ttg::LocalStoreOp, ttng::TMEMLoadOp,
314314
ttng::TMEMStoreOp, ttng::TMEMCopyOp, ttg::AsyncCopyGlobalToLocalOp,
315-
ttng::AsyncTMACopyLocalToGlobalOp, ttng::AsyncTMAScatterOp>(op)) {
315+
ttng::TMAOpInterface>(op)) {
316316
return true;
317317
}
318318
if (isa<ttg::MBarrierOpInterface>(op)) {

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,8 @@ getWarpsPerTile(DotOpInterface dotOp, const ArrayRef<int64_t> shape,
264264
static bool bwdFilter(Operation *op) {
265265
return (op->hasTrait<OpTrait::Elementwise>() && isMemoryEffectFree(op)) ||
266266
isView(op) ||
267-
isa<Fp4ToFpOp, LoadOp, DescriptorLoadOp, BroadcastOp, ConvertLayoutOp>(
268-
op);
267+
isa<Fp4ToFpOp, LoadOp, DescriptorLoadLikeOpInterface, BroadcastOp,
268+
ConvertLayoutOp>(op);
269269
}
270270

271271
// Finds the bitwidth with which the value x is loaded
@@ -284,7 +284,7 @@ static int computeOrigBitWidth(Value x) {
284284

285285
int origBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
286286
for (auto op : slice) {
287-
if (isa<LoadOp, DescriptorLoadOp>(op)) {
287+
if (isa<LoadOp, DescriptorLoadLikeOpInterface>(op)) {
288288
if (auto tensorTy =
289289
dyn_cast<RankedTensorType>(op->getResultTypes().front())) {
290290
origBitWidth =
@@ -473,8 +473,9 @@ static bool canUseTwoCTAs(triton::DotOp dotOp) {
473473
// Skip convert layouts.
474474
while (auto cvtOp = b.getDefiningOp<ConvertLayoutOp>())
475475
b = cvtOp.getSrc();
476-
return llvm::isa_and_nonnull<triton::LoadOp, triton::DescriptorLoadOp,
477-
triton::DescriptorGatherOp>(b.getDefiningOp());
476+
return llvm::isa_and_nonnull<triton::LoadOp,
477+
triton::DescriptorLoadLikeOpInterface>(
478+
b.getDefiningOp());
478479
}
479480

480481
static DistributedEncodingTrait
@@ -501,8 +502,7 @@ static Value splitBOperand(Value b, mlir::PatternRewriter &rewriter) {
501502
while (auto cvtOp = b.getDefiningOp<ConvertLayoutOp>())
502503
b = cvtOp.getSrc();
503504
auto loadOp = b.getDefiningOp();
504-
assert((isa<triton::LoadOp, triton::DescriptorLoadOp,
505-
triton::DescriptorGatherOp>(loadOp)) &&
505+
assert((isa<triton::LoadOp, triton::DescriptorLoadLikeOpInterface>(loadOp)) &&
506506
"expected LoadOp");
507507
RankedTensorType bType = cast<RankedTensorType>(b.getType());
508508
auto currentLayout = cast<DistributedEncodingTrait>(bType.getEncoding());
@@ -627,7 +627,7 @@ Value addSmemStageToScaleLoad(Value scale, mlir::PatternRewriter &rewriter) {
627627
if (!op)
628628
return scale;
629629

630-
while (!isa<LoadOp, DescriptorLoadOp>(op)) {
630+
while (!isa<LoadOp, DescriptorLoadLikeOpInterface>(op)) {
631631
if (auto reshape = dyn_cast<ReshapeOp>(op)) {
632632
op = reshape.getSrc().getDefiningOp();
633633
loadConsumer = reshape;

lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -276,25 +276,15 @@ std::optional<UseInfo>
276276
AssignDescriptorMemoryLayouts::getUseInfo(Operation *op) {
277277
UseInfo info;
278278
info.use = op;
279-
if (auto load = dyn_cast<DescriptorLoadOp>(op)) {
279+
if (auto load = dyn_cast<DescriptorLoadLikeOpInterface>(op)) {
280280
info.descriptor = load.getDesc();
281281
info.desiredSharedEncoding = findLoadEncodingFromUsers(op);
282+
auto resultTy = cast<RankedTensorType>(op->getResult(0).getType());
282283
auto encoding = info.desiredSharedEncoding ? info.desiredSharedEncoding
283-
: load.getType().getEncoding();
284+
: resultTy.getEncoding();
284285
info.cgaLayout = getCGALayout(encoding);
285-
auto shape = load.getResult().getType().getShape();
286-
auto rank = load.getDesc().getType().getShape().size();
287-
info.shape = expandToRank(shape, rank);
288-
return info;
289-
}
290-
if (auto gather = dyn_cast<DescriptorGatherOp>(op)) {
291-
info.descriptor = gather.getDesc();
292-
info.desiredSharedEncoding = findLoadEncodingFromUsers(op);
293-
auto encoding = info.desiredSharedEncoding ? info.desiredSharedEncoding
294-
: gather.getType().getEncoding();
295-
info.cgaLayout = getCGALayout(encoding);
296-
auto shape = gather.getResult().getType().getShape();
297-
auto rank = gather.getDesc().getType().getShape().size();
286+
auto shape = resultTy.getShape();
287+
auto rank = info.descriptor.getType().getShape().size();
298288
info.shape = expandToRank(shape, rank);
299289
return info;
300290
}

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,8 @@ class UseShmemForScales
255255
}
256256
auto localAlloc = getNextOp<LocalAllocOp>(localLoad.getSrc());
257257
bool usesTMAload =
258-
(localAlloc && localAlloc.getSrc() &&
259-
(getNextOp<DescriptorLoadOp>(localAlloc.getSrc()) != nullptr));
258+
localAlloc && localAlloc.getSrc() &&
259+
getNextOp<DescriptorLoadLikeOpInterface>(localAlloc.getSrc());
260260
if (!isTmemCopyCompatible(localLoad.getSrc().getType(), usesTMAload))
261261
return failure();
262262

lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class AssignLoadLatencies {
107107
return false;
108108
}
109109
}
110-
if (isa<tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op))
110+
if (isa<tt::DescriptorLoadLikeOpInterface>(op))
111111
return true;
112112
if (!canHaveSharedEncoding(cast<tt::LoadOp>(op))) {
113113
LDBG("Load " << *op << " cannot have shared encoding");
@@ -291,7 +291,7 @@ loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot,
291291
[&](Operation *op, Operation *finalUser, int distance) {
292292
if (!seen.insert(op).second || excluded.count(op))
293293
return;
294-
if (isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op)) {
294+
if (isa<tt::LoadOp, tt::DescriptorLoadLikeOpInterface>(op)) {
295295
if (!AssignLoadLatencies::isPipeliningBeneficial(
296296
op, finalUser, axisInfoAnalysis, filterSmall))
297297
return;
@@ -342,7 +342,7 @@ loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot,
342342
// that are not directly used by dot ops.
343343
if (pipelineWithoutDot) {
344344
for (Operation &op : forOp.getBody()->without_terminator()) {
345-
if (!isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op))
345+
if (!isa<tt::LoadOp, tt::DescriptorLoadLikeOpInterface>(op))
346346
dfs(&op, &op, 0);
347347
}
348348
}

0 commit comments

Comments
 (0)