@@ -479,9 +479,35 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
479479 ttng.warp_group_dot %shmem , %shmem , %acc : !ttg.memdesc <128 x128 xf16 , #shared , #smem , mutable > * !ttg.memdesc <128 x128 xf16 , #shared , #smem , mutable > -> tensor <128 x128 xf16 , #mma >
480480 // CHECK: ttng.warp_group_dot
481481
482+ // CHECK: tt.call @__triton_consan_verify_write_visibility
483+ // CHECK: tt.call @__triton_consan_check_outstanding_commits
484+ // CHECK: tt.call @__triton_consan_stage_access_for_commit
485+ // CHECK: tt.call @__triton_consan_commit_accesses
486+ ttng.async_tma_scatter %arg0 [%x_offsets , %c0_i32 ] %0 : !tt.tensordesc <1 x32 xf32 , #shared >, tensor <32 xi32 >, i32 , !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >
487+ tt.return
488+ }
489+ }
490+
491+ // -----
492+
493+ #shared = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 32 }>
494+ #smem = #ttg.shared_memory
495+ #blocked = #ttg.blocked <{sizePerThread = [1 , 128 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [1 , 4 ], order = [0 , 1 ]}>
496+ #mma = #ttg.nvidia_mma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 32 , 16 ]}>
497+
498+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.shared = 65536 : i32 , ttg.target = " cuda:90" , ttg.tensor_memory_size = 0 : i32 , " ttg.threads-per-warp" = 32 : i32 , " ttg.total-num-warps" = 1 : i32 } {
499+ // CHECK-LABEL: @async_tma_reduce
500+ tt.func public @async_tma_reduce (%arg0: !tt.tensordesc <32 x32 xf32 , #shared >, %ptr: tensor <128 x128 x!tt.ptr <f16 >, #blocked >, %acc: tensor <128 x128 xf16 , #mma >) {
501+ %c0_i32 = arith.constant 0 : i32
502+ %0 = ttg.local_alloc {allocation.offset = 0 : i32 } : () -> !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >
503+ %shmem = ttg.local_alloc {allocation.offset = 4096 : i32 } : () -> !ttg.memdesc <128 x128 xf16 , #shared , #smem , mutable >
504+ ttg.async_copy_global_to_local %ptr , %shmem : tensor <128 x128 x!tt.ptr <f16 >, #blocked > -> <128 x128 xf16 , #shared , #smem , mutable >
505+
482506 // CHECK: tt.call @__triton_consan_verify_write_visibility
483507 // CHECK: tt.call @__triton_consan_check_outstanding_commits
484- ttng.async_tma_scatter %arg0 [%x_offsets , %c0_i32 ] %0 : !tt.tensordesc <1 x32 xf32 , #shared >, tensor <32 xi32 >, i32 , !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >
508+ // CHECK: tt.call @__triton_consan_stage_access_for_commit
509+ // CHECK: tt.call @__triton_consan_commit_accesses
510+ ttng.async_tma_reduce add , %arg0 [%c0_i32 , %c0_i32 ] %0 : !tt.tensordesc <32 x32 xf32 , #shared >, !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >
485511 tt.return
486512 }
487513}
0 commit comments