Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
30fe42f
add getTMABlockShape for im2col mode
bingyizh233 Jan 27, 2026
8836ee6
Fix and add lit test
bingyizh233 Jan 27, 2026
ecc4977
Save changes
bingyizh233 Jan 28, 2026
f55b15d
Add more constraint and error check
bingyizh233 Jan 28, 2026
fd7680b
Add lit test
bingyizh233 Jan 28, 2026
1649e9a
Clean up
bingyizh233 Jan 28, 2026
91ebdb1
Clean up
bingyizh233 Jan 30, 2026
acb5dda
Clean up
bingyizh233 Jan 30, 2026
4f3edab
Clean up
bingyizh233 Jan 30, 2026
57e69d8
Clean up
bingyizh233 Jan 30, 2026
9a36eba
Clean up
bingyizh233 Jan 30, 2026
0225bd0
Add TMA msg order test
bingyizh233 Jan 31, 2026
51314a9
Add test for TMA message ordering consistency between im2col mode and…
bingyizh233 Jan 31, 2026
1ac7484
Add TMA layout test for checking the compatibility between im2col mod…
bingyizh233 Jan 31, 2026
30393c3
Remote default tiled mode
bingyizh233 Jan 31, 2026
e7cd39c
Merge remote-tracking branch 'origin/TMA-im2col-lowering' into TMA-im…
bingyizh233 Jan 31, 2026
5a33701
Update layout convert test
bingyizh233 Feb 2, 2026
a978df1
Clean up
bingyizh233 Feb 2, 2026
1ae042e
Suppose im2col mode for getTensorDescMetadata
bingyizh233 Feb 2, 2026
db9a801
Clean up
bingyizh233 Feb 2, 2026
d0440fe
Clean up
bingyizh233 Feb 2, 2026
c27e4d0
Clean up
bingyizh233 Feb 2, 2026
408f74b
Merge branch 'main' into TMA-im2col-lowering
bingyizh233 Feb 2, 2026
67198a8
Merge branch 'main' into TMA-im2col-lowering
bingyizh233 Feb 2, 2026
9de58ec
Merge branch 'main' into TMA-im2col-lowering
bingyizh233 Feb 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,14 +294,21 @@ bool isInnermostContiguous(MemDescType type, unsigned numElems);
LinearLayout inferReshapeLinearLayout(TensorOrMemDesc srcTy,
ArrayRef<int64_t> dstShape);

// TMA tensor access modes
enum class TMAMode {
Tiled, // Regular tiled tensor memory access
Im2Col // Im2col mode for convolution-friendly access patterns
};

FailureOr<SmallVector<int64_t>>
getTMABlockShape(ArrayRef<int64_t> shapePerCTA, int elementBitWidth,
int swizzleBytes, bool fp4Padded, bool isTransposed,
bool packedSize, function_ref<InFlightDiagnostic()> emitError);
bool packedSize, function_ref<InFlightDiagnostic()> emitError,
TMAMode mode);
SmallVector<int64_t> getTMABlockShape(ArrayRef<int64_t> shapePerCTA,
int elementBitWidth, int swizzleBytes,
bool fp4Padded, bool isTransposed,
bool packedSize);
bool packedSize, TMAMode mode);

// Verify the types of operations that operate on memory.
LogicalResult verifyMemoryOpTypes(Operation *op, ShapedType srcTy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class NVMMASharedEncodingAttr;
class TensorOrMemDesc;
class MemDescType;
class CGAEncodingAttr;
enum class TMAMode;

// - BlockedEncodingAttrs have the following input dimensions.
//
Expand Down Expand Up @@ -61,6 +62,7 @@ LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout);
// swizzling.
LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
NVMMASharedEncodingAttr shared,
TMAMode mode,
bool disableSwizzle = false);

// Given a linear layout where the input dimensions contain a "block" dimension,
Expand Down
22 changes: 12 additions & 10 deletions include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
#include "llvm/Support/Casting.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"

