|
6 | 6 | #include "triton/Dialect/Triton/IR/Dialect.h" |
7 | 7 | #include "triton/Dialect/Triton/IR/Types.h" |
8 | 8 | #include "triton/Dialect/Triton/IR/Utility.h" |
9 | | -#include "triton/Dialect/TritonGPU/IR/Types.h" |
10 | 9 | #include "llvm/Support/ErrorHandling.h" |
11 | 10 |
|
12 | 11 | using namespace mlir; |
13 | | -using namespace mlir::triton::gpu; |
14 | 12 |
|
15 | | -LogicalResult OpTrait::impl::verifyEquivalentType(Type typeA, Type typeB) { |
16 | | - auto memdescA = dyn_cast<MemDescType>(typeA); |
17 | | - auto memdescB = dyn_cast<MemDescType>(typeB); |
18 | | - if (memdescA || memdescB) { |
19 | | - if (!memdescA || !memdescB) |
20 | | - return failure(); |
21 | | - if (memdescA.getShape() != memdescB.getShape()) |
22 | | - return failure(); |
23 | | - if (memdescA.getAllocShape() != memdescB.getAllocShape()) |
24 | | - return failure(); |
25 | | - if (memdescA.getElementType() != memdescB.getElementType()) |
26 | | - return failure(); |
27 | | - if (memdescA.getMemorySpace() != memdescB.getMemorySpace()) |
28 | | - return failure(); |
29 | | - if (memdescA.getMutableMemory() != memdescB.getMutableMemory()) |
30 | | - return failure(); |
31 | | - |
32 | | - Attribute encodingA = memdescA.getEncoding(); |
33 | | - Attribute encodingB = memdescB.getEncoding(); |
34 | | - if (encodingA == encodingB) |
35 | | - return success(); |
36 | | - if (static_cast<bool>(encodingA) != static_cast<bool>(encodingB)) |
37 | | - return failure(); |
38 | | - |
39 | | - auto layoutInterface = |
40 | | - cast<triton::DialectInferLayoutInterface>(&encodingA.getDialect()); |
41 | | - return layoutInterface->verifyLayoutsAreEqual(memdescA.getShape(), |
42 | | - encodingA, encodingB, {}); |
43 | | - } |
| 13 | +LogicalResult OpTrait::impl::verifyEquivalentTensorType(Type typeA, |
| 14 | + Type typeB) { |
44 | 15 | auto tensorTypeA = dyn_cast<RankedTensorType>(typeA); |
45 | 16 | auto tensorTypeB = dyn_cast<RankedTensorType>(typeB); |
46 | 17 | if (!(bool(tensorTypeA) && bool(tensorTypeB))) |
@@ -162,35 +133,19 @@ LogicalResult OpTrait::impl::verifyTensorLayouts(Operation *op) { |
162 | 133 | auto checkLayout = [&](Value val, auto makeErr) -> LogicalResult { |
163 | 134 | // Only ranked tensors can have layouts. |
164 | 135 | auto rankedTy = dyn_cast<RankedTensorType>(val.getType()); |
165 | | - if (rankedTy) { |
166 | | - mlir::Attribute layout = rankedTy.getEncoding(); |
167 | | - if (!layout) |
168 | | - return success(); |
169 | | - |
170 | | - Dialect &dialect = layout.getDialect(); |
171 | | - auto verifyLayoutInterface = |
172 | | - dyn_cast<mlir::triton::DialectVerifyTensorLayoutInterface>(&dialect); |
173 | | - if (verifyLayoutInterface) { |
174 | | - return verifyLayoutInterface->verifyTensorLayout(layout, rankedTy, op, |
175 | | - makeErr); |
176 | | - } |
177 | | - return success(); |
178 | | - } |
179 | | - |
180 | | - auto memDescTy = dyn_cast<MemDescType>(val.getType()); |
181 | | - if (!memDescTy) |
| 136 | + if (!rankedTy) |
182 | 137 | return success(); |
183 | 138 |
|
184 | | - mlir::Attribute layout = memDescTy.getEncoding(); |
| 139 | + mlir::Attribute layout = rankedTy.getEncoding(); |
185 | 140 | if (!layout) |
186 | 141 | return success(); |
187 | 142 |
|
188 | 143 | Dialect &dialect = layout.getDialect(); |
189 | 144 | auto verifyLayoutInterface = |
190 | 145 | dyn_cast<mlir::triton::DialectVerifyTensorLayoutInterface>(&dialect); |
191 | 146 | if (verifyLayoutInterface) { |
192 | | - return verifyLayoutInterface->verifyMemDescLayout(layout, memDescTy, op, |
193 | | - makeErr); |
| 147 | + return verifyLayoutInterface->verifyTensorLayout(layout, rankedTy, op, |
| 148 | + makeErr); |
194 | 149 | } |
195 | 150 |
|
196 | 151 | return success(); |
|
0 commit comments