Skip to content

Commit 4965168

Browse files
authored
Merge branch 'main' into CRobeck-patch-1
2 parents 588ced4 + 9c446b4 commit 4965168

35 files changed

Lines changed: 1657 additions & 642 deletions

File tree

include/triton/Dialect/Triton/IR/Dialect.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,13 @@ class DialectInferLayoutInterface
5959
// makes the reshape a "nop", i.e. the same GPU threads contain the same
6060
// elements as before the reshape using legacy layouts. This is not always
6161
// possible (in which case we fallback to using LinearLayouts)
62+
// If allowReorder is set, an existing value in dstEnc is preferred when it
63+
// still yields a non-expensive view.
6264
// In the future we'll always use LinearLayouts
6365
virtual LogicalResult
6466
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
6567
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
68+
bool allowReorder,
6669
std::optional<Location> loc) const = 0;
6770

6871
// Check if two layouts are structurally the same, even if their names are

include/triton/Dialect/Triton/IR/TritonOpInterfaces.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,5 +114,32 @@ def TT_DescriptorStoreLikeOpInterface : OpInterface<"DescriptorStoreLikeOpInterf
114114
];
115115
}
116116

117+
def PredicatedOpInterface : OpInterface<"PredicatedOpInterface"> {
118+
let description = [{
119+
Common interface for operations that carry a predicate or mask operand that
120+
can be combined with a pipeline predicate.
121+
}];
122+
123+
let cppNamespace = "::mlir::triton";
124+
125+
let methods = [
126+
InterfaceMethod<
127+
/*desc=*/"Return the current predicate or mask operand.",
128+
/*retType=*/"::mlir::Value",
129+
/*methodName=*/"getPredicateOperand",
130+
/*args=*/(ins)>,
131+
InterfaceMethod<
132+
/*desc=*/"Update the predicate or mask operand.",
133+
/*retType=*/"void",
134+
/*methodName=*/"setPredicateOperand",
135+
/*args=*/(ins "::mlir::Value":$pred)>,
136+
InterfaceMethod<
137+
/*desc=*/"Return a type whose shape determines the predicate operand type.",
138+
/*retType=*/"::mlir::Type",
139+
/*methodName=*/"getPredicateOperandTypeLike",
140+
/*args=*/(ins)>
141+
];
142+
}
143+
117144

