Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ class DialectVerifyTensorLayoutInterface
LogicalResult verifyGatherScatterOp(Operation *op, ShapedType blockType,
ShapedType resultType,
ShapedType indicesType);
LogicalResult verifyDescriptorLoadStoreOp(Operation *op, TensorDescType desc,
LogicalResult verifyDescriptorLoadStoreOp(Operation *op,
TensorDescInterface desc,
ShapedType tensor);

} // namespace triton
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,25 +89,4 @@ def TTG_TensorMemoryScalesEncodingAttr : AttrDef<TritonNvidiaGPU_Dialect, "Tenso
let assemblyFormat = "`<` struct(params) `>`";
}


def TTNG_TensorModeAttr : I32EnumAttr<
"TensorMode", "",
[
I32EnumAttrCase<"TILED", 0, "tiled">,
I32EnumAttrCase<"IM2COL", 1, "im2col">
]> {
let cppNamespace = "::mlir::triton::nvidia_gpu";
let description = [{
Enum attribute for TMA tensor mode.

TILED: Tiled mode for regular tensor memory access.
IM2COL: Im2col mode for convolution-friendly tensor memory access.

See:
- https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-tiled-mode
- https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-mode
}];
}


#endif
18 changes: 8 additions & 10 deletions include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,10 @@ def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local",
tensor. The data copied depends on the global memory descriptor pointed to
by `desc`.

The operation supports two tensor modes:
- TILED (default): Regular tiled tensor memory access
The tensor mode is determined by the descriptor type:
- tt.tensordesc: TILED mode - Regular tiled tensor memory access
- See: https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-tiled-mode
- IM2COL: Im2col mode for convolution-friendly access patterns
- ttng.tensordesc_im2col: IM2COL mode - Im2col mode for convolution-friendly access patterns
- In IM2COL mode, 'coord' is the coordinates in the input tensor
- For example, for a 4D tensor (NHWC), 'coord' is [batch_idx, channel_idx, h, w]
- In IM2COL mode, additional `offsets` must be provided (uint16 values)
Expand All @@ -317,7 +317,7 @@ def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local",

let hasVerifier = 1;
let arguments = (ins
Arg<TT_TensorDescType, "", [MemRead<GlobalMemory>]>:$desc,
Arg<TT_AnyTensorDescType, "", [MemRead<GlobalMemory>]>:$desc,
Variadic<I32>:$coord,
Variadic<I16>:$offsets,
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$barrier,
Expand All @@ -326,8 +326,7 @@ def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local",
UnitAttr:$multicast,
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache,
DefaultValuedAttr<TT_EvictionPolicyAttr, "triton::EvictionPolicy::NORMAL">:$evict,
DefaultValuedAttr<BoolAttr, "false">:$isVolatile,
DefaultValuedAttr<TTNG_TensorModeAttr, "triton::nvidia_gpu::TensorMode::TILED">:$tensorMode
DefaultValuedAttr<BoolAttr, "false">:$isVolatile
);

let builders = [
Expand All @@ -337,16 +336,15 @@ def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local",
CArg<"bool", "false">:$multicast,
CArg<"triton::CacheModifier", "triton::CacheModifier::NONE">:$cache,
CArg<"triton::EvictionPolicy", "triton::EvictionPolicy::NORMAL">:$evict,
CArg<"bool", "false">:$isVolatile,
CArg<"triton::nvidia_gpu::TensorMode", "triton::nvidia_gpu::TensorMode::TILED">:$tensorMode), [{
CArg<"bool", "false">:$isVolatile), [{
build($_builder, $_state, desc, coord, /*offsets=*/ValueRange{}, barrier,
result, pred, multicast, cache, evict, isVolatile, tensorMode);
result, pred, multicast, cache, evict, isVolatile);
}]>
];

let assemblyFormat = [{
$desc `[` $coord `]` (`offsets` `=` `[` $offsets^ `]`)? $result `,` $barrier `,` $pred
oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict | `tensorMode` `=` $tensorMode)
oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict)
attr-dict `:` qualified(type($desc)) `,` qualified(type($barrier)) `->` qualified(type($result))
}];
}
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1445,7 +1445,8 @@ LogicalResult DescriptorScatterOp::verify() {
}