namespace mlir::triton::nvidia_gpu {

Expand All @@ -29,27 +29,29 @@ getEncodingFromDescriptor(Operation *op, RankedTensorType tensorType,

inline SmallVector<int64_t> getTMABlockShape(Attribute encoding,
ArrayRef<int64_t> shapePerCTA,
bool packedSize) {
bool packedSize,
gpu::TMAMode mode) {
auto mmaEnc = cast<gpu::NVMMASharedEncodingAttr>(encoding);
return triton::gpu::getTMABlockShape(
shapePerCTA, mmaEnc.getElementBitWidth(), mmaEnc.getSwizzlingByteWidth(),
mmaEnc.getFp4Padded(), mmaEnc.getTransposed(), packedSize);
mmaEnc.getFp4Padded(), mmaEnc.getTransposed(), packedSize, mode);
}

inline SmallVector<int64_t> getTMABlockShape(RankedTensorType ty,
bool packedSize) {
inline SmallVector<int64_t>
getTMABlockShape(RankedTensorType ty, bool packedSize, gpu::TMAMode mode) {
auto shapePerCTA = gpu::getShapePerCTA(ty);
return getTMABlockShape(ty.getEncoding(), shapePerCTA, packedSize);
return getTMABlockShape(ty.getEncoding(), shapePerCTA, packedSize, mode);
}

inline SmallVector<int64_t> getTMABlockShape(triton::gpu::MemDescType ty,
bool packedSize) {
bool packedSize,
gpu::TMAMode mode) {
auto shapePerCTA = gpu::getShapePerCTA(ty);
return getTMABlockShape(ty.getEncoding(), shapePerCTA, packedSize);
return getTMABlockShape(ty.getEncoding(), shapePerCTA, packedSize, mode);
}

FailureOr<int> getTMASwizzleMode(Location loc, TensorDescType ty);
FailureOr<int> getTMAElementType(Location loc, TensorDescType ty);
FailureOr<int> getTMASwizzleMode(Location loc, triton::TensorDescInterface ty);
FailureOr<int> getTMAElementType(Location loc, triton::TensorDescInterface ty);

LogicalResult createTMADesc(Value tmaPtr, MakeTensorDescOp op,
OpBuilder &builder);
Expand Down
4 changes: 4 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
addConversion([ctx](TensorDescType type) -> std::optional<Type> {
return LLVM::LLVMPointerType::get(ctx, 0);
});
addConversion(
[ctx](nvidia_gpu::TensorDescIm2ColType type) -> std::optional<Type> {
return LLVM::LLVMPointerType::get(ctx, 0);
});
addConversion([&](RankedTensorType type) -> std::optional<Type> {
return convertTritonTensorType(type, targetInfo);
});
Expand Down
102 changes: 90 additions & 12 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4079,14 +4079,76 @@ LinearLayout triton::gpu::inferReshapeLinearLayout(TensorOrMemDesc srcTy,
return dst;
}

FailureOr<SmallVector<int64_t>> triton::gpu::getTMABlockShape(
ArrayRef<int64_t> shapePerCTA, int elementBitWidth, int swizzleBytes,
bool fp4Padded, bool isTransposed, bool packedSize,
function_ref<InFlightDiagnostic()> emitError) {
// Helper function for im2col mode block shape calculation.
// Im2col mode produces a 2D block: [pixelsPerColumn, channelsPerPixel]
// Constraints:
// - channelsPerPixel (contigDim): max 256, or swizzle byte size if enabled
// - pixelsPerColumn (otherDim): max 1024, no splitting (single TMA message)
Comment thread
peterbell10 marked this conversation as resolved.
// Doc:
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html
static FailureOr<SmallVector<int64_t>>
getTMABlockShapeIm2Col(ArrayRef<int64_t> shapePerCTA, int elementBitWidth,
int swizzleBytes, bool fp4Padded, bool isTransposed,
bool packedSize,
function_ref<InFlightDiagnostic()> emitError) {
assert(shapePerCTA.size() == 2 && "im2col mode requires a 2D block shape");

SmallVector<int64_t> blockShape(shapePerCTA);
int contigDim = isTransposed ? 0 : blockShape.size() - 1;
if (fp4Padded)
blockShape[contigDim] *= 2;

constexpr int64_t contigDimMax = 256;
constexpr int64_t otherDimMax = 1024;
int otherDim = (contigDim == 0) ? 1 : 0;

// Check that pixelsPerColumn doesn't exceed the hardware maximum of 1024.
// This constraint ensures a single TMA message can cover all pixels,
// avoiding the need for multiple messages along spatial dimensions (N, D,
// H, W). Supporting pixelsPerColumn > 1024 would require computing offsets
// that depend on input tensor shape and padding, which is non-trivial.
if (blockShape[otherDim] > otherDimMax) {
return emitError() << "im2col mode: pixelsPerColumn dimension "
<< blockShape[otherDim]
<< " exceeds the maximum supported value of "
<< otherDimMax;
}

// Clamp the contiguous dimension (channelsPerPixel) to max 256
blockShape[contigDim] = std::min(blockShape[contigDim], contigDimMax);

// Contiguous dim must equal the swizzle byte size if swizzle is enabled
if (swizzleBytes != 0) {
auto contigDimSize = (8 * swizzleBytes) / elementBitWidth;
if (blockShape[contigDim] < contigDimSize) {
return emitError() << "im2col mode: block shape along the contiguous "
"dimension "
<< contigDim
<< " is too small for the swizzle byte size "
<< swizzleBytes << ", got " << blockShape[contigDim]
<< " but expected at least " << contigDimSize;
}
blockShape[contigDim] = contigDimSize;
}

if (fp4Padded && packedSize) {
blockShape[contigDim] /= 2;
}
return blockShape;
}

// Tiled mode block shape calculation.
static FailureOr<SmallVector<int64_t>>
getTMABlockShapeTiled(ArrayRef<int64_t> shapePerCTA, int elementBitWidth,
int swizzleBytes, bool fp4Padded, bool isTransposed,
bool packedSize,
function_ref<InFlightDiagnostic()> emitError) {
SmallVector<int64_t> blockShape(shapePerCTA);

int contigDim = isTransposed ? 0 : blockShape.size() - 1;
if (fp4Padded)
blockShape[contigDim] *= 2;

// All dimensions must be at most 256
constexpr int64_t dimMax = 256;
for (auto &size : blockShape)
Expand All @@ -4109,16 +4171,32 @@ FailureOr<SmallVector<int64_t>> triton::gpu::getTMABlockShape(
}
return blockShape;
}

FailureOr<SmallVector<int64_t>> triton::gpu::getTMABlockShape(
ArrayRef<int64_t> shapePerCTA, int elementBitWidth, int swizzleBytes,
bool fp4Padded, bool isTransposed, bool packedSize,
function_ref<InFlightDiagnostic()> emitError, TMAMode mode) {
if (mode == TMAMode::Im2Col) {
return getTMABlockShapeIm2Col(shapePerCTA, elementBitWidth, swizzleBytes,
fp4Padded, isTransposed, packedSize,
emitError);
}
// Tiled mode
return getTMABlockShapeTiled(shapePerCTA, elementBitWidth, swizzleBytes,
fp4Padded, isTransposed, packedSize, emitError);
}

SmallVector<int64_t> triton::gpu::getTMABlockShape(
ArrayRef<int64_t> shapePerCTA, int elementBitWidth, int swizzleBytes,
bool fp4Padded, bool isTransposed, bool packedSize) {
return *getTMABlockShape(
shapePerCTA, elementBitWidth, swizzleBytes, fp4Padded, isTransposed,
packedSize, []() -> InFlightDiagnostic {
llvm::report_fatal_error(
"Block shape is too small for the swizzle byte "
"size in NVMMA Shared Layout.");
});
bool fp4Padded, bool isTransposed, bool packedSize, TMAMode mode) {
auto emitFatalError = []() -> InFlightDiagnostic {
llvm::report_fatal_error("getTMABlockShape failed: invalid block shape "
"for TMA operation.");
};

return *getTMABlockShape(shapePerCTA, elementBitWidth, swizzleBytes,
fp4Padded, isTransposed, packedSize, emitFatalError,
mode);
}

SetVector<int> triton::gpu::getPartitionIds(Operation *op) {
Expand Down
10 changes: 6 additions & 4 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,14 @@ LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared,

LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
NVMMASharedEncodingAttr shared,
bool disableSwizzle) {
TMAMode mode, bool disableSwizzle) {
MLIRContext *ctx = shared.getContext();
int rank = shape.size();
auto shapePerCTA = getShapePerCTA(shared, shape);
auto kOffset = S("offset");
auto tmaShape = triton::nvidia_gpu::getTMABlockShape(shared, shapePerCTA,
/*packedSize=*/true);
auto tmaShape =
triton::nvidia_gpu::getTMABlockShape(shared, shapePerCTA,
/*packedSize=*/true, mode);
if (shared.getSwizzlingByteWidth() == 0) {
auto outDimNames = standardOutDimNames(ctx, rank);
LinearLayout layout = LinearLayout::identity1D(tmaShape[rank - 1], kOffset,
Expand Down Expand Up @@ -1186,7 +1187,8 @@ LinearLayout TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape,
} else if (auto shared = dyn_cast<SharedLinearEncodingAttr>(layout)) {
result = shared.toLinearLayout(shape);
} else if (auto shared = dyn_cast<NVMMASharedEncodingAttr>(layout)) {
result = nvmmaSharedToLinearLayout(shape, shared);
// The shared memory layout is independent of TMA mode (Tiled vs Im2Col)
result = nvmmaSharedToLinearLayout(shape, shared, TMAMode::Tiled);
} else if (auto sbl = dyn_cast<AMDRotatingSharedEncodingAttr>(layout)) {
result = sharedToLinearLayoutAMDRotating(shape, sbl);
} else if (auto tensorMemoryEncoding =
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/IR/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ LogicalResult MemDescType::verify(function_ref<InFlightDiagnostic()> emitError,
if (failed(getTMABlockShape(blockShape, enc.getElementBitWidth(),
enc.getSwizzlingByteWidth(), enc.getFp4Padded(),
enc.getTransposed(), /*packedSize=*/false,
emitError)))
emitError, TMAMode::Tiled)))
return failure();
} else if (auto enc = dyn_cast<SharedLinearEncodingAttr>(encoding)) {
auto blockShape = ArrayRef(allocShape).take_back(enc.getRank());
Expand Down
18 changes: 10 additions & 8 deletions lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ ttg::SharedEncodingTrait getEncodingFromDescriptor(Operation *op,
return updateEncodingForShape(op, sharedEnc, tensorType);
}

FailureOr<int> getTMASwizzleMode(Location loc, TensorDescType ty) {
auto encoding = ty.getBlockType().getEncoding();
FailureOr<int> getTMASwizzleMode(Location loc, tt::TensorDescInterface ty) {
auto blockType = ty.getBlockType();
auto encoding = blockType.getEncoding();
auto mmaEncoding = dyn_cast<ttg::NVMMASharedEncodingAttr>(encoding);
unsigned swizzleBytes = mmaEncoding ? mmaEncoding.getSwizzlingByteWidth() : 0;
if (!mmaEncoding) {
Expand Down Expand Up @@ -160,15 +161,15 @@ enum TMA_ELEMENT_TYPES {
TMA_B6P2X16 = 15,
};

FailureOr<int> getTMAElementType(Location loc, TensorDescType ty) {
auto encoding = ty.getBlockType().getEncoding();
auto mmaEncoding = dyn_cast<ttg::NVMMASharedEncodingAttr>(encoding);
FailureOr<int> getTMAElementType(Location loc, tt::TensorDescInterface ty) {
auto blockType = ty.getBlockType();
auto encoding = blockType.getEncoding();
bool fp4Padded = isFp4Padded(encoding);

if (fp4Padded)
return TMA_B4X16_P64;

auto elemTy = ty.getBlockType().getElementType();
auto elemTy = blockType.getElementType();
if (elemTy.isBF16()) {
return TMA_BF16;
} else if (elemTy.isF16()) {
Expand Down Expand Up @@ -216,8 +217,9 @@ LogicalResult createTMADesc(Value tmaPtr, MakeTensorDescOp op,

int paddingScale = fp4Padded ? 2 : 1;
auto shapePerCTA = gpu::getShapePerCTA(encoding, op.getTensorShape());
auto blockShape =
getTMABlockShape(encoding, shapePerCTA, /*packedSize=*/false);
// MakeTensorDescOp creates tiled descriptors (not im2col)
auto blockShape = getTMABlockShape(encoding, shapePerCTA,
/*packedSize=*/false, gpu::TMAMode::Tiled);
auto contigDimSize = blockShape.back();

llvm::SmallVector<Value> boxDim;
Expand Down
11 changes: 7 additions & 4 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,11 @@ py::list getTensorDescMetadata(ModuleOp &mod) {
assert(kernelFunc);

for (auto [i, arg] : llvm::enumerate(kernelFunc.getArguments())) {
auto descTy = dyn_cast<TensorDescType>(arg.getType());
auto descTy = dyn_cast<TensorDescInterface>(arg.getType());
if (!descTy)
continue;

bool isIm2Col = isa<ttng::TensorDescIm2ColType>(arg.getType());
auto blockType = descTy.getBlockType();
auto encoding = blockType.getEncoding();

Expand All @@ -224,14 +225,16 @@ py::list getTensorDescMetadata(ModuleOp &mod) {
auto elemType = ttng::getTMAElementType(arg.getLoc(), descTy);
if (failed(swizzle) || failed(elemType))
throw py::type_error("invalid TMA descriptor type");
auto blockSize = ttng::getTMABlockShape(blockType, /*packedSize=*/false);
auto tmaMode = isIm2Col ? ttg::TMAMode::Im2Col : ttg::TMAMode::Tiled;
auto blockSize =
ttng::getTMABlockShape(blockType, /*packedSize=*/false, tmaMode);
metadata["swizzle"] = *swizzle;
metadata["elem_size"] =
descTy.getBlockType().getElementTypeBitWidth() / 8;
metadata["elem_size"] = blockType.getElementTypeBitWidth() / 8;
metadata["elem_type"] = *elemType;
metadata["block_size"] =
std::vector<int>(blockSize.begin(), blockSize.end());
metadata["fp4_padded"] = mmaEncoding && mmaEncoding.getFp4Padded();
metadata["is_im2col"] = isIm2Col;
} else {
auto blockShape = blockType.getShape();
metadata["block_size"] =
Expand Down
Loading
Loading