@@ -293,6 +293,40 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, ttg.shar
293293
294294// -----
295295
296+ #shared = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 32 , CGALayout = [[0 , 0 ]]}>
297+ #shared1 = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [0 ], CGALayout = [[0 ]]}>
298+ #smem = #ttg.shared_memory
299+ #offset_parent = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [1 , 1 ], order = [1 , 0 ], CGALayout = [[0 , 0 ]]}>
300+ #offsets = #ttg.slice <{dim = 0 , parent = #offset_parent }>
301+ module attributes {" ttg.num-ctas" = 2 : i32 , " ttg.num-warps" = 1 : i32 , " ttng.two-ctas" = true , ttg.shared = 65544 : i32 , ttg.target = " cuda:100" , ttg.tensor_memory_size = 0 : i32 , " ttg.threads-per-warp" = 32 : i32 , " ttg.total-num-warps" = 1 : i32 } {
302+ // CHECK-LABEL: @multicast_gather_two_cta_tx_count
303+ tt.func public @multicast_gather_two_cta_tx_count (%desc: !tt.tensordesc <1 x32 xf32 , #shared >) {
304+ %true = arith.constant true
305+ %c0_i32 = arith.constant 0 : i32
306+ %c0 = arith.constant 0 : index
307+ %c1 = arith.constant 1 : index
308+ %c2 = arith.constant 2 : index
309+ %x_offsets = arith.constant dense <0 > : tensor <32 xi32 , #offsets >
310+ %bar = ttg.local_alloc {allocation.offset = 65536 : i32 } : () -> !ttg.memdesc <1 xi64 , #shared1 , #smem , mutable >
311+ %result = ttg.local_alloc {allocation.offset = 0 : i32 } : () -> !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >
312+ ttng.init_barrier %bar , 1 : !ttg.memdesc <1 xi64 , #shared1 , #smem , mutable >
313+ // CHECK: scf.for
314+ scf.for %i = %c0 to %c2 step %c1 {
315+ // CHECK: arith.constant 8192 : i64
316+ // CHECK: tt.call @__triton_consan_verify_barrier_arrive
317+ // CHECK: ttng.barrier_expect
318+ ttng.barrier_expect %bar , 4096 , %true : !ttg.memdesc <1 xi64 , #shared1 , #smem , mutable >
319+ // CHECK: arith.constant -8192 : i64
320+ // CHECK: tt.call @__triton_consan_verify_barrier_arrive
321+ // CHECK: ttng.async_tma_gather
322+ ttng.async_tma_gather %desc [%x_offsets , %c0_i32 ] %result , %bar , %true {multicast } : !tt.tensordesc <1 x32 xf32 , #shared >, tensor <32 xi32 , #offsets >, i32 , !ttg.memdesc <1 xi64 , #shared1 , #smem , mutable >, !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >, i1
323+ }
324+ tt.return
325+ }
326+ }
327+
328+ // -----
329+
296330#shared = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 32 }>
297331#shared1 = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [0 ]}>
298332#smem = #ttg.shared_memory
@@ -479,9 +513,35 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
479513 ttng.warp_group_dot %shmem , %shmem , %acc : !ttg.memdesc <128 x128 xf16 , #shared , #smem , mutable > * !ttg.memdesc <128 x128 xf16 , #shared , #smem , mutable > -> tensor <128 x128 xf16 , #mma >
480514 // CHECK: ttng.warp_group_dot
481515
516+ // CHECK: tt.call @__triton_consan_verify_write_visibility
517+ // CHECK: tt.call @__triton_consan_check_outstanding_commits
518+ // CHECK: tt.call @__triton_consan_stage_access_for_commit
519+ // CHECK: tt.call @__triton_consan_commit_accesses
520+ ttng.async_tma_scatter %arg0 [%x_offsets , %c0_i32 ] %0 : !tt.tensordesc <1 x32 xf32 , #shared >, tensor <32 xi32 >, i32 , !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >
521+ tt.return
522+ }
523+ }
524+
525+ // -----
526+
527+ #shared = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 32 }>
528+ #smem = #ttg.shared_memory
529+ #blocked = #ttg.blocked <{sizePerThread = [1 , 128 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [1 , 4 ], order = [0 , 1 ]}>
530+ #mma = #ttg.nvidia_mma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 32 , 16 ]}>
531+
532+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.shared = 65536 : i32 , ttg.target = " cuda:90" , ttg.tensor_memory_size = 0 : i32 , " ttg.threads-per-warp" = 32 : i32 , " ttg.total-num-warps" = 1 : i32 } {
533+ // CHECK-LABEL: @async_tma_reduce
534+ tt.func public @async_tma_reduce (%arg0: !tt.tensordesc <32 x32 xf32 , #shared >, %ptr: tensor <128 x128 x!tt.ptr <f16 >, #blocked >, %acc: tensor <128 x128 xf16 , #mma >) {
535+ %c0_i32 = arith.constant 0 : i32
536+ %0 = ttg.local_alloc {allocation.offset = 0 : i32 } : () -> !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >
537+ %shmem = ttg.local_alloc {allocation.offset = 4096 : i32 } : () -> !ttg.memdesc <128 x128 xf16 , #shared , #smem , mutable >
538+ ttg.async_copy_global_to_local %ptr , %shmem : tensor <128 x128 x!tt.ptr <f16 >, #blocked > -> <128 x128 xf16 , #shared , #smem , mutable >
539+
482540 // CHECK: tt.call @__triton_consan_verify_write_visibility
483541 // CHECK: tt.call @__triton_consan_check_outstanding_commits
484- ttng.async_tma_scatter %arg0 [%x_offsets , %c0_i32 ] %0 : !tt.tensordesc <1 x32 xf32 , #shared >, tensor <32 xi32 >, i32 , !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >
542+ // CHECK: tt.call @__triton_consan_stage_access_for_commit
543+ // CHECK: tt.call @__triton_consan_commit_accesses
544+ ttng.async_tma_reduce add , %arg0 [%c0_i32 , %c0_i32 ] %0 : !tt.tensordesc <32 x32 xf32 , #shared >, !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >
485545 tt.return
486546 }
487547}
0 commit comments