@@ -395,6 +395,48 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar
395395
396396// -----
397397
398+ #shared = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 32 }>
399+ #smem = #ttg.shared_memory
400+ #blocked = #ttg.blocked <{sizePerThread = [1 , 32 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [1 , 1 ], order = [0 , 1 ]}>
401+
402+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 1 : i32 , ttg.shared = 65544 : i32 , ttg.target = " cuda:90" , ttg.tensor_memory_size = 0 : i32 , " ttg.threads-per-warp" = 32 : i32 , " ttg.total-num-warps" = 1 : i32 } {
403+ tt.func public @select_shared_memory_regions (%cond: i1 ) {
404+ %alloc_a = ttg.local_alloc {allocation.offset = 57344 : i32 } : () -> !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >
405+ %alloc_b = ttg.local_alloc {allocation.offset = 61440 : i32 } : () -> !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >
406+ %selected = arith.select %cond , %alloc_a , %alloc_b : !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >
407+ // expected-remark @below {{Buffers: [57344, 4096], [61440, 4096]}}
408+ ttg.local_load %selected : !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable > -> tensor <32 x32 xf32 , #blocked >
409+ tt.return
410+ }
411+
412+ // expected-remark @below {{All Shared Regions: [57344, 4096], [61440, 4096]}}
413+ tt.func private @print_all_regions () attributes {test.print_all_used_regions } {
414+ tt.return
415+ }
416+ }
417+
418+ // -----
419+
420+ #tmem = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 128 , colStride = 1 >
421+
422+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 1 : i32 , ttg.shared = 65544 : i32 , ttg.target = " cuda:90" , ttg.tensor_memory_size = 0 : i32 , " ttg.threads-per-warp" = 32 : i32 , " ttg.total-num-warps" = 1 : i32 } {
423+ tt.func public @select_tensor_memory_regions (%cond: i1 ) {
424+ %tm0 = ttng.tmem_alloc {tensor_memory_col_offset = 0 : i32 , tensor_memory_row_offset = 0 : i32 } : () -> !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
425+ %tm1 = ttng.tmem_alloc {tensor_memory_col_offset = 128 : i32 , tensor_memory_row_offset = 0 : i32 } : () -> !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
426+ %selected = arith.select %cond , %tm0 , %tm1 : !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
427+ // expected-remark @below {{Buffers: [0, 128], [128, 128]}}
428+ ttng.tmem_load %selected : !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable > -> tensor <128 x128 xf32 >
429+ tt.return
430+ }
431+
432+ // expected-remark @below {{All Tensor Regions: [0, 128], [128, 128]}}
433+ tt.func private @print_all_regions () attributes {test.print_all_used_regions } {
434+ tt.return
435+ }
436+ }
437+
438+ // -----
439+
398440#shared = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 32 }>
399441#shared1 = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [0 ]}>
400442#smem = #ttg.shared_memory
0 commit comments