Skip to content

Commit 84214b6

Browse files
authored
[Gluon][Dialect] Tighten verifiers, add more helpful error messages (triton-lang#8981)
`DotOpMMASmemLoader::build` is now fallible if it cannot find an SMEM atom to implement the layout. Since the logic is fairly complex, it perhaps doesn't make sense as a verifier. This is piped through the lowering of `WarpGroupDotOp`, `TMEMCopyOp`, and the MMAv5 ops, which are not fallible too. * Add verifiers to `tma` functions in Gluon, especially `async_gather` and `async_scatter` which will happily trigger illegal instruction if runtime invariants are not satisfied * Verify `NVMMASharedEncodingAttr` is valid and fix all the unit tests * Misc other cleanups Some of these should eventually get moved up to Python for better UX but it's sufficient for the error message to be more actionable.
1 parent f171598 commit 84214b6

File tree

29 files changed

+239
-135
lines changed

29 files changed

+239
-135
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,15 @@ bool isInnermostContiguous(MemDescType type, unsigned numElems);
294294
LinearLayout inferReshapeLinearLayout(TensorOrMemDesc srcTy,
295295
ArrayRef<int64_t> dstShape);
296296

297+
FailureOr<SmallVector<int64_t>>
298+
getTMABlockShape(ArrayRef<int64_t> shapePerCTA, int elementBitWidth,
299+
int swizzleBytes, bool fp4Padded, bool isTransposed,
300+
bool packedSize, function_ref<InFlightDiagnostic()> emitError);
301+
SmallVector<int64_t> getTMABlockShape(ArrayRef<int64_t> shapePerCTA,
302+
int elementBitWidth, int swizzleBytes,
303+
bool fp4Padded, bool isTransposed,
304+
bool packedSize);
305+
297306
// Verify the types of operations that operate on memory.
298307
LogicalResult verifyMemoryOpTypes(Operation *op, ShapedType srcTy,
299308
ShapedType dstTy);

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ def NVMMASharedEncodingAttr : TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_share
457457
int getVec() const;
458458
}];
459459
let hasCustomAssemblyFormat = 1;
460+
let genVerifyDecl = 1;
460461
}
461462

462463
def AMDRotatingSharedEncodingAttr :

include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,13 @@ triton::gpu::SharedEncodingTrait
3131
getEncodingFromDescriptor(Operation *op, RankedTensorType tensorType,
3232
Value desc);
3333

34-
SmallVector<int64_t> getTMABlockShape(ArrayRef<int64_t> shapePerCTA,
35-
int elementBitWidth, int swizzleBytes,
36-
bool fp4Padded, bool transposed,
37-
bool packedSize);
38-
3934
inline SmallVector<int64_t> getTMABlockShape(Attribute encoding,
4035
ArrayRef<int64_t> shapePerCTA,
4136
bool packedSize) {
4237
auto mmaEnc = cast<gpu::NVMMASharedEncodingAttr>(encoding);
43-
return getTMABlockShape(shapePerCTA, mmaEnc.getElementBitWidth(),
44-
mmaEnc.getSwizzlingByteWidth(), mmaEnc.getFp4Padded(),
45-
mmaEnc.getTransposed(), packedSize);
38+
return triton::gpu::getTMABlockShape(
39+
shapePerCTA, mmaEnc.getElementBitWidth(), mmaEnc.getSwizzlingByteWidth(),
40+
mmaEnc.getFp4Padded(), mmaEnc.getTransposed(), packedSize);
4641
}
4742