118145
#endif // TRITON_OP_INTERFACES

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def TT_LoadOp : TT_Op<"load", [
212212
SameLoadStoreOperandsAndResultShape,
213213
SameLoadStoreOperandsAndResultEncoding,
214214
AttrSizedOperandSegments,
215+
DeclareOpInterfaceMethods<PredicatedOpInterface>,
215216
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
216217
DeclareOpInterfaceMethods<InferTypeOpInterface>,
217218
TypesMatchWith<"result matches ptr type", "ptr", "result", "getPointeeType($_self)">,
@@ -271,6 +272,7 @@ def TT_LoadOp : TT_Op<"load", [
271272
def TT_StoreOp : TT_Op<"store", [
272273
SameLoadStoreOperandsShape,
273274
SameLoadStoreOperandsEncoding,
275+
DeclareOpInterfaceMethods<PredicatedOpInterface>,
274276
TypesMatchWith<"value type matches ptr type", "ptr", "value",
275277
"getPointeeType($_self)">,
276278
TypesMatchWith<"mask type matches ptr type", "ptr", "mask",
@@ -314,6 +316,7 @@ def TT_StoreOp : TT_Op<"store", [
314316
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [
315317
SameOperandsAndResultShape,
316318
SameOperandsAndResultEncoding,
319+
DeclareOpInterfaceMethods<PredicatedOpInterface>,
317320
TypesMatchWith<"ptr type matches value type", "val", "ptr",
318321
"getPointerTypeSameShape($_self)">,
319322
TypesMatchWith<"mask type matches value type",
@@ -345,6 +348,7 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [
345348
$atomic_rmw_op `,` $sem `,` $scope `,` $ptr `,` $val (`,` $mask^)? attr-dict `:`
346349
functional-type(operands, $result)
347350
}];
351+
348352
}
349353

350354
def TT_AtomicCASOp : TT_Op<"atomic_cas", [

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,14 @@ SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
263263
bool isExpensiveCat(CatOp cat, Attribute targetEncoding);
264264

265265
// Return true if a view between the two types cannot be implemented as a no-op.
266-
bool isExpensiveView(Type srcType, Type dstType);
266+
bool isExpensiveView(ArrayRef<int64_t> srcShape, Attribute srcEncoding,
267+
ArrayRef<int64_t> dstShape, Attribute dstEncoding);
268+
inline bool isExpensiveView(Type srcType, Type dstType) {
269+
auto tensorSrcType = cast<RankedTensorType>(srcType);
270+
auto tensorDstType = cast<RankedTensorType>(dstType);
271+
return isExpensiveView(tensorSrcType.getShape(), tensorSrcType.getEncoding(),
272+
tensorDstType.getShape(), tensorDstType.getEncoding());
273+
}
267274

268275
// Return a blocked encoding where the shape is distributed contiguously amongst
269276
// the threads, warps, CTAs with 1 element per threads.

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> {
8989

9090
def TTG_AsyncCopyGlobalToLocalOp : TTG_Op<"async_copy_global_to_local", [
9191
AttrSizedOperandSegments,
92+
DeclareOpInterfaceMethods<PredicatedOpInterface>,
9293
OptionalTypesMatchWith<"infer mask type from src type",
9394
"src", "mask", "getI1SameShape($_self)">,
9495
OptionalTypesMatchWith<"infer other type from src type",

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,8 @@ def TTNG_InvalBarrierOp : TTNG_Op<"inval_barrier", [
289289
}
290290

291291
def TTNG_BarrierExpectOp : TTNG_Op<"barrier_expect", [
292-
DeclareOpInterfaceMethods<MBarrierOpInterface, ["getBarrier"]>]> {
292+
DeclareOpInterfaceMethods<MBarrierOpInterface, ["getBarrier"]>,
293+
DeclareOpInterfaceMethods<PredicatedOpInterface>]> {
293294
let summary = "Signal a barrier of an expected number of bytes to be copied.";
294295

295296
let description = [{
@@ -307,10 +308,12 @@ def TTNG_BarrierExpectOp : TTNG_Op<"barrier_expect", [
307308
let assemblyFormat = [{
308309
$alloc `,` $size attr-dict `,` $pred `:` qualified(type($alloc))
309310
}];
311+
310312
}
311313

312314
def TTNG_WaitBarrierOp : TTNG_Op<"wait_barrier", [AttrSizedOperandSegments,
313-
DeclareOpInterfaceMethods<MBarrierOpInterface, ["getBarrier"]>]> {
315+
DeclareOpInterfaceMethods<MBarrierOpInterface, ["getBarrier"]>,
316+
DeclareOpInterfaceMethods<PredicatedOpInterface>]> {
314317
let summary = "wait until the mbarrier phase completes.";
315318

316319
let description = [{
@@ -357,7 +360,8 @@ def TTNG_WaitBarrierOp : TTNG_Op<"wait_barrier", [AttrSizedOperandSegments,
357360
}
358361

359362
def TTNG_ArriveBarrierOp : TTNG_Op<"arrive_barrier", [
360-
DeclareOpInterfaceMethods<MBarrierOpInterface, ["getBarrier"]>]> {
363+
DeclareOpInterfaceMethods<MBarrierOpInterface, ["getBarrier"]>,
364+
DeclareOpInterfaceMethods<PredicatedOpInterface>]> {
361365
let summary = "perform the arrive operation on an mbarrier";
362366
let description = [{
363367
The `ttng.arrive_barrier` operation performs the "arrive" operation on an
@@ -406,7 +410,8 @@ def TTNG_AsyncCopyMbarrierArriveOp : TTNG_Op<"async_copy_mbarrier_arrive", [
406410

407411

408412
def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", [
409-
AttrSizedOperandSegments, DeclareOpInterfaceMethods<MBarrierOpInterface>]> {
413+
AttrSizedOperandSegments, DeclareOpInterfaceMethods<MBarrierOpInterface>,
414+
DeclareOpInterfaceMethods<PredicatedOpInterface>]> {
410415
let summary = "copy data based on descriptor from global memory to local memory asynchronously";
411416

412417
let description = [{
@@ -466,6 +471,7 @@ def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local",
466471
oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict)
467472
attr-dict `:` qualified(type($desc)) `,` qualified(type($barrier)) `->` qualified(type($result))
468473
}];
474+
469475
}
470476

471477
def TTNG_AsyncTMACopyLocalToGlobalOp : TTNG_Op<"async_tma_copy_local_to_global"> {
@@ -517,7 +523,8 @@ def TTNG_AsyncTMAReduceOp : TTNG_Op<"async_tma_reduce", [MemoryEffects<[MemRead<
517523
}
518524

519525
def TTNG_AsyncTMAGatherOp : TTNG_Op<"async_tma_gather", [
520-
DeclareOpInterfaceMethods<MBarrierOpInterface>]> {
526+
DeclareOpInterfaceMethods<MBarrierOpInterface>,
527+
DeclareOpInterfaceMethods<PredicatedOpInterface>]> {
521528
let summary = "gather data based on descriptor from global memory to local memory asynchronously";
522529

523530
let description = [{
@@ -584,6 +591,7 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [
584591
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
585592
DeclareOpInterfaceMethods<DotOpInterface>,
586593
DeclareOpInterfaceMethods<MMAv5OpInterface>,
594+
DeclareOpInterfaceMethods<PredicatedOpInterface>,
587595
DeclareOpInterfaceMethods<MBarrierOpInterface>,
588596
AttrSizedOperandSegments
589597
]> {
@@ -645,6 +653,7 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
645653
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
646654
DeclareOpInterfaceMethods<DotOpInterface, ["verifyDims", "verifyOutputDims"]>,
647655
DeclareOpInterfaceMethods<MMAv5OpInterface>,
656+
DeclareOpInterfaceMethods<PredicatedOpInterface>,
648657
DeclareOpInterfaceMethods<MBarrierOpInterface>,
649658
AttrSizedOperandSegments
650659
]> {
@@ -715,7 +724,8 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
715724
}
716725

717726
def TTNG_TCGen5CommitOp : TTNG_Op<"tc_gen5_commit", [AttrSizedOperandSegments,
718-
DeclareOpInterfaceMethods<MBarrierOpInterface>]> {
727+
DeclareOpInterfaceMethods<MBarrierOpInterface>,
728+
DeclareOpInterfaceMethods<PredicatedOpInterface>]> {
719729
let summary = "make an mbarrier track completion of all prior async tcgen5 ops";
720730

721731
let description = [{
@@ -858,7 +868,8 @@ def TTNG_TMEMLoadOp : TTNG_Op<"tmem_load", [AttrSizedResultSegments]> {
858868
}];
859869
}
860870

861-
def TTNG_TMEMStoreOp : TTNG_Op<"tmem_store"> {
871+
def TTNG_TMEMStoreOp : TTNG_Op<"tmem_store", [
872+
DeclareOpInterfaceMethods<PredicatedOpInterface>]> {
862873
let summary = "Store a distributed tensor into a buffer in tensor memory";
863874

864875
let description = [{

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,10 +307,7 @@ struct ReshapeOpConversion : public ConvertOpToLLVMPattern<ReshapeOp> {
307307
matchAndRewrite(ReshapeOp op, OpAdaptor adaptor,
308308
ConversionPatternRewriter &rewriter) const override {
309309
Location loc = op->getLoc();
310-
if (triton::gpu::isExpensiveView(op.getSrc().getType(), op.getType())) {
311-
return emitOptionalError(loc,
312-
"expensive view not supported on reshape op");
313-
}
310+
assert(!isExpensiveView(op.getSrc().getType(), op.getType()));
314311
auto resultTy = cast<RankedTensorType>(op.getType());
315312
auto srcTy = cast<RankedTensorType>(op.getSrc().getType());
316313
auto typeConverter = getTypeConverter();

lib/Dialect/Gluon/IR/Dialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ struct GluonInferLayoutInterface : public triton::DialectInferLayoutInterface {
7474

7575
LogicalResult
7676
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
77-
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
77+
ArrayRef<int64_t> dstShape, Attribute &dstEnc, bool,
7878
std::optional<Location> loc) const override {
7979
return inferAutoEncoding(srcEnc, dstEnc);
8080
}

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
5353
isVolatile);
5454
}
5555

56+
Value LoadOp::getPredicateOperand() { return getMask(); }
57+
58+
void LoadOp::setPredicateOperand(Value pred) { getMaskMutable().assign(pred); }
59+
60+
Type LoadOp::getPredicateOperandTypeLike() { return getPtr().getType(); }
61+
5662
// load(ptr, splat(1), ...) -> load(ptr, ...)
5763
// load(ptr, splat(0), other, ...) -> other
5864
struct CanonicalizeMaskedLoadPattern : public OpRewritePattern<LoadOp> {
@@ -103,6 +109,12 @@ void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr,
103109
return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, cache, evict);
104110
}
105111

112+
Value StoreOp::getPredicateOperand() { return getMask(); }
113+
114+
void StoreOp::setPredicateOperand(Value pred) { getMaskMutable().assign(pred); }
115+
116+
Type StoreOp::getPredicateOperandTypeLike() { return getPtr().getType(); }
117+
106118
// store(ptr, value, splat(1), ...) -> store(ptr, value, ...)
107119
// store(ptr, value, splat(0), ...) -> [none]
108120
struct CanonicalizeMaskedStorePattern : public OpRewritePattern<StoreOp> {
@@ -136,6 +148,14 @@ struct CanonicalizeMaskedStorePattern : public OpRewritePattern<StoreOp> {
136148
}
137149
};
138150

151+
Value AtomicRMWOp::getPredicateOperand() { return getMask(); }
152+
153+
void AtomicRMWOp::setPredicateOperand(Value pred) {
154+
getMaskMutable().assign(pred);
155+
}
156+
157+
Type AtomicRMWOp::getPredicateOperandTypeLike() { return getPtr().getType(); }
158+
139159
void StoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
140160
MLIRContext *context) {
141161
results.add<CanonicalizeMaskedStorePattern>(context);
@@ -850,9 +870,10 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &state,
850870
auto srcEnc = srcTy.getEncoding();
851871
Attribute dstEnc;
852872
if (srcEnc) {
853-
auto result = cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
854-
->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, shape,
855-
dstEnc, state.location);
873+
auto result =
874+
cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
875+
->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, shape, dstEnc,
876+
allowReorder, state.location);
856877
assert(succeeded(result));
857878
}
858879
auto dstTy = RankedTensorType::get(shape, srcTy.getElementType(), dstEnc);
@@ -912,17 +933,20 @@ LogicalResult ReshapeOp::verify() {
912933
"encodings, or (b) neither does.");
913934
}
914935

915-
if (!srcEnc || getAllowReorder()) {
936+
if (!srcEnc) {
916937
return success();
917938
}
918939

919-
// Check that we can infer the dst encoding from the src encoding
920-
// and that the inferred dst encoding is the same as the given dst encoding
921-
Attribute inferredDstEnc;
940+
// Check that we can infer the dst encoding from the src encoding and that the
941+
// inferred dst encoding is the same as the given dst encoding. We pass the
942+
// current dst encoding as a hint so that allowReorder reshapes are guaranteed
943+
// to produce the current encoding iff it is valid.
944+
Attribute inferredDstEnc = dstEnc;
922945
auto layoutInterface =
923946
cast<DialectInferLayoutInterface>(&srcEnc.getDialect());
924947
auto result = layoutInterface->inferReshapeOpEncoding(
925-
srcTy.getShape(), srcEnc, dstTy.getShape(), inferredDstEnc, getLoc());
948+
srcTy.getShape(), srcEnc, dstTy.getShape(), inferredDstEnc,
949+
getAllowReorder(), getLoc());
926950
if (failed(result))
927951
return failure();
928952
return layoutInterface->verifyLayoutsAreEqual(

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,10 @@ SmallVector<unsigned> getContigPerThread(RankedTensorType type) {
114114
return toLinearEncoding(type).getContigPerThread();
115115
}
116116

117-
bool isExpensiveView(Type srcType, Type dstType) {
118-
auto tensorSrcType = cast<RankedTensorType>(srcType);
119-
auto tensorDstType = cast<RankedTensorType>(dstType);
120-
auto llSrc = toLinearLayout(tensorSrcType);
121-
auto llDst = toLinearLayout(tensorDstType);
117+
bool isExpensiveView(ArrayRef<int64_t> srcShape, Attribute srcEncoding,
118+
ArrayRef<int64_t> dstShape, Attribute dstEncoding) {
119+
auto llSrc = toLinearLayout(srcShape, srcEncoding);
120+
auto llDst = toLinearLayout(dstShape, dstEncoding);
122121
// In case there are replicated value we need to make sure the new and old
123122
// layout have matching masks.
124123
for (auto [srcMask, dstMask] :
@@ -127,7 +126,8 @@ bool isExpensiveView(Type srcType, Type dstType) {
127126
if (srcMask.second != dstMask.second)
128127
return true;
129128
}
130-
return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType);
129+
return getTotalElemsPerThread(srcEncoding, srcShape) !=
130+
getTotalElemsPerThread(dstEncoding, dstShape);
131131
}
132132

133133
/* Utility function used by get.*Order methods of SliceEncodingAttr.
@@ -3285,11 +3285,17 @@ struct TritonGPUInferLayoutInterface
32853285
LogicalResult
32863286
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
32873287
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
3288+
bool allowReorder,
32883289
std::optional<Location> loc) const override {
32893290
if (product(srcShape) != product(dstShape)) {
32903291
return emitOptionalError(loc, "numel of dst shape does not match "
32913292
"numel of src shape");
32923293
}
3294+
// If allowReorder is true, there are multiple valid encodings. Prefer the
3295+
// hint if it is set and valid.
3296+
if (allowReorder && dstEnc)
3297+
if (!isExpensiveView(srcShape, srcEnc, dstShape, dstEnc))
3298+
return success();
32933299
auto result =
32943300
inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc);
32953301
if (succeeded(result)) {

0 commit comments

Comments
 (0)