@@ -33,3 +33,127 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
3333 tt.return
3434 }
3535}
36+
37+ // -----
38+
39+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
40+ #shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 8 , order = [1 , 0 ]}>
41+ #smem = #ttg.shared_memory
42+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.shared = 8192 : i32 } {
43+ // CHECK-LABEL: async_copy_with_swizzle
44+ tt.func public @async_copy_with_swizzle (%arg0: !tt.ptr <f32 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 },
45+ %arg2: !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >) {
46+ // We need the splat to allow the AxisAnalysis to work during lowering
47+ %1 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <32 x32 x!tt.ptr <f32 >, #blocked >
48+ // Each thread needs to load 8 elements and we load 1 (sizePerThread) per global.load.lds
49+ // CHECK-COUNT-8: llvm.amdgcn.global.load.async.to.lds.b32
50+ // CHECK-NOT: llvm.amdgcn.global.load.async.to.lds
51+ %2 = ttg.async_copy_global_to_local %1 , %arg2 : tensor <32 x32 x!tt.ptr <f32 >, #blocked > -> <32 x32 xf32 , #shared , #smem , mutable >
52+ tt.return
53+ }
54+ }
55+
56+ // -----
57+
58+ // Broadcast to all CTAs so we should just see 15 (0b1111) as the broadcast mask since we have 4 CTAs per CGA
59+ #blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 4 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ], CTAsPerCGA = [2 , 2 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
60+ #shared = #ttg.swizzled_shared <{vec = 8 , perPhase = 2 , maxPhase = 4 , order = [1 , 0 ], CTAsPerCGA = [2 , 2 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
61+ #smem = #ttg.shared_memory
62+ module attributes {" ttg.num-ctas" = 4 : i32 , " ttg.num-warps" = 4 : i32 , ttg.shared = 8192 : i32 } {
63+ // CHECK-LABEL: async_load_multicast_to_all_ctas
64+ tt.func public @async_load_multicast_to_all_ctas (%arg0: tensor <32 x32 x!tt.ptr <f32 >, #blocked > {tt.divisibility = dense <[16 , 16 ]> : tensor <2 xi32 >, tt.contiguity = dense <[16 , 16 ]> : tensor <2 xi32 >, tt.constancy = dense <[1 , 1 ]> : tensor <2 xi32 >},
65+ %arg1: !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >) {
66+ // CHECK: %[[GROUP_MASK:.*]] = llvm.mlir.constant(15 : i32) : i32
67+ // CHECK: llvm.amdgcn.cluster.load.async.to.lds{{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[GROUP_MASK]]
68+
69+ %6 = ttg.async_copy_global_to_local %arg0 , %arg1 : tensor <32 x32 x!tt.ptr <f32 >, #blocked > -> <32 x32 xf32 , #shared , #smem , mutable >
70+ tt.return
71+ }
72+ }
73+
74+ // -----
75+
76+ // 8 CTAs, 2 multicast groups of 4 CTAs each. Each group is strided by 1 so the base mask should be 0b1010101 (85) and the non free mask is -7 (~0b110)
77+ #blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 4 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ], CTAsPerCGA = [8 , 1 ], CTASplitNum = [2 , 1 ], CTAOrder = [1 , 0 ]}>
78+ #shared = #ttg.swizzled_shared <{vec = 8 , perPhase = 2 , maxPhase = 4 , order = [1 , 0 ], CTAsPerCGA = [8 , 1 ], CTASplitNum = [2 , 1 ], CTAOrder = [1 , 0 ]}>
79+ #smem = #ttg.shared_memory
80+ module attributes {" ttg.num-ctas" = 8 : i32 , " ttg.num-warps" = 4 : i32 , ttg.shared = 8192 : i32 } {
81+ // CHECK-LABEL: async_load_multicast_to_half_ctas
82+ tt.func public @async_load_multicast_to_half_ctas (%arg0: tensor <32 x32 x!tt.ptr <f32 >, #blocked > {tt.divisibility = dense <[16 , 16 ]> : tensor <2 xi32 >, tt.contiguity = dense <[16 , 16 ]> : tensor <2 xi32 >, tt.constancy = dense <[1 , 1 ]> : tensor <2 xi32 >},
83+ %arg1: !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >) {
84+ // CHECK: llvm.amdgcn.cluster.workgroup.id.x
85+ // CHECK: %[[CTA_ID:.*]] = {{.*}}llvm.amdgcn.cluster.workgroup.id.x
86+ // CHECK: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-7 : i32) : i32
87+ // CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]]
88+ // CHECK: %[[GROUP_MASK:.*]] = llvm.mlir.constant(85 : i32) : i32
89+ // CHECK: %[[CTA_MASK:.*]] = llvm.shl %[[GROUP_MASK]], %[[SHIFT_AMOUNT]]
90+ // CHECK: llvm.amdgcn.cluster.load.async.to.lds{{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[CTA_MASK]]
91+ %6 = ttg.async_copy_global_to_local %arg0 , %arg1 : tensor <32 x32 x!tt.ptr <f32 >, #blocked > -> <32 x32 xf32 , #shared , #smem , mutable >
92+ tt.return
93+ }
94+ }
95+
96+ // -----
97+
98+ // 16 CTAs, 8 multicast groups of 2 CTAs each, each group is strided by 8 so the base mask should be 0b100000001 (257) and the non free mask is -9 (~0b1000)
99+ #blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 4 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ], CTAsPerCGA = [1 , 16 ], CTASplitNum = [1 , 8 ], CTAOrder = [1 , 0 ]}>
100+ #shared = #ttg.swizzled_shared <{vec = 8 , perPhase = 2 , maxPhase = 4 , order = [1 , 0 ], CTAsPerCGA = [1 , 16 ], CTASplitNum = [1 , 8 ], CTAOrder = [1 , 0 ]}>
101+ #smem = #ttg.shared_memory
102+ module attributes {" ttg.num-ctas" = 16 : i32 , " ttg.num-warps" = 4 : i32 , ttg.shared = 8192 : i32 } {
103+ // CHECK-LABEL: async_load_multicast_group_of_2_strided_by_8
104+ tt.func public @async_load_multicast_group_of_2_strided_by_8 (%arg0: tensor <32 x32 x!tt.ptr <f32 >, #blocked > {tt.divisibility = dense <[16 , 16 ]> : tensor <2 xi32 >, tt.contiguity = dense <[16 , 16 ]> : tensor <2 xi32 >, tt.constancy = dense <[1 , 1 ]> : tensor <2 xi32 >},
105+ %arg1: !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >) {
106+ // Skip the first cluster id because it's emitted for address calculation
107+ // CHECK: llvm.amdgcn.cluster.workgroup.id.x
108+ // CHECK: %[[CTA_ID:.*]] = {{.*}}llvm.amdgcn.cluster.workgroup.id.x
109+ // CHECK: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-9 : i32) : i32
110+ // CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]]
111+ // CHECK: %[[GROUP_MASK:.*]] = llvm.mlir.constant(257 : i32) : i32
112+ // CHECK: %[[CTA_MASK:.*]] = llvm.shl %[[GROUP_MASK]], %[[SHIFT_AMOUNT]]
113+ // CHECK: llvm.amdgcn.cluster.load.async.to.lds{{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[CTA_MASK]]
114+ %6 = ttg.async_copy_global_to_local %arg0 , %arg1 : tensor <32 x32 x!tt.ptr <f32 >, #blocked > -> <32 x32 xf32 , #shared , #smem , mutable >
115+ tt.return
116+ }
117+ }
118+
119+ // -----
120+
121+ // 16 CTAs split into 16 multicast groups so we should not emit cluster load since we do not share any data
122+ #blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 4 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ], CTAsPerCGA = [1 , 16 ], CTASplitNum = [1 , 16 ], CTAOrder = [1 , 0 ]}>
123+ #shared = #ttg.swizzled_shared <{vec = 8 , perPhase = 2 , maxPhase = 4 , order = [1 , 0 ], CTAsPerCGA = [1 , 16 ], CTASplitNum = [1 , 16 ], CTAOrder = [1 , 0 ]}>
124+ #smem = #ttg.shared_memory
125+ module attributes {" ttg.num-ctas" = 16 : i32 , " ttg.num-warps" = 4 : i32 , ttg.shared = 8192 : i32 } {
126+ // CHECK-LABEL: async_load_multi_cta_but_not_data_sharing
127+ tt.func public @async_load_multi_cta_but_not_data_sharing (%arg0: tensor <32 x32 x!tt.ptr <f32 >, #blocked > {tt.divisibility = dense <[16 , 16 ]> : tensor <2 xi32 >, tt.contiguity = dense <[16 , 16 ]> : tensor <2 xi32 >, tt.constancy = dense <[1 , 1 ]> : tensor <2 xi32 >},
128+ %arg1: !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >) {
129+ // CHECK-NOT: llvm.amdgcn.cluster.load.async.to.lds
130+ // CHECK: llvm.amdgcn.global.load.async.to.lds.b64
131+ // CHECK-NOT: llvm.amdgcn.cluster.load.async.to.lds
132+ %6 = ttg.async_copy_global_to_local %arg0 , %arg1 : tensor <32 x32 x!tt.ptr <f32 >, #blocked > -> <32 x32 xf32 , #shared , #smem , mutable >
133+ tt.return
134+ }
135+ }
136+
137+ // -----
138+
139+ // Test with linear layout as src layout
140+ // 16 CTAs, 8 multicast groups of 2 CTAs each, each group is strided by 8 so the base mask should be 0b100000001 (257) and the non free mask is -9 (~0b1000)
141+ #linear = #ttg.linear <{register = [[0 , 1 ], [0 , 2 ]], lane = [[0 , 0 ], [0 , 0 ], [1 , 0 ], [2 , 0 ], [4 , 0 ]], warp = [[8 , 0 ], [16 , 0 ]], block = [[0 , 4 ], [0 , 8 ], [0 , 16 ], [0 , 0 ]], order = [1 , 0 ]}>
142+ #shared = #ttg.swizzled_shared <{vec = 8 , perPhase = 2 , maxPhase = 4 , order = [1 , 0 ], CTAsPerCGA = [1 , 16 ], CTASplitNum = [1 , 8 ], CTAOrder = [1 , 0 ]}>
143+ #smem = #ttg.shared_memory
144+ module attributes {" ttg.num-ctas" = 16 : i32 , " ttg.num-warps" = 4 : i32 , ttg.shared = 8192 : i32 } {
145+ // CHECK-LABEL: async_load_multi_cta_linear_layout
146+ tt.func public @async_load_multi_cta_linear_layout (%arg0: tensor <32 x32 x!tt.ptr <f32 >, #linear > {tt.divisibility = dense <[16 , 16 ]> : tensor <2 xi32 >, tt.contiguity = dense <[16 , 16 ]> : tensor <2 xi32 >, tt.constancy = dense <[1 , 1 ]> : tensor <2 xi32 >},
147+ %arg1: !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable >) {
148+ // Skip the first cluster id because it's emitted for address calculation
149+ // CHECK: llvm.amdgcn.cluster.workgroup.id.x
150+ // CHECK: %[[CTA_ID:.*]] = {{.*}}llvm.amdgcn.cluster.workgroup.id.x
151+ // CHECK: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-9 : i32) : i32
152+ // CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]]
153+ // CHECK: %[[GROUP_MASK:.*]] = llvm.mlir.constant(257 : i32) : i32
154+ // CHECK: %[[CTA_MASK:.*]] = llvm.shl %[[GROUP_MASK]], %[[SHIFT_AMOUNT]]
155+ // CHECK: llvm.amdgcn.cluster.load.async.to.lds{{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[CTA_MASK]]
156+ %6 = ttg.async_copy_global_to_local %arg0 , %arg1 : tensor <32 x32 x!tt.ptr <f32 >, #linear > -> <32 x32 xf32 , #shared , #smem , mutable >
157+ tt.return
158+ }
159+ }
0 commit comments