// -- DescriptorLoadOp --
LogicalResult verifyDescriptorLoadStoreOp(Operation *op, TensorDescType desc,
LogicalResult verifyDescriptorLoadStoreOp(Operation *op,
TensorDescInterface desc,
ShapedType tensor) {
RankedTensorType block = desc.getSignlessBlockType();
if (block.getElementType() != tensor.getElementType()) {
Expand Down
52 changes: 29 additions & 23 deletions lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,15 +228,13 @@ LogicalResult ClusterWaitOp::verify() {
}

// -- TMA operation verifiers --
static LogicalResult verifyTMAEncoding(Operation *op,
TypedValue<TensorDescType> desc,
static LogicalResult verifyTMAEncoding(Operation *op, TensorDescInterface desc,
Attribute enc) {
auto nvmma = dyn_cast<NVMMASharedEncodingAttr>(enc);
if (!nvmma)
return op->emitOpError("TMA descriptor must have NVMMA shared layout");
auto descTy = desc.getType();
auto descEnc = dyn_cast_if_present<NVMMASharedEncodingAttr>(
descTy.getBlockType().getEncoding());
desc.getBlockType().getEncoding());
// NOTE: Cannot do descEnc != enc as the encodings may differ in rank for
// rank-reducing loads
if (!descEnc || descEnc.getTransposed() != nvmma.getTransposed() ||
Expand All @@ -253,7 +251,7 @@ static LogicalResult verifyTMAEncoding(Operation *op,
}

static LogicalResult verifyAsyncTMALoadOp(Operation *op,
TypedValue<TensorDescType> desc,
TensorDescInterface desc,
TypedValue<MemDescType> barrier,
MemDescType resultType) {
if (failed(verifyBarrierType(op, barrier.getType())))
Expand All @@ -273,15 +271,20 @@ static LogicalResult verifyAsyncTMAStoreOp(Operation *op,
// do not support fp4_padded operands.
if (isFp4Padded(srcEnc))
return op->emitOpError("does not support fp4_padded operands");
return verifyTMAEncoding(op, desc, srcEnc);
return verifyTMAEncoding(op, desc.getType(), srcEnc);
}

// Helper to determine if the descriptor type is for im2col mode
static bool isIm2ColDescriptor(Type descType) {
return isa<TensorDescIm2ColType>(descType);
}

static LogicalResult verifyAsyncTMACoords(Operation *op, ValueRange coords,
TypedValue<TensorDescType> desc,
TensorMode tensorMode) {
unsigned blockRank = desc.getType().getBlockType().getRank();
TensorDescInterface desc,
bool isIm2Col) {
unsigned blockRank = desc.getBlockType().getRank();

if (tensorMode == TensorMode::IM2COL) {
if (isIm2Col) {
// For IM2COL mode, coordinates are for the full tensor (3D-5D)
// not the 2D block shape
if (coords.size() < 3)
Expand All @@ -304,9 +307,9 @@ static LogicalResult verifyAsyncTMACoords(Operation *op, ValueRange coords,
return success();
}

static LogicalResult verifyTMAMode(Operation *op, TensorMode tensorMode,
static LogicalResult verifyTMAMode(Operation *op, bool isIm2Col,
ValueRange coords, ValueRange offsets) {
if (tensorMode == TensorMode::IM2COL) {
if (isIm2Col) {
if (offsets.empty())
return op->emitOpError("IM2COL mode requires offsets to be provided");

Expand All @@ -329,26 +332,28 @@ static LogicalResult verifyTMAMode(Operation *op, TensorMode tensorMode,

// -- AsyncTMACopyGlobalToLocalOp --
LogicalResult AsyncTMACopyGlobalToLocalOp::verify() {
if (failed(
verifyAsyncTMACoords(*this, getCoord(), getDesc(), getTensorMode())))
auto descType = getDesc().getType();
bool isIm2Col = isIm2ColDescriptor(descType);
auto descInterface = cast<TensorDescInterface>(descType);

if (failed(verifyAsyncTMACoords(*this, getCoord(), descInterface, isIm2Col)))
return failure();
auto resultType = getResult().getType();
if (failed(
verifyDescriptorLoadStoreOp(*this, getDesc().getType(), resultType)))
if (failed(verifyDescriptorLoadStoreOp(*this, descType, resultType)))
return failure();
if (failed(verifyAsyncTMALoadOp(*this, getDesc(), getBarrier(),
if (failed(verifyAsyncTMALoadOp(*this, descInterface, getBarrier(),
getResult().getType())))
return failure();
if (failed(verifyTMAMode(*this, getTensorMode(), getCoord(), getOffsets())))
if (failed(verifyTMAMode(*this, isIm2Col, getCoord(), getOffsets())))
return failure();
return success();
}

// -- AsyncTMACopyLocalToGlobalOp --
LogicalResult AsyncTMACopyLocalToGlobalOp::verify() {
// Store ops only support TILED mode
if (failed(verifyAsyncTMACoords(*this, getCoord(), getDesc(),
TensorMode::TILED)))
if (failed(verifyAsyncTMACoords(*this, getCoord(), getDesc().getType(),
/*isIm2Col=*/false)))
return failure();
MemDescType srcType = getSrc().getType();
if (failed(verifyDescriptorLoadStoreOp(*this, getDesc().getType(), srcType)))
Expand All @@ -359,8 +364,8 @@ LogicalResult AsyncTMACopyLocalToGlobalOp::verify() {
// -- AsyncTMAReduceOp --
LogicalResult AsyncTMAReduceOp::verify() {
// Reduce ops only support TILED mode
if (failed(verifyAsyncTMACoords(*this, getCoord(), getDesc(),
TensorMode::TILED)))
if (failed(verifyAsyncTMACoords(*this, getCoord(), getDesc().getType(),
/*isIm2Col=*/false)))
return failure();
MemDescType srcType = getSrc().getType();
if (failed(verifyDescriptorLoadStoreOp(*this, getDesc().getType(), srcType)))
Expand All @@ -371,7 +376,8 @@ LogicalResult AsyncTMAReduceOp::verify() {
// -- AsyncTMAGatherOp --
LogicalResult AsyncTMAGatherOp::verify() {
auto resultType = getResult().getType();
if (failed(verifyAsyncTMALoadOp(*this, getDesc(), getBarrier(), resultType)))
if (failed(verifyAsyncTMALoadOp(*this, getDesc().getType(), getBarrier(),
resultType)))
return failure();
// `tile::gather4` does not support fp4_padded operands.
if (isFp4Padded(getResult().getType().getEncoding()))
Expand Down
12 changes: 6 additions & 6 deletions test/TritonNvidiaGPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @tma_im2col_missing_offsets(%arg0: !tt.tensordesc<tensor<64x128xf16, #nvmma_128>>) {
tt.func public @tma_im2col_missing_offsets(%arg0: !ttng.tensordesc_im2col<tensor<64x128xf16, #nvmma_128>>) {
%true = arith.constant true
%c0_i32 = arith.constant 0 : i32
%0 = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable>
%1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared2, #smem, mutable>
// expected-error @below {{IM2COL mode requires offsets to be provided}}
ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32, %c0_i32, %c0_i32] %0, %1, %true tensorMode = im2col : !tt.tensordesc<tensor<64x128xf16, #nvmma_128>>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable>
ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32, %c0_i32, %c0_i32] %0, %1, %true : !ttng.tensordesc_im2col<tensor<64x128xf16, #nvmma_128>>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable>
tt.return
}
}
Expand All @@ -167,14 +167,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @tma_im2col_wrong_offset_count(%arg0: !tt.tensordesc<tensor<64x128xf16, #nvmma_128>>) {
tt.func public @tma_im2col_wrong_offset_count(%arg0: !ttng.tensordesc_im2col<tensor<64x128xf16, #nvmma_128>>) {
%true = arith.constant true
%c0_i32 = arith.constant 0 : i32
%c1_i16 = arith.constant 1 : i16
%0 = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable>
%1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared2, #smem, mutable>
// expected-error @below {{IM2COL mode with 4D coordinates requires 2 offsets, but got 1}}
ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32, %c0_i32, %c0_i32] offsets = [%c1_i16] %0, %1, %true tensorMode = im2col : !tt.tensordesc<tensor<64x128xf16, #nvmma_128>>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable>
ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32, %c0_i32, %c0_i32] offsets = [%c1_i16] %0, %1, %true : !ttng.tensordesc_im2col<tensor<64x128xf16, #nvmma_128>>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable>
tt.return
}
}
Expand Down Expand Up @@ -205,13 +205,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @tma_im2col_2d_invalid(%arg0: !tt.tensordesc<tensor<64x128xf16, #nvmma_128>>) {
tt.func public @tma_im2col_2d_invalid(%arg0: !ttng.tensordesc_im2col<tensor<64x128xf16, #nvmma_128>>) {
%true = arith.constant true
%c0_i32 = arith.constant 0 : i32
%0 = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable>
%1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared2, #smem, mutable>
// expected-error @below {{IM2COL mode requires at least 3D coordinates, but got 2D}}
ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %0, %1, %true tensorMode = im2col : !tt.tensordesc<tensor<64x128xf16, #nvmma_128>>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable>
ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %0, %1, %true : !ttng.tensordesc_im2col<tensor<64x128xf16, #nvmma_128>>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable>
tt.return
}
}
Expand Down
19 changes: 9 additions & 10 deletions test/TritonNvidiaGPU/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -105,35 +105,35 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: @tma_load_im2col_3d
// CHECK: ttng.async_tma_copy_global_to_local {{.*}} offsets = [{{.*}}] {{.*}} tensorMode = im2col
tt.func public @tma_load_im2col_3d(%desc: !tt.tensordesc<tensor<64x128xf16, #nvmma_128>>) {
// CHECK: ttng.async_tma_copy_global_to_local {{.*}} offsets = [{{.*}}] {{.*}} : !ttng.tensordesc_im2col
tt.func public @tma_load_im2col_3d(%desc: !ttng.tensordesc_im2col<tensor<64x128xf16, #nvmma_128>>) {
%true = arith.constant true
%c0 = arith.constant 0 : i32
%off = arith.constant 1 : i16
%buf = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable>
%bar = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared3, #smem, mutable>
ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared3, #smem, mutable>
ttng.async_tma_copy_global_to_local %desc[%c0, %c0, %c0] offsets = [%off] %buf, %bar, %true tensorMode = im2col : !tt.tensordesc<tensor<64x128xf16, #nvmma_128>>, !ttg.memdesc<1xi64, #shared3, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable>
ttng.async_tma_copy_global_to_local %desc[%c0, %c0, %c0] offsets = [%off] %buf, %bar, %true : !ttng.tensordesc_im2col<tensor<64x128xf16, #nvmma_128>>, !ttg.memdesc<1xi64, #shared3, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable>
tt.return
}

// CHECK-LABEL: @tma_load_im2col_4d
// CHECK: ttng.async_tma_copy_global_to_local {{.*}} offsets = [{{.*}}, {{.*}}] {{.*}} tensorMode = im2col
tt.func public @tma_load_im2col_4d(%desc: !tt.tensordesc<tensor<64x128xf16, #nvmma_128>>) {
// CHECK: ttng.async_tma_copy_global_to_local {{.*}} offsets = [{{.*}}, {{.*}}] {{.*}} : !ttng.tensordesc_im2col
tt.func public @tma_load_im2col_4d(%desc: !ttng.tensordesc_im2col<tensor<64x128xf16, #nvmma_128>>) {
%true = arith.constant true
%c0 = arith.constant 0 : i32
%off1 = arith.constant 1 : i16
%off2 = arith.constant 2 : i16
%buf = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable>
%bar = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared3, #smem, mutable>
ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared3, #smem, mutable>
ttng.async_tma_copy_global_to_local %desc[%c0, %c0, %c0, %c0] offsets = [%off1, %off2] %buf, %bar, %true tensorMode = im2col : !tt.tensordesc<tensor<64x128xf16, #nvmma_128>>, !ttg.memdesc<1xi64, #shared3, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable>
ttng.async_tma_copy_global_to_local %desc[%c0, %c0, %c0, %c0] offsets = [%off1, %off2] %buf, %bar, %true : !ttng.tensordesc_im2col<tensor<64x128xf16, #nvmma_128>>, !ttg.memdesc<1xi64, #shared3, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable>
tt.return
}

// CHECK-LABEL: @tma_load_im2col_5d
// CHECK: ttng.async_tma_copy_global_to_local {{.*}} offsets = [{{.*}}, {{.*}}, {{.*}}] {{.*}} tensorMode = im2col
tt.func public @tma_load_im2col_5d(%desc: !tt.tensordesc<tensor<64x128xf16, #nvmma_128>>) {
// CHECK: ttng.async_tma_copy_global_to_local {{.*}} offsets = [{{.*}}, {{.*}}, {{.*}}] {{.*}} : !ttng.tensordesc_im2col
tt.func public @tma_load_im2col_5d(%desc: !ttng.tensordesc_im2col<tensor<64x128xf16, #nvmma_128>>) {
%true = arith.constant true
%c0 = arith.constant 0 : i32
%off1 = arith.constant 1 : i16
Expand All @@ -142,14 +142,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
%buf = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable>
%bar = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared3, #smem, mutable>
ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared3, #smem, mutable>
ttng.async_tma_copy_global_to_local %desc[%c0, %c0, %c0, %c0, %c0] offsets = [%off1, %off2, %off3] %buf, %bar, %true tensorMode = im2col : !tt.tensordesc<tensor<64x128xf16, #nvmma_128>>, !ttg.memdesc<1xi64, #shared3, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable>
ttng.async_tma_copy_global_to_local %desc[%c0, %c0, %c0, %c0, %c0] offsets = [%off1, %off2, %off3] %buf, %bar, %true : !ttng.tensordesc_im2col<tensor<64x128xf16, #nvmma_128>>, !ttg.memdesc<1xi64, #shared3, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable>
tt.return
}

// CHECK-LABEL: @tma_load_tiled_mode
// CHECK: ttng.async_tma_copy_global_to_local {{.*}}[{{.*}}, {{.*}}] %{{.*}}, %{{.*}}, {{.*}} : !tt.tensordesc
// CHECK-NOT: offsets
// CHECK-NOT: tensorMode
tt.func public @tma_load_tiled_mode(%desc: !tt.tensordesc<tensor<64x128xf16, #nvmma_128>>) {
%true = arith.constant true
%c0 = arith.constant 0 : i32
Expand Down
Loading