diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index 6369fb9994a5..db54dcd708cf 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -111,7 +111,8 @@ class DialectVerifyTensorLayoutInterface LogicalResult verifyGatherScatterOp(Operation *op, ShapedType blockType, ShapedType resultType, ShapedType indicesType); -LogicalResult verifyDescriptorLoadStoreOp(Operation *op, TensorDescType desc, +LogicalResult verifyDescriptorLoadStoreOp(Operation *op, + TensorDescInterface desc, ShapedType tensor); } // namespace triton diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td index 03478c37dfed..d84a600173d4 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td @@ -89,25 +89,4 @@ def TTG_TensorMemoryScalesEncodingAttr : AttrDef, - I32EnumAttrCase<"IM2COL", 1, "im2col"> - ]> { - let cppNamespace = "::mlir::triton::nvidia_gpu"; - let description = [{ - Enum attribute for TMA tensor mode. - - TILED: Tiled mode for regular tensor memory access. - IM2COL: Im2col mode for convolution-friendly tensor memory access. - - See: - - https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-tiled-mode - - https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-mode - }]; -} - - #endif diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index c691bade2b9f..653f27071637 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -300,10 +300,10 @@ def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", tensor. The data copied depends on the global memory descriptor pointed to by `desc`. - The operation supports two tensor modes: - - TILED (default): Regular tiled tensor memory access + The tensor mode is determined by the descriptor type: + - tt.tensordesc: TILED mode - Regular tiled tensor memory access - See: https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-tiled-mode - - IM2COL: Im2col mode for convolution-friendly access patterns + - ttng.tensordesc_im2col: IM2COL mode - Im2col mode for convolution-friendly access patterns - In IM2COL mode, 'coord' is the coordinates in the input tensor - For example, for a 4D tensor (NHWC), 'coord' is [batch_idx, channel_idx, h, w] - In IM2COL mode, additional `offsets` must be provided (uint16 values) @@ -317,7 +317,7 @@ def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", let hasVerifier = 1; let arguments = (ins - Arg]>:$desc, + Arg]>:$desc, Variadic:$coord, Variadic:$offsets, Arg]>:$barrier, @@ -326,8 +326,7 @@ def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", UnitAttr:$multicast, DefaultValuedAttr:$cache, DefaultValuedAttr:$evict, - DefaultValuedAttr:$isVolatile, - DefaultValuedAttr:$tensorMode + DefaultValuedAttr:$isVolatile ); let builders = [ @@ -337,16 +336,15 @@ def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", CArg<"bool", "false">:$multicast, CArg<"triton::CacheModifier", "triton::CacheModifier::NONE">:$cache, CArg<"triton::EvictionPolicy", "triton::EvictionPolicy::NORMAL">:$evict, - CArg<"bool", "false">:$isVolatile, - CArg<"triton::nvidia_gpu::TensorMode", "triton::nvidia_gpu::TensorMode::TILED">:$tensorMode), [{ + CArg<"bool", "false">:$isVolatile), [{ build($_builder, $_state, desc, coord, /*offsets=*/ValueRange{}, barrier, - result, pred, multicast, cache, evict, isVolatile, tensorMode); + result, pred, multicast, cache, evict, isVolatile); }]> ]; let assemblyFormat = [{ $desc `[` $coord `]` (`offsets` `=` `[` $offsets^ `]`)? $result `,` $barrier `,` $pred - oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict | `tensorMode` `=` $tensorMode) + oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict) attr-dict `:` qualified(type($desc)) `,` qualified(type($barrier)) `->` qualified(type($result)) }]; } diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 5f9df02006da..85e366c146e4 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -1445,7 +1445,8 @@ LogicalResult DescriptorScatterOp::verify() { } // -- DescriptorLoadOp -- -LogicalResult verifyDescriptorLoadStoreOp(Operation *op, TensorDescType desc, +LogicalResult verifyDescriptorLoadStoreOp(Operation *op, + TensorDescInterface desc, ShapedType tensor) { RankedTensorType block = desc.getSignlessBlockType(); if (block.getElementType() != tensor.getElementType()) { diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index afd83a78df01..1e7fcea3bc38 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -228,15 +228,13 @@ LogicalResult ClusterWaitOp::verify() { } // -- TMA operation verifiers -- -static LogicalResult verifyTMAEncoding(Operation *op, - TypedValue desc, +static LogicalResult verifyTMAEncoding(Operation *op, TensorDescInterface desc, Attribute enc) { auto nvmma = dyn_cast(enc); if (!nvmma) return op->emitOpError("TMA descriptor must have NVMMA shared layout"); - auto descTy = desc.getType(); auto descEnc = dyn_cast_if_present( - descTy.getBlockType().getEncoding()); + desc.getBlockType().getEncoding()); // NOTE: Cannot do descEnc != enc as the encodings may differ in rank for // rank-reducing loads if (!descEnc || descEnc.getTransposed() != nvmma.getTransposed() || @@ -253,7 +251,7 @@ static LogicalResult verifyTMAEncoding(Operation *op, } static LogicalResult verifyAsyncTMALoadOp(Operation *op, - TypedValue desc, + TensorDescInterface desc, TypedValue barrier, MemDescType resultType) { if (failed(verifyBarrierType(op, barrier.getType()))) @@ -273,15 +271,20 @@ static LogicalResult verifyAsyncTMAStoreOp(Operation *op, // do not support fp4_padded operands. if (isFp4Padded(srcEnc)) return op->emitOpError("does not support fp4_padded operands"); - return verifyTMAEncoding(op, desc, srcEnc); + return verifyTMAEncoding(op, desc.getType(), srcEnc); +} + +// Helper to determine if the descriptor type is for im2col mode +static bool isIm2ColDescriptor(Type descType) { + return isa(descType); } static LogicalResult verifyAsyncTMACoords(Operation *op, ValueRange coords, - TypedValue desc, - TensorMode tensorMode) { - unsigned blockRank = desc.getType().getBlockType().getRank(); + TensorDescInterface desc, + bool isIm2Col) { + unsigned blockRank = desc.getBlockType().getRank(); - if (tensorMode == TensorMode::IM2COL) { + if (isIm2Col) { // For IM2COL mode, coordinates are for the full tensor (3D-5D) // not the 2D block shape if (coords.size() < 3) @@ -304,9 +307,9 @@ static LogicalResult verifyAsyncTMACoords(Operation *op, ValueRange coords, return success(); } -static LogicalResult verifyTMAMode(Operation *op, TensorMode tensorMode, +static LogicalResult verifyTMAMode(Operation *op, bool isIm2Col, ValueRange coords, ValueRange offsets) { - if (tensorMode == TensorMode::IM2COL) { + if (isIm2Col) { if (offsets.empty()) return op->emitOpError("IM2COL mode requires offsets to be provided"); @@ -329,17 +332,19 @@ static LogicalResult verifyTMAMode(Operation *op, TensorMode tensorMode, // -- AsyncTMACopyGlobalToLocalOp -- LogicalResult AsyncTMACopyGlobalToLocalOp::verify() { - if (failed( - verifyAsyncTMACoords(*this, getCoord(), getDesc(), getTensorMode()))) + auto descType = getDesc().getType(); + bool isIm2Col = isIm2ColDescriptor(descType); + auto descInterface = cast(descType); + + if (failed(verifyAsyncTMACoords(*this, getCoord(), descInterface, isIm2Col))) return failure(); auto resultType = getResult().getType(); - if (failed( - verifyDescriptorLoadStoreOp(*this, getDesc().getType(), resultType))) + if (failed(verifyDescriptorLoadStoreOp(*this, descType, resultType))) return failure(); - if (failed(verifyAsyncTMALoadOp(*this, getDesc(), getBarrier(), + if (failed(verifyAsyncTMALoadOp(*this, descInterface, getBarrier(), getResult().getType()))) return failure(); - if (failed(verifyTMAMode(*this, getTensorMode(), getCoord(), getOffsets()))) + if (failed(verifyTMAMode(*this, isIm2Col, getCoord(), getOffsets()))) return failure(); return success(); } @@ -347,8 +352,8 @@ LogicalResult AsyncTMACopyGlobalToLocalOp::verify() { // -- AsyncTMACopyLocalToGlobalOp -- LogicalResult AsyncTMACopyLocalToGlobalOp::verify() { // Store ops only support TILED mode - if (failed(verifyAsyncTMACoords(*this, getCoord(), getDesc(), - TensorMode::TILED))) + if (failed(verifyAsyncTMACoords(*this, getCoord(), getDesc().getType(), + /*isIm2Col=*/false))) return failure(); MemDescType srcType = getSrc().getType(); if (failed(verifyDescriptorLoadStoreOp(*this, getDesc().getType(), srcType))) @@ -359,8 +364,8 @@ LogicalResult AsyncTMACopyLocalToGlobalOp::verify() { // -- AsyncTMAReduceOp -- LogicalResult AsyncTMAReduceOp::verify() { // Reduce ops only support TILED mode - if (failed(verifyAsyncTMACoords(*this, getCoord(), getDesc(), - TensorMode::TILED))) + if (failed(verifyAsyncTMACoords(*this, getCoord(), getDesc().getType(), + /*isIm2Col=*/false))) return failure(); MemDescType srcType = getSrc().getType(); if (failed(verifyDescriptorLoadStoreOp(*this, getDesc().getType(), srcType))) @@ -371,7 +376,8 @@ LogicalResult AsyncTMAReduceOp::verify() { // -- AsyncTMAGatherOp -- LogicalResult AsyncTMAGatherOp::verify() { auto resultType = getResult().getType(); - if (failed(verifyAsyncTMALoadOp(*this, getDesc(), getBarrier(), resultType))) + if (failed(verifyAsyncTMALoadOp(*this, getDesc().getType(), getBarrier(), + resultType))) return failure(); // `tile::gather4` does not support fp4_padded operands. if (isFp4Padded(getResult().getType().getEncoding())) diff --git a/test/TritonNvidiaGPU/invalid.mlir b/test/TritonNvidiaGPU/invalid.mlir index ad8a4441cc2f..557740519d70 100644 --- a/test/TritonNvidiaGPU/invalid.mlir +++ b/test/TritonNvidiaGPU/invalid.mlir @@ -149,13 +149,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ #shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @tma_im2col_missing_offsets(%arg0: !tt.tensordesc>) { + tt.func public @tma_im2col_missing_offsets(%arg0: !ttng.tensordesc_im2col>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared2, #smem, mutable> // expected-error @below {{IM2COL mode requires offsets to be provided}} - ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32, %c0_i32, %c0_i32] %0, %1, %true tensorMode = im2col : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> + ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32, %c0_i32, %c0_i32] %0, %1, %true : !ttng.tensordesc_im2col>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> tt.return } } @@ -167,14 +167,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ #shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @tma_im2col_wrong_offset_count(%arg0: !tt.tensordesc>) { + tt.func public @tma_im2col_wrong_offset_count(%arg0: !ttng.tensordesc_im2col>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %c1_i16 = arith.constant 1 : i16 %0 = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared2, #smem, mutable> // expected-error @below {{IM2COL mode with 4D coordinates requires 2 offsets, but got 1}} - ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32, %c0_i32, %c0_i32] offsets = [%c1_i16] %0, %1, %true tensorMode = im2col : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> + ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32, %c0_i32, %c0_i32] offsets = [%c1_i16] %0, %1, %true : !ttng.tensordesc_im2col>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> tt.return } } @@ -205,13 +205,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ #shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @tma_im2col_2d_invalid(%arg0: !tt.tensordesc>) { + tt.func public @tma_im2col_2d_invalid(%arg0: !ttng.tensordesc_im2col>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared2, #smem, mutable> // expected-error @below {{IM2COL mode requires at least 3D coordinates, but got 2D}} - ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %0, %1, %true tensorMode = im2col : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> + ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %0, %1, %true : !ttng.tensordesc_im2col>, !ttg.memdesc<1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> tt.return } } diff --git a/test/TritonNvidiaGPU/ops.mlir b/test/TritonNvidiaGPU/ops.mlir index 4c496bbb7c53..e3c2f07ad6bc 100644 --- a/test/TritonNvidiaGPU/ops.mlir +++ b/test/TritonNvidiaGPU/ops.mlir @@ -105,21 +105,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @tma_load_im2col_3d - // CHECK: ttng.async_tma_copy_global_to_local {{.*}} offsets = [{{.*}}] {{.*}} tensorMode = im2col - tt.func public @tma_load_im2col_3d(%desc: !tt.tensordesc>) { + // CHECK: ttng.async_tma_copy_global_to_local {{.*}} offsets = [{{.*}}] {{.*}} : !ttng.tensordesc_im2col + tt.func public @tma_load_im2col_3d(%desc: !ttng.tensordesc_im2col>) { %true = arith.constant true %c0 = arith.constant 0 : i32 %off = arith.constant 1 : i16 %buf = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> %bar = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared3, #smem, mutable> ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared3, #smem, mutable> - ttng.async_tma_copy_global_to_local %desc[%c0, %c0, %c0] offsets = [%off] %buf, %bar, %true tensorMode = im2col : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared3, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> + ttng.async_tma_copy_global_to_local %desc[%c0, %c0, %c0] offsets = [%off] %buf, %bar, %true : !ttng.tensordesc_im2col>, !ttg.memdesc<1xi64, #shared3, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> tt.return } // CHECK-LABEL: @tma_load_im2col_4d - // CHECK: ttng.async_tma_copy_global_to_local {{.*}} offsets = [{{.*}}, {{.*}}] {{.*}} tensorMode = im2col - tt.func public @tma_load_im2col_4d(%desc: !tt.tensordesc>) { + // CHECK: ttng.async_tma_copy_global_to_local {{.*}} offsets = [{{.*}}, {{.*}}] {{.*}} : !ttng.tensordesc_im2col + tt.func public @tma_load_im2col_4d(%desc: !ttng.tensordesc_im2col>) { %true = arith.constant true %c0 = arith.constant 0 : i32 %off1 = arith.constant 1 : i16 @@ -127,13 +127,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %buf = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> %bar = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared3, #smem, mutable> ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared3, #smem, mutable> - ttng.async_tma_copy_global_to_local %desc[%c0, %c0, %c0, %c0] offsets = [%off1, %off2] %buf, %bar, %true tensorMode = im2col : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared3, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> + ttng.async_tma_copy_global_to_local %desc[%c0, %c0, %c0, %c0] offsets = [%off1, %off2] %buf, %bar, %true : !ttng.tensordesc_im2col>, !ttg.memdesc<1xi64, #shared3, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> tt.return } // CHECK-LABEL: @tma_load_im2col_5d - // CHECK: ttng.async_tma_copy_global_to_local {{.*}} offsets = [{{.*}}, {{.*}}, {{.*}}] {{.*}} tensorMode = im2col - tt.func public @tma_load_im2col_5d(%desc: !tt.tensordesc>) { + // CHECK: ttng.async_tma_copy_global_to_local {{.*}} offsets = [{{.*}}, {{.*}}, {{.*}}] {{.*}} : !ttng.tensordesc_im2col + tt.func public @tma_load_im2col_5d(%desc: !ttng.tensordesc_im2col>) { %true = arith.constant true %c0 = arith.constant 0 : i32 %off1 = arith.constant 1 : i16 @@ -142,14 +142,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %buf = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> %bar = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared3, #smem, mutable> ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared3, #smem, mutable> - ttng.async_tma_copy_global_to_local %desc[%c0, %c0, %c0, %c0, %c0] offsets = [%off1, %off2, %off3] %buf, %bar, %true tensorMode = im2col : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared3, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> + ttng.async_tma_copy_global_to_local %desc[%c0, %c0, %c0, %c0, %c0] offsets = [%off1, %off2, %off3] %buf, %bar, %true : !ttng.tensordesc_im2col>, !ttg.memdesc<1xi64, #shared3, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_128, #smem, mutable> tt.return } // CHECK-LABEL: @tma_load_tiled_mode // CHECK: ttng.async_tma_copy_global_to_local {{.*}}[{{.*}}, {{.*}}] %{{.*}}, %{{.*}}, {{.*}} : !tt.tensordesc // CHECK-NOT: offsets - // CHECK-NOT: tensorMode tt.func public @tma_load_tiled_mode(%desc: !tt.tensordesc>) { %true = arith.constant true %c0 = arith.constant 0 : i32