Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
30fe42f
add getTMABlockShape for im2col mode
bingyizh233 Jan 27, 2026
8836ee6
Fix and add lit test
bingyizh233 Jan 27, 2026
ecc4977
Save changes
bingyizh233 Jan 28, 2026
f55b15d
Add more constraint and error check
bingyizh233 Jan 28, 2026
fd7680b
Add lit test
bingyizh233 Jan 28, 2026
1649e9a
Clean up
bingyizh233 Jan 28, 2026
91ebdb1
Clean up
bingyizh233 Jan 30, 2026
acb5dda
Clean up
bingyizh233 Jan 30, 2026
4f3edab
Clean up
bingyizh233 Jan 30, 2026
57e69d8
Clean up
bingyizh233 Jan 30, 2026
9a36eba
Clean up
bingyizh233 Jan 30, 2026
0225bd0
Add TMA msg order test
bingyizh233 Jan 31, 2026
51314a9
Add test for TMA message ordering consistency between im2col mode and…
bingyizh233 Jan 31, 2026
1ac7484
Add TMA layout test for checking the compatibility between im2col mod…
bingyizh233 Jan 31, 2026
30393c3
Remote default tiled mode
bingyizh233 Jan 31, 2026
e7cd39c
Merge remote-tracking branch 'origin/TMA-im2col-lowering' into TMA-im…
bingyizh233 Jan 31, 2026
5a33701
Update layout convert test
bingyizh233 Feb 2, 2026
a978df1
Clean up
bingyizh233 Feb 2, 2026
1ae042e
Suppose im2col mode for getTensorDescMetadata
bingyizh233 Feb 2, 2026
db9a801
Clean up
bingyizh233 Feb 2, 2026
d0440fe
Clean up
bingyizh233 Feb 2, 2026
c27e4d0
Clean up
bingyizh233 Feb 2, 2026
408f74b
Merge branch 'main' into TMA-im2col-lowering
bingyizh233 Feb 2, 2026
67198a8
Merge branch 'main' into TMA-im2col-lowering
bingyizh233 Feb 2, 2026
9de58ec
Merge branch 'main' into TMA-im2col-lowering
bingyizh233 Feb 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
#include "llvm/Support/Casting.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"

