Skip to content

Commit e86144b

Browse files
authored
Remove TritonIR dependence on TritonGPUIR (#9392)
1 parent 743f178 commit e86144b

11 files changed

Lines changed: 154 additions & 60 deletions

File tree

include/triton/Dialect/Triton/IR/Traits.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ LogicalResult verifyTensorLayouts(Operation *op);
2828

2929
LogicalResult verifySameOperandsEncoding(Operation *op,
3030
bool allowTensorPointerType = false);
31-
LogicalResult verifyEquivalentType(Type typeA, Type typeB);
31+
LogicalResult verifyEquivalentTensorType(Type typeA, Type typeB);
3232
LogicalResult
3333
verifySameOperandsAndResultEncoding(Operation *op,
3434
bool allowTensorPointerType = false);

include/triton/Dialect/Triton/IR/TritonInterfaces.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ def AsyncRegions : NativeOpTrait<"AsyncRegions">;
1616

1717
// A trait equivalent to InferTypeOpAdaptor, but that checks for structural
1818
// equivalence of the layouts of the result rather than just layout equality.
19-
def InferTypeOpWithLayoutEquivalence : InferTypeOpAdaptorBase<[{
19+
def InferTensorTypeOpWithLayoutEquivalence : InferTypeOpAdaptorBase<[{
2020
static bool isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
2121
if (lhs.size() != rhs.size())
2222
return false;
2323
return llvm::all_of(llvm::zip(lhs, rhs), [](auto tup) {
2424
auto [lhs, rhs] = tup;
25-
return succeeded(OpTrait::impl::verifyEquivalentType(lhs, rhs));
25+
return succeeded(OpTrait::impl::verifyEquivalentTensorType(lhs, rhs));
2626
});
2727
}
2828
}]>;

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ def TT_JoinOp : TT_Op<"join", [
545545

546546
def TT_SplitOp : TT_Op<"split", [
547547
Pure,
548-
InferTypeOpWithLayoutEquivalence,
548+
InferTensorTypeOpWithLayoutEquivalence,
549549
TypesMatchWith<"outLHS and outRHS types match",
550550
"outLHS", "outRHS", "$_self">,
551551
]> {
@@ -565,7 +565,7 @@ def TT_SplitOp : TT_Op<"split", [
565565

566566
def TT_TransOp : TT_Op<"trans", [Pure,
567567
TransposeOpInterface,
568-
InferTypeOpWithLayoutEquivalence,
568+
InferTensorTypeOpWithLayoutEquivalence,
569569
SameOperandsAndResultElementType]> {
570570

571571
let summary = "rearrange the dimensions of a tensor";

include/triton/Dialect/TritonGPU/IR/Traits.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,22 @@
1010
namespace mlir {
1111
namespace OpTrait {
1212

13+
namespace impl {
14+
LogicalResult verifyEquivalentMemDescType(Type typeA, Type typeB);
15+
LogicalResult verifyMemDescLayouts(Operation *op);
16+
} // namespace impl
17+
18+
// Trait applied to all Triton GPU MLIR ops. Checks that the layouts of
19+
// MemDescs are valid.
20+
template <class ConcreteType>
21+
class VerifyMemDescLayoutsTrait
22+
: public TraitBase<ConcreteType, VerifyMemDescLayoutsTrait> {
23+
public:
24+
static LogicalResult verifyTrait(Operation *op) {
25+
return impl::verifyMemDescLayouts(op);
26+
}
27+
};
28+
1329
template <typename ConcreteType>
1430
class MemDescViewTrait
1531
: public mlir::OpTrait::TraitBase<ConcreteType, MemDescViewTrait> {

include/triton/Dialect/TritonGPU/IR/TritonGPUOpInterfaces.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TRITONGPU_OP_INTERFACES
33

44
include "mlir/IR/OpBase.td"
5+
include "mlir/Interfaces/InferTypeOpInterface.td"
56

67
def UpcastFpOpInterface : OpInterface<"UpcastFpOpInterface"> {
78
let description = [{
@@ -26,4 +27,19 @@ def UpcastFpOpInterface : OpInterface<"UpcastFpOpInterface"> {
2627
];
2728
}
2829

30+
def VerifyMemDescLayoutsTrait : NativeOpTrait<"VerifyMemDescLayoutsTrait">;
31+
32+
// A trait equivalent to InferTypeOpAdaptor, but that checks for structural
33+
// equivalence of the layouts of the result rather than just layout equality.
34+
def InferMemDescTypeOpWithLayoutEquivalence : InferTypeOpAdaptorBase<[{
35+
static bool isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
36+
if (lhs.size() != rhs.size())
37+
return false;
38+
return llvm::all_of(llvm::zip(lhs, rhs), [](auto tup) {
39+
auto [lhs, rhs] = tup;
40+
return succeeded(OpTrait::impl::verifyEquivalentMemDescType(lhs, rhs));
41+
});
42+
}
43+
}]>;
44+
2945
#endif // TRITONGPU_OP_INTERFACES

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ include "triton/Dialect/TritonGPU/IR/TritonGPUEnums.td"
66
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
77
include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td"
88
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
9+
include "triton/Dialect/TritonGPU/IR/TritonGPUOpInterfaces.td"
910
include "mlir/Dialect/Arith/IR/ArithBase.td"
1011
include "triton/Dialect/Triton/IR/TritonTypes.td"
1112
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
@@ -25,7 +26,7 @@ def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">;
2526

2627
class TTG_Op<string mnemonic, list<Trait> traits = []> :
2728
Op<TritonGPU_Dialect, mnemonic,
28-
!listconcat(traits, [VerifyTensorLayoutsTrait])> {
29+
!listconcat(traits, [VerifyTensorLayoutsTrait, VerifyMemDescLayoutsTrait])> {
2930
}
3031

3132
def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
@@ -271,7 +272,7 @@ def TTG_MemDescSubsliceOp : TTG_Op<"memdesc_subslice", [Pure, MemDescViewTrait]>
271272
def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure,
272273
MemDescViewTrait,
273274
TransposeOpInterface,
274-
InferTypeOpWithLayoutEquivalence,
275+
InferMemDescTypeOpWithLayoutEquivalence,
275276
SameOperandsAndResultElementType]> {
276277
let summary = "transpose the descriptor";
277278

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
3333
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
3434
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
3535
include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td"
36+
include "triton/Dialect/TritonGPU/IR/TritonGPUOpInterfaces.td"
3637
include "mlir/IR/OpBase.td"
3738
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
3839
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
@@ -45,7 +46,7 @@ def TensorMemory : Resource<"::mlir::triton::nvidia_gpu::TensorMemory">;
4546

4647
class TTNG_Op<string mnemonic, list<Trait> traits = []> :
4748
Op<TritonNvidiaGPU_Dialect, mnemonic,
48-
!listconcat(traits, [VerifyTensorLayoutsTrait])> {
49+
!listconcat(traits, [VerifyTensorLayoutsTrait, VerifyMemDescLayoutsTrait])> {
4950
}
5051

5152
def TTNG_FenceAsyncSharedOp : TTNG_Op<"fence_async_shared"> {

lib/Dialect/Triton/IR/Traits.cpp

Lines changed: 6 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -6,41 +6,12 @@
66
#include "triton/Dialect/Triton/IR/Dialect.h"
77
#include "triton/Dialect/Triton/IR/Types.h"
88
#include "triton/Dialect/Triton/IR/Utility.h"
9-
#include "triton/Dialect/TritonGPU/IR/Types.h"
109
#include "llvm/Support/ErrorHandling.h"
1110

1211
using namespace mlir;
13-
using namespace mlir::triton::gpu;
1412

15-
LogicalResult OpTrait::impl::verifyEquivalentType(Type typeA, Type typeB) {
16-
auto memdescA = dyn_cast<MemDescType>(typeA);
17-
auto memdescB = dyn_cast<MemDescType>(typeB);
18-
if (memdescA || memdescB) {
19-
if (!memdescA || !memdescB)
20-
return failure();
21-
if (memdescA.getShape() != memdescB.getShape())
22-
return failure();
23-
if (memdescA.getAllocShape() != memdescB.getAllocShape())
24-
return failure();
25-
if (memdescA.getElementType() != memdescB.getElementType())
26-
return failure();
27-
if (memdescA.getMemorySpace() != memdescB.getMemorySpace())
28-
return failure();
29-
if (memdescA.getMutableMemory() != memdescB.getMutableMemory())
30-
return failure();
31-
32-
Attribute encodingA = memdescA.getEncoding();
33-
Attribute encodingB = memdescB.getEncoding();
34-
if (encodingA == encodingB)
35-
return success();
36-
if (static_cast<bool>(encodingA) != static_cast<bool>(encodingB))
37-
return failure();
38-
39-
auto layoutInterface =
40-
cast<triton::DialectInferLayoutInterface>(&encodingA.getDialect());
41-
return layoutInterface->verifyLayoutsAreEqual(memdescA.getShape(),
42-
encodingA, encodingB, {});
43-
}
13+
LogicalResult OpTrait::impl::verifyEquivalentTensorType(Type typeA,
14+
Type typeB) {
4415
auto tensorTypeA = dyn_cast<RankedTensorType>(typeA);
4516
auto tensorTypeB = dyn_cast<RankedTensorType>(typeB);
4617
if (!(bool(tensorTypeA) && bool(tensorTypeB)))
@@ -162,35 +133,19 @@ LogicalResult OpTrait::impl::verifyTensorLayouts(Operation *op) {
162133
auto checkLayout = [&](Value val, auto makeErr) -> LogicalResult {
163134
// Only ranked tensors can have layouts.
164135
auto rankedTy = dyn_cast<RankedTensorType>(val.getType());
165-
if (rankedTy) {
166-
mlir::Attribute layout = rankedTy.getEncoding();
167-
if (!layout)
168-
return success();
169-
170-
Dialect &dialect = layout.getDialect();
171-
auto verifyLayoutInterface =
172-
dyn_cast<mlir::triton::DialectVerifyTensorLayoutInterface>(&dialect);
173-
if (verifyLayoutInterface) {
174-
return verifyLayoutInterface->verifyTensorLayout(layout, rankedTy, op,
175-
makeErr);
176-
}
177-
return success();
178-
}
179-
180-
auto memDescTy = dyn_cast<MemDescType>(val.getType());
181-
if (!memDescTy)
136+
if (!rankedTy)
182137
return success();
183138

184-
mlir::Attribute layout = memDescTy.getEncoding();
139+
mlir::Attribute layout = rankedTy.getEncoding();
185140
if (!layout)
186141
return success();
187142

188143
Dialect &dialect = layout.getDialect();
189144
auto verifyLayoutInterface =
190145
dyn_cast<mlir::triton::DialectVerifyTensorLayoutInterface>(&dialect);
191146
if (verifyLayoutInterface) {
192-
return verifyLayoutInterface->verifyMemDescLayout(layout, memDescTy, op,
193-
makeErr);
147+
return verifyLayoutInterface->verifyTensorLayout(layout, rankedTy, op,
148+
makeErr);
194149
}
195150

196151
return success();

lib/Dialect/TritonGPU/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_triton_library(TritonGPUIR
33
LinearLayoutConversions.cpp
44
Ops.cpp
55
Types.cpp
6+
Traits.cpp
67

78
DEPENDS
89
TritonGPUCGAAttrIncGen

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ LogicalResult MemDescReshapeOp::verify() {
571571
if (failed(inferReturnTypes(getContext(), getLoc(), srcType,
572572
dstType.getShape(), expectedTy)))
573573
return failure();
574-
return OpTrait::impl::verifyEquivalentType(expectedTy, dstType);
574+
return OpTrait::impl::verifyEquivalentMemDescType(expectedTy, dstType);
575575
}
576576

577577
static LogicalResult inferMemDescReshapeOpEncoding(ArrayRef<int64_t> srcShape,

0 commit comments

Comments
 (0)