4843
inline SmallVector<int64_t> getTMABlockShape(RankedTensorType ty,

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2131,6 +2131,16 @@ void NVMMASharedEncodingAttr::print(AsmPrinter &printer) const {
21312131
printer << "}>";
21322132
}
21332133

2134+
LogicalResult
2135+
NVMMASharedEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2136+
unsigned swizzlingByteWidth, bool transposed,
2137+
unsigned elementBitWidth, bool fp4Padded,
2138+
CGAEncodingAttr CGALayout) {
2139+
if (elementBitWidth == 0)
2140+
return emitError() << "elementBitWidth must be non-zero";
2141+
return success();
2142+
}
2143+
21342144
int NVMMASharedEncodingAttr::getVec() const {
21352145
if (getSwizzlingByteWidth() == 0)
21362146
return 1;
@@ -2469,8 +2479,8 @@ CGAEncodingAttr DotOperandEncodingAttr::getCGALayout() const {
24692479
LinearLayout(std::move(bases), dims, true));
24702480
}
24712481
LogicalResult DotOperandEncodingAttr::verify(
2472-
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
2473-
unsigned opIdx, Attribute parent, unsigned kWidth) {
2482+
function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned opIdx,
2483+
Attribute parent, unsigned kWidth) {
24742484
if (opIdx != 0 && opIdx != 1) {
24752485
return emitError() << "ttg.dot_op opIdx parameter can be 0 or 1, got: "
24762486
<< opIdx;
@@ -3963,6 +3973,48 @@ LinearLayout triton::gpu::inferReshapeLinearLayout(TensorOrMemDesc srcTy,
39633973
return dst;
39643974
}
39653975

3976+
FailureOr<SmallVector<int64_t>> triton::gpu::getTMABlockShape(
3977+
ArrayRef<int64_t> shapePerCTA, int elementBitWidth, int swizzleBytes,
3978+
bool fp4Padded, bool isTransposed, bool packedSize,
3979+
function_ref<InFlightDiagnostic()> emitError) {
3980+
SmallVector<int64_t> blockShape(shapePerCTA);
3981+
int contigDim = isTransposed ? 0 : blockShape.size() - 1;
3982+
if (fp4Padded)
3983+
blockShape[contigDim] *= 2;
3984+
// All dimensions must be at most 256
3985+
constexpr int64_t dimMax = 256;
3986+
for (auto &size : blockShape)
3987+
size = std::min(size, dimMax);
3988+
// Last dim must equal the swizzle byte size
3989+
if (swizzleBytes != 0) {
3990+
auto contigDimSize = (8 * swizzleBytes) / elementBitWidth;
3991+
if (blockShape[contigDim] < contigDimSize) {
3992+
return emitError() << "block shape along the contiguous dimension "
3993+
<< contigDim
3994+
<< " is too small for the swizzle byte size "
3995+
<< swizzleBytes << " in an NVMMASharedLayout, got "
3996+
<< blockShape[contigDim] << " but expected at least "
3997+
<< contigDimSize;
3998+
}
3999+
blockShape[contigDim] = contigDimSize;
4000+
}
4001+
if (fp4Padded && packedSize) {
4002+
blockShape[contigDim] /= 2;
4003+
}
4004+
return blockShape;
4005+
}
4006+
SmallVector<int64_t> triton::gpu::getTMABlockShape(
4007+
ArrayRef<int64_t> shapePerCTA, int elementBitWidth, int swizzleBytes,
4008+
bool fp4Padded, bool isTransposed, bool packedSize) {
4009+
return *getTMABlockShape(
4010+
shapePerCTA, elementBitWidth, swizzleBytes, fp4Padded, isTransposed,
4011+
packedSize, []() -> InFlightDiagnostic {
4012+
llvm::report_fatal_error(
4013+
"Block shape is too small for the swizzle byte "
4014+
"size in NVMMA Shared Layout.");
4015+
});
4016+
}
4017+
39664018
SetVector<int> triton::gpu::getPartitionIds(Operation *op) {
39674019
auto attrs = op->getAttr(kPartitionAttrName);
39684020
SmallVector<int> partitionIds;

lib/Dialect/TritonGPU/IR/Types.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,16 @@ LogicalResult MemDescType::verify(function_ref<InFlightDiagnostic()> emitError,
192192
}
193193
}
194194

195+
if (auto enc = dyn_cast<NVMMASharedEncodingAttr>(encoding)) {
196+
SmallVector<int64_t> shapePerCTA(getShapePerCTA(enc, allocShape));
197+
auto blockShape = ArrayRef(shapePerCTA).take_back(enc.getRank());
198+
if (failed(getTMABlockShape(blockShape, enc.getElementBitWidth(),
199+
enc.getSwizzlingByteWidth(), enc.getFp4Padded(),
200+
enc.getTransposed(), /*packedSize=*/false,
201+
emitError)))
202+
return failure();
203+
}
204+
195205
return success();
196206
}
197207

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,7 @@ LogicalResult TMEMCopyOp::verify() {
870870
return emitOpError("Incorrect tmem layout.");
871871
}
872872
if (tmemEnc.getBlockM() != 128) {
873-
return emitOpError("Tmem layout ahouls have M=128.");
873+
return emitOpError("Tmem layout must have blockM=128.");
874874
}
875875
if (nvmmaEnc && nvmmaEnc.getSwizzlingByteWidth() == 0) {
876876
return emitOpError("Source layout should be swizzled.");

lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -116,35 +116,6 @@ ttg::SharedEncodingTrait getEncodingFromDescriptor(Operation *op,
116116
return updateEncodingForShape(op, sharedEnc, tensorType);
117117
}
118118

119-
SmallVector<int64_t> getTMABlockShape(ArrayRef<int64_t> shapePerCTA,
120-
int elementBitWidth, int swizzleBytes,
121-
bool fp4Padded, bool isTransposed,
122-
bool packedSize) {
123-
SmallVector<int64_t> blockShape(shapePerCTA);
124-
int contigDim = isTransposed ? 0 : blockShape.size() - 1;
125-
if (fp4Padded) {
126-
blockShape[contigDim] *= 2;
127-
}
128-
// All dimensions must be at most 256
129-
constexpr int64_t dimMax = 256;
130-
for (auto &size : blockShape) {
131-
size = std::min(size, dimMax);
132-
}
133-
// Last dim must equal the swizzle byte size
134-
if (swizzleBytes != 0) {
135-
auto contigDimSize = (8 * swizzleBytes) / elementBitWidth;
136-
if (blockShape[contigDim] < contigDimSize) {
137-
llvm::report_fatal_error("Block shape is too small for the swizzle byte "
138-
"size in NVMMA Shared Layout.");
139-
}
140-
blockShape[contigDim] = contigDimSize;
141-
}
142-
if (fp4Padded && packedSize) {
143-
blockShape[contigDim] /= 2;
144-
}
145-
return blockShape;
146-
}
147-
148119
std::optional<int> getTMASwizzleMode(Operation *op, TensorDescType ty) {
149120
auto encoding = ty.getBlockType().getEncoding();
150121
auto mmaEncoding = dyn_cast<ttg::NVMMASharedEncodingAttr>(encoding);

python/examples/gluon/01-attention-forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ def _attn_fwd_correction_epilogue(config, prog, s_tmem, M, corr_consumer, epi_pr
778778
o_tmem, o_bar, o_consumer = o_consumer.acquire()
779779

780780
# Shared memory subtile size is limited by the swizzle byte size.
781-
contigDimSize: gl.constexpr = o_smem.type.layout.swizzle_byte_width * 8 / o_smem.type.element_ty.primitive_bitwidth
781+
contigDimSize: gl.constexpr = o_smem.type.layout.swizzle_byte_width * 8 // o_smem.type.element_ty.primitive_bitwidth
782782
if o_smem.type.shape[1] // config.SPLIT_D_FACTOR >= contigDimSize:
783783
SPLIT_N_FACTOR: gl.constexpr = config.SPLIT_D_FACTOR
784784
else:

python/triton/experimental/gluon/language/nvidia/blackwell/tma.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,25 @@
2020
]
2121

2222

23+
def _check_gather_scatter(tensor_desc, x_offsets, smem, op_name, smem_name):
24+
# Tensor descriptor must be 2D and layout must match the shared memory layout.
25+
assert len(
26+
tensor_desc.block_shape
27+
) == 2, f"async {op_name} requires a 2D tensor descriptor, but got one with rank {len(tensor_desc.block_shape)}"
28+
assert tensor_desc.layout == smem.layout, f"tensor descriptor layout {tensor_desc.layout} does not match {smem_name} shared memory layout {smem.layout}"
29+
# Row offsets must be 1D and have at least 8 rows.
30+
assert len(
31+
x_offsets.shape
32+
) == 1, f"async {op_name} requires a 1D tensor of row offsets, but got one with rank {len(x_offsets.shape)}"
33+
assert x_offsets.shape[0] >= 8, f"async {op_name} requires at least 8 rows, but got {x_offsets.shape[0]}"
34+
# Block shape must be [1, Y] where Y >= min_cols.
35+
min_cols = 32 // tensor_desc.dtype.primitive_bitwidth * 8
36+
assert tensor_desc.block_shape[
37+
0] == 1, f"async {op_name} requires the tensor descriptor's block shape to have 1 row, but got {tensor_desc.block_shape}"
38+
assert tensor_desc.block_shape[
39+
1] >= min_cols, f"async {op_name} requires the tensor descriptor's block shape to have at least {min_cols} columns, but got {tensor_desc.block_shape[1]}"
40+
41+
2342
@builtin
2443
def async_gather(tensor_desc, x_offsets, y_offset, barrier, result, pred=True, _semantic=None):
2544
"""
@@ -33,6 +52,7 @@ def async_gather(tensor_desc, x_offsets, y_offset, barrier, result, pred=True, _
3352
result (tensor_memory_descriptor): Result shared memory, must have NVMMASharedLayout.
3453
pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True.
3554
"""
55+
_check_gather_scatter(tensor_desc, x_offsets, result, "gather", "result")
3656
pred = _semantic.to_tensor(pred)
3757
y_offset = _semantic.to_tensor(y_offset)
3858
_semantic.builder.create_async_tma_gather(tensor_desc.handle, x_offsets.handle, y_offset.handle, barrier.handle,
@@ -50,5 +70,6 @@ def async_scatter(tensor_desc, x_offsets, y_offset, src, _semantic=None):
5070
y_offset (int): Scalar Y offset.
5171
src (tensor_memory_descriptor): The source data, must be in NVMMASharedLayout.
5272
"""
73+
_check_gather_scatter(tensor_desc, x_offsets, src, "scatter", "source")
5374
y_offset = _semantic.to_tensor(y_offset)
5475
_semantic.builder.create_async_tma_scatter(tensor_desc.handle, x_offsets.handle, y_offset.handle, src.handle)

python/triton/experimental/gluon/language/nvidia/hopper/tma.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def layout(self):
8787

8888
@builtin
8989
def async_copy_global_to_shared(tensor_desc, coord, barrier, result, pred=True, _semantic=None):
90+
assert tensor_desc.layout == result.layout, f"tensor descriptor layout {tensor_desc.layout} does not match result shared memory layout {result.layout}"
9091
coord = _semantic._convert_to_ir_values(coord, require_i64=False)
9192
pred = _semantic.to_tensor(pred)
9293
_semantic.builder.create_async_tma_copy_global_to_local(tensor_desc.handle, coord, barrier.handle, result.handle,
@@ -95,6 +96,7 @@ def async_copy_global_to_shared(tensor_desc, coord, barrier, result, pred=True,
9596

9697
@builtin
9798
def async_copy_shared_to_global(tensor_desc, coord, src, _semantic=None):
99+
assert tensor_desc.layout == src.layout, f"tensor descriptor layout {tensor_desc.layout} does not match source shared memory layout {src.layout}"
98100
coord = _semantic._convert_to_ir_values(coord, require_i64=False)
99101
_semantic.builder.create_async_tma_copy_local_to_global(tensor_desc.handle, coord, src.handle)
100102

0 commit comments

Comments
 (0)