namespace mlir::triton::nvidia_gpu {

Expand Down Expand Up @@ -51,7 +51,9 @@ inline SmallVector<int64_t> getTMABlockShape(triton::gpu::MemDescType ty,
}

FailureOr<int> getTMASwizzleMode(Location loc, TensorDescType ty);
FailureOr<int> getTMASwizzleMode(Location loc, TensorDescIm2ColType ty);
FailureOr<int> getTMAElementType(Location loc, TensorDescType ty);
FailureOr<int> getTMAElementType(Location loc, TensorDescIm2ColType ty);

LogicalResult createTMADesc(Value tmaPtr, MakeTensorDescOp op,
OpBuilder &builder);
Expand Down
29 changes: 23 additions & 6 deletions lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ ttg::SharedEncodingTrait getEncodingFromDescriptor(Operation *op,
return updateEncodingForShape(op, sharedEnc, tensorType);
}

FailureOr<int> getTMASwizzleMode(Location loc, TensorDescType ty) {
auto encoding = ty.getBlockType().getEncoding();
static FailureOr<int> getTMASwizzleModeImpl(Location loc,
RankedTensorType blockType) {
auto encoding = blockType.getEncoding();
auto mmaEncoding = dyn_cast<ttg::NVMMASharedEncodingAttr>(encoding);
unsigned swizzleBytes = mmaEncoding ? mmaEncoding.getSwizzlingByteWidth() : 0;
if (!mmaEncoding) {
Expand Down Expand Up @@ -140,6 +141,14 @@ FailureOr<int> getTMASwizzleMode(Location loc, TensorDescType ty) {
return swizzleMode;
}

FailureOr<int> getTMASwizzleMode(Location loc, TensorDescType ty) {
return getTMASwizzleModeImpl(loc, ty.getBlockType());
}

FailureOr<int> getTMASwizzleMode(Location loc, TensorDescIm2ColType ty) {
return getTMASwizzleModeImpl(loc, ty.getBlockType());
}

enum TMA_ELEMENT_TYPES {
TMA_U8 = 0,
TMA_U16 = 1,
Expand All @@ -160,15 +169,15 @@ enum TMA_ELEMENT_TYPES {
TMA_B6P2X16 = 15,
};

FailureOr<int> getTMAElementType(Location loc, TensorDescType ty) {
auto encoding = ty.getBlockType().getEncoding();
auto mmaEncoding = dyn_cast<ttg::NVMMASharedEncodingAttr>(encoding);
static FailureOr<int> getTMAElementTypeImpl(Location loc,
RankedTensorType blockType) {
auto encoding = blockType.getEncoding();
bool fp4Padded = isFp4Padded(encoding);

if (fp4Padded)
return TMA_B4X16_P64;

auto elemTy = ty.getBlockType().getElementType();
auto elemTy = blockType.getElementType();
if (elemTy.isBF16()) {
return TMA_BF16;
} else if (elemTy.isF16()) {
Expand Down Expand Up @@ -197,6 +206,14 @@ FailureOr<int> getTMAElementType(Location loc, TensorDescType ty) {
<< elemSize;
}

FailureOr<int> getTMAElementType(Location loc, TensorDescType ty) {
return getTMAElementTypeImpl(loc, ty.getBlockType());
}

FailureOr<int> getTMAElementType(Location loc, TensorDescIm2ColType ty) {
return getTMAElementTypeImpl(loc, ty.getBlockType());
}

LogicalResult createTMADesc(Value tmaPtr, MakeTensorDescOp op,
OpBuilder &builder) {
using namespace mlir;
Expand Down
30 changes: 20 additions & 10 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,30 +210,40 @@ py::list getTensorDescMetadata(ModuleOp &mod) {
assert(kernelFunc);

for (auto [i, arg] : llvm::enumerate(kernelFunc.getArguments())) {
auto descTy = dyn_cast<TensorDescType>(arg.getType());
if (!descTy)
auto tiledDescTy = dyn_cast<TensorDescType>(arg.getType());
auto im2colDescTy = dyn_cast<ttng::TensorDescIm2ColType>(arg.getType());
Comment thread
peterbell10 marked this conversation as resolved.
Outdated
if (!tiledDescTy && !im2colDescTy)
continue;

auto blockType = descTy.getBlockType();
auto blockType =
tiledDescTy ? tiledDescTy.getBlockType() : im2colDescTy.getBlockType();
auto encoding = blockType.getEncoding();

py::dict metadata;
if (isa<ttg::NVMMASharedEncodingAttr>(encoding)) {
auto mmaEncoding = dyn_cast<ttg::NVMMASharedEncodingAttr>(encoding);
auto swizzle = ttng::getTMASwizzleMode(arg.getLoc(), descTy);
auto elemType = ttng::getTMAElementType(arg.getLoc(), descTy);
FailureOr<int> swizzle, elemType;
ttg::TMAMode tmaMode;
if (tiledDescTy) {
swizzle = ttng::getTMASwizzleMode(arg.getLoc(), tiledDescTy);
elemType = ttng::getTMAElementType(arg.getLoc(), tiledDescTy);
tmaMode = ttg::TMAMode::Tiled;
} else {
swizzle = ttng::getTMASwizzleMode(arg.getLoc(), im2colDescTy);
elemType = ttng::getTMAElementType(arg.getLoc(), im2colDescTy);
tmaMode = ttg::TMAMode::Im2Col;
}
if (failed(swizzle) || failed(elemType))
throw py::type_error("invalid TMA descriptor type");
// TensorDescType is for tiled mode (not im2col)
auto blockSize = ttng::getTMABlockShape(blockType, /*packedSize=*/false,
ttg::TMAMode::Tiled);
auto blockSize =
ttng::getTMABlockShape(blockType, /*packedSize=*/false, tmaMode);
metadata["swizzle"] = *swizzle;
metadata["elem_size"] =
descTy.getBlockType().getElementTypeBitWidth() / 8;
metadata["elem_size"] = blockType.getElementTypeBitWidth() / 8;
metadata["elem_type"] = *elemType;
metadata["block_size"] =
std::vector<int>(blockSize.begin(), blockSize.end());
metadata["fp4_padded"] = mmaEncoding && mmaEncoding.getFp4Padded();
metadata["is_im2col"] = im2colDescTy != nullptr;
} else {
auto blockShape = blockType.getShape();
metadata["block_size"] =
Expand Down