Skip to content

Commit 639808d

Browse files
authored
[Nvidia] Enable TMA im2col mode - LLVM lowering (#9322)
# Summary This is the fifth PR in a series that enables TMA im2col mode (in addition to the existing tiled mode) for NVIDIA GPUs. The goal of the series is to support TMA im2col mode in Gluon DSL. - First PR: #9202 - Second PR: #9225 - Third PR: #9303 - Fourth PR: #9305 - -> Fifth PR: #9322 PTX ISA documentation for TMA im2col mode: https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-mode TMA tensor descriptor documentation: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html # Summary of Changes Added LLVM lowering logic for `AsyncTMACopyGlobalToLocalOpConversion` to support im2col mode. ## Im2col Mode Constraints ### pixelsPerColumn (non-contiguous dimension) - **Maximum size**: 1024 elements - **Corresponds to**: Spatial dimensions (N, D, H, W) - **Block shape**: Restricted to match `shapePerCTA` (no splitting) - **Rationale**: Avoids generating multiple TMA messages along spatial dimensions, eliminating complex offset calculations that would depend on input tensor shape and padding - **Note**: 1024 is sufficient for most practical use cases ### channelsPerPixel (contiguous dimension) - **Maximum size**: 256 elements, or swizzle byte size if swizzle is enabled - **Multiple messages**: Supported when channel dimension exceeds block size - **Offset application**: Only coord[0] (channel coordinate in PTX order) receives non-zero offsets ## Key Implementation Details 1. **Offset application**: For im2col mode, only the channel dimension receives non-zero offsets; spatial dimension offsets are always 0 (verified by assertion) 2. **Im2col offsets reversal**: Spatial offsets (e.g., `off_w`, `off_h`) are reversed to match PTX/CUDA innermost-to-outermost ordering, consistent with coordinate handling 3. **Alignment with tiled mode**: These constraints align with tiled mode behavior used for GEMM operations <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [] I have not added any `lit` tests. - [x] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent b5e3800 commit 639808d

12 files changed

Lines changed: 316 additions & 48 deletions

File tree

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,14 +294,21 @@ bool isInnermostContiguous(MemDescType type, unsigned numElems);
294294
LinearLayout inferReshapeLinearLayout(TensorOrMemDesc srcTy,
295295
ArrayRef<int64_t> dstShape);
296296

297+
// TMA tensor access modes
298+
enum class TMAMode {
299+
Tiled, // Regular tiled tensor memory access
300+
Im2Col // Im2col mode for convolution-friendly access patterns
301+
};
302+
297303
FailureOr<SmallVector<int64_t>>
298304
getTMABlockShape(ArrayRef<int64_t> shapePerCTA, int elementBitWidth,
299305
int swizzleBytes, bool fp4Padded, bool isTransposed,
300-
bool packedSize, function_ref<InFlightDiagnostic()> emitError);
306+
bool packedSize, function_ref<InFlightDiagnostic()> emitError,
307+
TMAMode mode);
301308
SmallVector<int64_t> getTMABlockShape(ArrayRef<int64_t> shapePerCTA,
302309
int elementBitWidth, int swizzleBytes,
303310
bool fp4Padded, bool isTransposed,
304-
bool packedSize);
311+
bool packedSize, TMAMode mode);
305312

306313
// Verify the types of operations that operate on memory.
307314
LogicalResult verifyMemoryOpTypes(Operation *op, ShapedType srcTy,

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class NVMMASharedEncodingAttr;
1818
class TensorOrMemDesc;
1919
class MemDescType;
2020
class CGAEncodingAttr;
21+
enum class TMAMode;
2122

2223
// - BlockedEncodingAttrs have the following input dimensions.
2324
//
@@ -61,6 +62,7 @@ LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout);
6162
// swizzling.
6263
LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
6364
NVMMASharedEncodingAttr shared,
65+
TMAMode mode,
6466
bool disableSwizzle = false);
6567

6668
// Given a linear layout where the input dimensions contain a "block" dimension,

include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
55
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
66
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
7-
#include "llvm/Support/Casting.h"
7+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
88

