Skip to content

Commit bb795b9

Browse files
committed
add lit test that checks multicast gather completion
1 parent 56f6760 commit bb795b9

1 file changed

Lines changed: 34 additions & 0 deletions

File tree

test/TritonGPU/consan.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,40 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, ttg.shar
293293

294294
// -----
295295

296+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32, CGALayout = [[0, 0]]}>
297+
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}>
298+
#smem = #ttg.shared_memory
299+
#offset_parent = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0], CGALayout = [[0, 0]]}>
300+
#offsets = #ttg.slice<{dim = 0, parent = #offset_parent}>
301+
module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, "ttng.two-ctas" = true, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} {
302+
// CHECK-LABEL: @multicast_gather_two_cta_tx_count
303+
tt.func public @multicast_gather_two_cta_tx_count(%desc: !tt.tensordesc<1x32xf32, #shared>) {
304+
%true = arith.constant true
305+
%c0_i32 = arith.constant 0 : i32
306+
%c0 = arith.constant 0 : index
307+
%c1 = arith.constant 1 : index
308+
%c2 = arith.constant 2 : index
309+
%x_offsets = arith.constant dense<0> : tensor<32xi32, #offsets>
310+
%bar = ttg.local_alloc {allocation.offset = 65536 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable>
311+
%result = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable>
312+
ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable>
313+
// CHECK: scf.for
314+
scf.for %i = %c0 to %c2 step %c1 {
315+
// CHECK: arith.constant 8192 : i64
316+
// CHECK: tt.call @__triton_consan_verify_barrier_arrive
317+
// CHECK: ttng.barrier_expect
318+
ttng.barrier_expect %bar, 4096, %true : !ttg.memdesc<1xi64, #shared1, #smem, mutable>
319+
// CHECK: arith.constant -8192 : i64
320+
// CHECK: tt.call @__triton_consan_verify_barrier_arrive
321+
// CHECK: ttng.async_tma_gather
322+
ttng.async_tma_gather %desc[%x_offsets, %c0_i32] %result, %bar, %true {multicast} : !tt.tensordesc<1x32xf32, #shared>, tensor<32xi32, #offsets>, i32, !ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<32x32xf32, #shared, #smem, mutable>, i1
323+
}
324+
tt.return
325+
}
326+
}
327+
328+
// -----
329+
296330
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
297331
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
298332
#smem = #ttg.shared_memory

0 commit comments

Comments
 (0)