@@ -293,6 +293,40 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, ttg.shar
293293
294294// -----
295295
296+ #shared = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 32 , CGALayout = [[0 , 0 ]]}>
297+ #shared1 = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [0 ], CGALayout = [[0 ]]}>
298+ #smem = #ttg.shared_memory
299+ #offset_parent = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [1 , 1 ], order = [1 , 0 ], CGALayout = [[0 , 0 ]]}>
300+ #offsets = #ttg.slice <{dim = 0 , parent = #offset_parent }>
301+ module attributes {" ttg.num-ctas" = 2 : i32 , " ttg.num-warps" = 1 : i32 , " ttng.two-ctas" = true , ttg.shared = 65544 : i32 , ttg.target = " cuda:100" , ttg.tensor_memory_size = 0 : i32 , " ttg.threads-per-warp" = 32 : i32 , " ttg.total-num-warps" = 1 : i32 } {
302+ // CHECK-LABEL: @multicast_gather_two_cta_tx_count
303+ tt.func public @multicast_gather_two_cta_tx_count (%desc: !tt.tensordesc <1 x32 xf32 , #shared >) {
304+ %true = arith.constant true
305+ %c0_i32 = arith.constant 0 : i32
306+ %c0 = arith.constant 0 : index
307+ %c1 = arith.constant 1 : index
308+ %c2 = arith.constant 2 : index
309+ %x_offsets = arith.constant dense <0 > : tensor <32 xi32 , #offsets >
310+ %bar = ttg.local_alloc {allocation.offset = 65536 : i32 } : () -> !ttg.memdesc <1 xi64 , #shared1 , #smem , mutable >
311+ %result = ttg.local_alloc {allocation.offset = 0 : i32 } : () -> !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >
312+ ttng.init_barrier %bar , 1 : !ttg.memdesc <1 xi64 , #shared1 , #smem , mutable >
313+ // CHECK: scf.for
314+ scf.for %i = %c0 to %c2 step %c1 {
315+ // CHECK: arith.constant 8192 : i64
316+ // CHECK: tt.call @__triton_consan_verify_barrier_arrive
317+ // CHECK: ttng.barrier_expect
318+ ttng.barrier_expect %bar , 4096 , %true : !ttg.memdesc <1 xi64 , #shared1 , #smem , mutable >
319+ // CHECK: arith.constant -8192 : i64
320+ // CHECK: tt.call @__triton_consan_verify_barrier_arrive
321+ // CHECK: ttng.async_tma_gather
322+ ttng.async_tma_gather %desc [%x_offsets , %c0_i32 ] %result , %bar , %true {multicast } : !tt.tensordesc <1 x32 xf32 , #shared >, tensor <32 xi32 , #offsets >, i32 , !ttg.memdesc <1 xi64 , #shared1 , #smem , mutable >, !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >, i1
323+ }
324+ tt.return
325+ }
326+ }
327+
328+ // -----
329+
296330#shared = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 32 }>
297331#shared1 = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [0 ]}>
298332#smem = #ttg.shared_memory
0 commit comments