99
namespace mlir::triton::nvidia_gpu {
1010

@@ -29,27 +29,29 @@ getEncodingFromDescriptor(Operation *op, RankedTensorType tensorType,
2929

3030
inline SmallVector<int64_t> getTMABlockShape(Attribute encoding,
3131
ArrayRef<int64_t> shapePerCTA,
32-
bool packedSize) {
32+
bool packedSize,
33+
gpu::TMAMode mode) {
3334
auto mmaEnc = cast<gpu::NVMMASharedEncodingAttr>(encoding);
3435
return triton::gpu::getTMABlockShape(
3536
shapePerCTA, mmaEnc.getElementBitWidth(), mmaEnc.getSwizzlingByteWidth(),
36-
mmaEnc.getFp4Padded(), mmaEnc.getTransposed(), packedSize);
37+
mmaEnc.getFp4Padded(), mmaEnc.getTransposed(), packedSize, mode);
3738
}
3839

39-
inline SmallVector<int64_t> getTMABlockShape(RankedTensorType ty,
40-
bool packedSize) {
40+
inline SmallVector<int64_t>
41+
getTMABlockShape(RankedTensorType ty, bool packedSize, gpu::TMAMode mode) {
4142
auto shapePerCTA = gpu::getShapePerCTA(ty);
42-
return getTMABlockShape(ty.getEncoding(), shapePerCTA, packedSize);
43+
return getTMABlockShape(ty.getEncoding(), shapePerCTA, packedSize, mode);
4344
}
4445

4546
inline SmallVector<int64_t> getTMABlockShape(triton::gpu::MemDescType ty,
46-
bool packedSize) {
47+
bool packedSize,
48+
gpu::TMAMode mode) {
4749
auto shapePerCTA = gpu::getShapePerCTA(ty);
48-
return getTMABlockShape(ty.getEncoding(), shapePerCTA, packedSize);
50+
return getTMABlockShape(ty.getEncoding(), shapePerCTA, packedSize, mode);
4951
}
5052

51-
FailureOr<int> getTMASwizzleMode(Location loc, TensorDescType ty);
52-
FailureOr<int> getTMAElementType(Location loc, TensorDescType ty);
53+
FailureOr<int> getTMASwizzleMode(Location loc, triton::TensorDescInterface ty);
54+
FailureOr<int> getTMAElementType(Location loc, triton::TensorDescInterface ty);
5355

5456
LogicalResult createTMADesc(Value tmaPtr, MakeTensorDescOp op,
5557
OpBuilder &builder);

lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
2525
addConversion([ctx](TensorDescType type) -> std::optional<Type> {
2626
return LLVM::LLVMPointerType::get(ctx, 0);
2727
});
28+
addConversion(
29+
[ctx](nvidia_gpu::TensorDescIm2ColType type) -> std::optional<Type> {
30+
return LLVM::LLVMPointerType::get(ctx, 0);
31+
});
2832
addConversion([&](RankedTensorType type) -> std::optional<Type> {
2933
return convertTritonTensorType(type, targetInfo);
3034
});

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 90 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4169,14 +4169,76 @@ LinearLayout triton::gpu::inferReshapeLinearLayout(TensorOrMemDesc srcTy,
41694169
return dst;
41704170
}
41714171

4172-
FailureOr<SmallVector<int64_t>> triton::gpu::getTMABlockShape(
4173-
ArrayRef<int64_t> shapePerCTA, int elementBitWidth, int swizzleBytes,
4174-
bool fp4Padded, bool isTransposed, bool packedSize,
4175-
function_ref<InFlightDiagnostic()> emitError) {
4172+
// Helper function for im2col mode block shape calculation.
4173+
// Im2col mode produces a 2D block: [pixelsPerColumn, channelsPerPixel]
4174+
// Constraints:
4175+
// - channelsPerPixel (contigDim): max 256, or swizzle byte size if enabled
4176+
// - pixelsPerColumn (otherDim): max 1024, no splitting (single TMA message)
4177+
// Doc:
4178+
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html
4179+
static FailureOr<SmallVector<int64_t>>
4180+
getTMABlockShapeIm2Col(ArrayRef<int64_t> shapePerCTA, int elementBitWidth,
4181+
int swizzleBytes, bool fp4Padded, bool isTransposed,
4182+
bool packedSize,
4183+
function_ref<InFlightDiagnostic()> emitError) {
4184+
assert(shapePerCTA.size() == 2 && "im2col mode requires a 2D block shape");
4185+
41764186
SmallVector<int64_t> blockShape(shapePerCTA);
41774187
int contigDim = isTransposed ? 0 : blockShape.size() - 1;
41784188
if (fp4Padded)
41794189
blockShape[contigDim] *= 2;
4190+
4191+
constexpr int64_t contigDimMax = 256;
4192+
constexpr int64_t otherDimMax = 1024;
4193+
int otherDim = (contigDim == 0) ? 1 : 0;
4194+
4195+
// Check that pixelsPerColumn doesn't exceed the hardware maximum of 1024.
4196+
// This constraint ensures a single TMA message can cover all pixels,
4197+
// avoiding the need for multiple messages along spatial dimensions (N, D,
4198+
// H, W). Supporting pixelsPerColumn > 1024 would require computing offsets
4199+
// that depend on input tensor shape and padding, which is non-trivial.
4200+
if (blockShape[otherDim] > otherDimMax) {
4201+
return emitError() << "im2col mode: pixelsPerColumn dimension "
4202+
<< blockShape[otherDim]
4203+
<< " exceeds the maximum supported value of "
4204+
<< otherDimMax;
4205+
}
4206+
4207+
// Clamp the contiguous dimension (channelsPerPixel) to max 256
4208+
blockShape[contigDim] = std::min(blockShape[contigDim], contigDimMax);
4209+
4210+
// Contiguous dim must equal the swizzle byte size if swizzle is enabled
4211+
if (swizzleBytes != 0) {
4212+
auto contigDimSize = (8 * swizzleBytes) / elementBitWidth;
4213+
if (blockShape[contigDim] < contigDimSize) {
4214+
return emitError() << "im2col mode: block shape along the contiguous "
4215+
"dimension "
4216+
<< contigDim
4217+
<< " is too small for the swizzle byte size "
4218+
<< swizzleBytes << ", got " << blockShape[contigDim]
4219+
<< " but expected at least " << contigDimSize;
4220+
}
4221+
blockShape[contigDim] = contigDimSize;
4222+
}
4223+
4224+
if (fp4Padded && packedSize) {
4225+
blockShape[contigDim] /= 2;
4226+
}
4227+
return blockShape;
4228+
}
4229+
4230+
// Tiled mode block shape calculation.
4231+
static FailureOr<SmallVector<int64_t>>
4232+
getTMABlockShapeTiled(ArrayRef<int64_t> shapePerCTA, int elementBitWidth,
4233+
int swizzleBytes, bool fp4Padded, bool isTransposed,
4234+
bool packedSize,
4235+
function_ref<InFlightDiagnostic()> emitError) {
4236+
SmallVector<int64_t> blockShape(shapePerCTA);
4237+
4238+
int contigDim = isTransposed ? 0 : blockShape.size() - 1;
4239+
if (fp4Padded)
4240+
blockShape[contigDim] *= 2;
4241+
41804242
// All dimensions must be at most 256
41814243
constexpr int64_t dimMax = 256;
41824244
for (auto &size : blockShape)
@@ -4199,16 +4261,32 @@ FailureOr<SmallVector<int64_t>> triton::gpu::getTMABlockShape(
41994261
}
42004262
return blockShape;
42014263
}
4264+
4265+
FailureOr<SmallVector<int64_t>> triton::gpu::getTMABlockShape(
4266+
ArrayRef<int64_t> shapePerCTA, int elementBitWidth, int swizzleBytes,
4267+
bool fp4Padded, bool isTransposed, bool packedSize,
4268+
function_ref<InFlightDiagnostic()> emitError, TMAMode mode) {
4269+
if (mode == TMAMode::Im2Col) {
4270+
return getTMABlockShapeIm2Col(shapePerCTA, elementBitWidth, swizzleBytes,
4271+
fp4Padded, isTransposed, packedSize,
4272+
emitError);
4273+
}
4274+
// Tiled mode
4275+
return getTMABlockShapeTiled(shapePerCTA, elementBitWidth, swizzleBytes,
4276+
fp4Padded, isTransposed, packedSize, emitError);
4277+
}
4278+
42024279
SmallVector<int64_t> triton::gpu::getTMABlockShape(
42034280
ArrayRef<int64_t> shapePerCTA, int elementBitWidth, int swizzleBytes,
4204-
bool fp4Padded, bool isTransposed, bool packedSize) {
4205-
return *getTMABlockShape(
4206-
shapePerCTA, elementBitWidth, swizzleBytes, fp4Padded, isTransposed,
4207-
packedSize, []() -> InFlightDiagnostic {
4208-
llvm::report_fatal_error(
4209-
"Block shape is too small for the swizzle byte "
4210-
"size in NVMMA Shared Layout.");
4211-
});
4281+
bool fp4Padded, bool isTransposed, bool packedSize, TMAMode mode) {
4282+
auto emitFatalError = []() -> InFlightDiagnostic {
4283+
llvm::report_fatal_error("getTMABlockShape failed: invalid block shape "
4284+
"for TMA operation.");
4285+
};
4286+
4287+
return *getTMABlockShape(shapePerCTA, elementBitWidth, swizzleBytes,
4288+
fp4Padded, isTransposed, packedSize, emitFatalError,
4289+
mode);
42124290
}
42134291

42144292
SetVector<int> triton::gpu::getPartitionIds(Operation *op) {

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,13 +195,14 @@ LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared,
195195

196196
LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
197197
NVMMASharedEncodingAttr shared,
198-
bool disableSwizzle) {
198+
TMAMode mode, bool disableSwizzle) {
199199
MLIRContext *ctx = shared.getContext();
200200
int rank = shape.size();
201201
auto shapePerCTA = getShapePerCTA(shared, shape);
202202
auto kOffset = S("offset");
203-
auto tmaShape = triton::nvidia_gpu::getTMABlockShape(shared, shapePerCTA,
204-
/*packedSize=*/true);
203+
auto tmaShape =
204+
triton::nvidia_gpu::getTMABlockShape(shared, shapePerCTA,
205+
/*packedSize=*/true, mode);
205206
if (shared.getSwizzlingByteWidth() == 0) {
206207
auto outDimNames = standardOutDimNames(ctx, rank);
207208
LinearLayout layout = LinearLayout::identity1D(tmaShape[rank - 1], kOffset,
@@ -1186,7 +1187,8 @@ LinearLayout TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape,
11861187
} else if (auto shared = dyn_cast<SharedLinearEncodingAttr>(layout)) {
11871188
result = shared.toLinearLayout(shape);
11881189
} else if (auto shared = dyn_cast<NVMMASharedEncodingAttr>(layout)) {
1189-
result = nvmmaSharedToLinearLayout(shape, shared);
1190+
// The shared memory layout is independent of TMA mode (Tiled vs Im2Col)
1191+
result = nvmmaSharedToLinearLayout(shape, shared, TMAMode::Tiled);
11901192
} else if (auto sbl = dyn_cast<AMDRotatingSharedEncodingAttr>(layout)) {
11911193
result = sharedToLinearLayoutAMDRotating(shape, sbl);
11921194
} else if (auto tensorMemoryEncoding =

lib/Dialect/TritonGPU/IR/Types.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ LogicalResult MemDescType::verify(function_ref<InFlightDiagnostic()> emitError,
198198
if (failed(getTMABlockShape(blockShape, enc.getElementBitWidth(),
199199
enc.getSwizzlingByteWidth(), enc.getFp4Padded(),
200200
enc.getTransposed(), /*packedSize=*/false,
201-
emitError)))
201+
emitError, TMAMode::Tiled)))
202202
return failure();
203203
} else if (auto enc = dyn_cast<SharedLinearEncodingAttr>(encoding)) {
204204
auto blockShape = ArrayRef(allocShape).take_back(enc.getRank());

lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,9 @@ ttg::SharedEncodingTrait getEncodingFromDescriptor(Operation *op,
106106
return updateEncodingForShape(op, sharedEnc, tensorType);
107107
}
108108

109-
FailureOr<int> getTMASwizzleMode(Location loc, TensorDescType ty) {
110-
auto encoding = ty.getBlockType().getEncoding();
109+
FailureOr<int> getTMASwizzleMode(Location loc, tt::TensorDescInterface ty) {
110+
auto blockType = ty.getBlockType();
111+
auto encoding = blockType.getEncoding();
111112
auto mmaEncoding = dyn_cast<ttg::NVMMASharedEncodingAttr>(encoding);
112113
unsigned swizzleBytes = mmaEncoding ? mmaEncoding.getSwizzlingByteWidth() : 0;
113114
if (!mmaEncoding) {
@@ -160,15 +161,15 @@ enum TMA_ELEMENT_TYPES {
160161
TMA_B6P2X16 = 15,
161162
};
162163

163-
FailureOr<int> getTMAElementType(Location loc, TensorDescType ty) {
164-
auto encoding = ty.getBlockType().getEncoding();
165-
auto mmaEncoding = dyn_cast<ttg::NVMMASharedEncodingAttr>(encoding);
164+
FailureOr<int> getTMAElementType(Location loc, tt::TensorDescInterface ty) {
165+
auto blockType = ty.getBlockType();
166+
auto encoding = blockType.getEncoding();
166167
bool fp4Padded = isFp4Padded(encoding);
167168

168169
if (fp4Padded)
169170
return TMA_B4X16_P64;
170171

171-
auto elemTy = ty.getBlockType().getElementType();
172+
auto elemTy = blockType.getElementType();
172173
if (elemTy.isBF16()) {
173174
return TMA_BF16;
174175
} else if (elemTy.isF16()) {
@@ -216,8 +217,9 @@ LogicalResult createTMADesc(Value tmaPtr, MakeTensorDescOp op,
216217

217218
int paddingScale = fp4Padded ? 2 : 1;
218219
auto shapePerCTA = gpu::getShapePerCTA(encoding, op.getTensorShape());
219-
auto blockShape =
220-
getTMABlockShape(encoding, shapePerCTA, /*packedSize=*/false);
220+
// MakeTensorDescOp creates tiled descriptors (not im2col)
221+
auto blockShape = getTMABlockShape(encoding, shapePerCTA,
222+
/*packedSize=*/false, gpu::TMAMode::Tiled);
221223
auto contigDimSize = blockShape.back();
222224

223225
llvm::SmallVector<Value> boxDim;

python/src/ir.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,11 @@ py::list getTensorDescMetadata(ModuleOp &mod) {
210210
assert(kernelFunc);
211211

212212
for (auto [i, arg] : llvm::enumerate(kernelFunc.getArguments())) {
213-
auto descTy = dyn_cast<TensorDescType>(arg.getType());
213+
auto descTy = dyn_cast<TensorDescInterface>(arg.getType());
214214
if (!descTy)
215215
continue;
216216

217+
bool isIm2Col = isa<ttng::TensorDescIm2ColType>(arg.getType());
217218
auto blockType = descTy.getBlockType();
218219
auto encoding = blockType.getEncoding();
219220

@@ -224,14 +225,16 @@ py::list getTensorDescMetadata(ModuleOp &mod) {
224225
auto elemType = ttng::getTMAElementType(arg.getLoc(), descTy);
225226
if (failed(swizzle) || failed(elemType))
226227
throw py::type_error("invalid TMA descriptor type");
227-
auto blockSize = ttng::getTMABlockShape(blockType, /*packedSize=*/false);
228+
auto tmaMode = isIm2Col ? ttg::TMAMode::Im2Col : ttg::TMAMode::Tiled;
229+
auto blockSize =
230+
ttng::getTMABlockShape(blockType, /*packedSize=*/false, tmaMode);
228231
metadata["swizzle"] = *swizzle;
229-
metadata["elem_size"] =
230-
descTy.getBlockType().getElementTypeBitWidth() / 8;
232+
metadata["elem_size"] = blockType.getElementTypeBitWidth() / 8;
231233
metadata["elem_type"] = *elemType;
232234
metadata["block_size"] =
233235
std::vector<int>(blockSize.begin(), blockSize.end());
234236
metadata["fp4_padded"] = mmaEncoding && mmaEncoding.getFp4Padded();
237+
metadata["is_im2col"] = isIm2Col;
235238
} else {
236239
auto blockShape = blockType.getShape();
237240
metadata["block_size"] =

0 commit comments

Comments
 (0)