11// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s --check-prefix=GFX950
2- // RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --verify-diagnostics | FileCheck %s
2+ // RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s
33
44#blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 64 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
55#shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 0 ]}>
@@ -11,9 +11,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
1111 %arg2: !ttg.memdesc <32 x64 xf32 , #shared , #smem , mutable >) {
1212 // We need the splat to allow the AxisAnalysis to work during lowering
1313 %1 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <32 x64 x!tt.ptr <f32 >, #blocked >
14- // Each thread needs to load 8 elements and we load 1 (sizePerThread) per global.load.lds
15- // CHECK-COUNT-8: rocdl.global.load.lds
16- // CHECK-NOT: rocdl.global.load.lds
14+ // Each thread needs to load 8 elements and we load 1 (sizePerThread) per load.
15+ // CDNA3/CDNA4 use the async variant so LLVM tracks via asyncmark.
16+ // CHECK-COUNT-8: rocdl.global.load.async.lds
17+ // CHECK-NOT: rocdl.global.load.async.lds
1718 %2 = ttg.async_copy_global_to_local %1 , %arg2 : tensor <32 x64 x!tt.ptr <f32 >, #blocked > -> <32 x64 xf32 , #shared , #smem , mutable >
1819 tt.return
1920 }
@@ -31,9 +32,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
3132 %arg2: !ttg.memdesc <32 x64 xf32 , #shared , #smem , mutable >) {
3233 // We need the splat to allow the AxisAnalysis to work during lowering
3334 %1 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <32 x64 x!tt.ptr <f32 >, #blocked >
34- // Each thread needs to load 8 elements and we load 1 () per global. load.lds
35- // CHECK-COUNT-8: rocdl.global.load.lds
36- // CHECK-NOT: rocdl.global.load.lds
35+ // Each thread needs to load 8 elements and we load 1 () per load
36+ // CHECK-COUNT-8: rocdl.global.load.async. lds
37+ // CHECK-NOT: rocdl.global.load.async. lds
3738 %2 = ttg.async_copy_global_to_local %1 , %arg2 : tensor <32 x64 x!tt.ptr <f32 >, #blocked > -> <32 x64 xf32 , #shared , #smem , mutable >
3839 tt.return
3940 }
@@ -56,9 +57,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
5657 %4 = tt.splat %arg0 : !tt.ptr <f16 > -> tensor <32 x64 x!tt.ptr <f16 >, #blocked >
5758 %5 = tt.addptr %4 , %3 : tensor <32 x64 x!tt.ptr <f16 >, #blocked >, tensor <32 x64 xi32 , #blocked >
5859
59- // Each thread needs to load 8 elements and we load 2 (sizePerThread) per global. load.lds
60- // CHECK-COUNT-4: rocdl.global.load.lds
61- // CHECK-NOT: rocdl.global.load.lds
60+ // Each thread needs to load 8 elements and we load 2 (sizePerThread) per load
61+ // CHECK-COUNT-4: rocdl.global.load.async. lds
62+ // CHECK-NOT: rocdl.global.load.async. lds
6263 %6 = ttg.async_copy_global_to_local %5 , %arg2 : tensor <32 x64 x!tt.ptr <f16 >, #blocked > -> <32 x64 xf16 , #shared , #smem , mutable >
6364 tt.return
6465 }
@@ -69,61 +70,31 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
6970#blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
7071#shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 0 ]}>
7172#smem = #ttg.shared_memory
72- module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.shared = 8192 : i32 , ttg.target = " hip:gfx950" , " ttg.threads-per-warp" = 64 : i32 } {
73- // GFX950-LABEL: async_copy_vectorized_8xf16
74- tt.func public @async_copy_vectorized_8xf16 (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 },
75- %arg1: i32 {tt.divisibility = 16 : i32 },
76- %arg2: !ttg.memdesc <32 x64 xf16 , #shared , #smem , mutable >) {
77- // We need the index calculation so AxisAnalysis sees that we can vectorize the load
78- %1 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>>
79- %2 = tt.expand_dims %1 {axis = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>> -> tensor <1 x64 xi32 , #blocked >
80- %3 = tt.broadcast %2 : tensor <1 x64 xi32 , #blocked > -> tensor <32 x64 xi32 , #blocked >
81- %4 = tt.splat %arg0 : !tt.ptr <f16 > -> tensor <32 x64 x!tt.ptr <f16 >, #blocked >
82- %5 = tt.addptr %4 , %3 : tensor <32 x64 x!tt.ptr <f16 >, #blocked >, tensor <32 x64 xi32 , #blocked >
83-
84- // Each thread needs to load 8 elements and we load 8 (sizePerThread) per global.load.lds
85- // GFX950: rocdl.global.load.lds
86- // GFX950-next: llvm.return
87-
88- // GFX942 does not support vectorization > 4bytes
89- // expected-error@+1 {{failed to legalize operation 'ttg.async_copy_global_to_local' that was explicitly marked illegal}}
90- %6 = ttg.async_copy_global_to_local %5 , %arg2 : tensor <32 x64 x!tt.ptr <f16 >, #blocked > -> <32 x64 xf16 , #shared , #smem , mutable >
91- tt.return
92- }
93- }
94-
95- // -----
96-
97- #blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
98- #shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 0 ]}>
99- #smem = #ttg.shared_memory
100- module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.shared = 8192 : i32 , ttg.target = " hip:gfx950" , " ttg.threads-per-warp" = 64 : i32 } {
73+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.shared = 8192 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
10174 // CHECK-LABEL: async_wait
75+ // GFX950-LABEL: async_wait
10276 tt.func public @async_wait (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 },
10377 %arg1: i32 {tt.divisibility = 16 : i32 },
10478 %arg2: !ttg.memdesc <32 x64 xf16 , #shared , #smem , mutable >) {
105- // The waitcnt stores all counters in one i32 bits 15:14 and 3:0 store the vmcnt we have to wait on
106- // CHECK: rocdl.s.waitcnt -49168
107- // CHECK: rocdl.s.waitcnt 49279
108- // CHECK: rocdl.s.barrier
109- amdg.async_wait {num_inst = 0 : i32 }
110- // CHECK: rocdl.s.waitcnt -49167
111- // CHECK: rocdl.s.waitcnt 49279
112- // CHECK: rocdl.s.barrier
113- amdg.async_wait {num_inst = 1 : i32 }
114- // CHECK: rocdl.s.waitcnt -2
115- // CHECK: rocdl.s.waitcnt 49279
116- // CHECK: rocdl.s.barrier
117- amdg.async_wait {num_inst = 62 : i32 }
118- // CHECK: rocdl.s.waitcnt -1
119- // CHECK: rocdl.s.waitcnt 49279
120- // CHECK: rocdl.s.barrier
121- amdg.async_wait {num_inst = 63 : i32 }
122- // Check that we clamp values > 63
123- // CHECK: rocdl.s.waitcnt -1
124- // CHECK: rocdl.s.waitcnt 49279
125- // CHECK: rocdl.s.barrier
126- amdg.async_wait {num_inst = 64 : i32 }
79+ // CDNA3/CDNA4 lower ttg.async_wait directly to wait_asyncmark.
80+ // The commit group count is passed through without clamping since
81+ // LLVM will compute the final waitcnt.
82+ // CHECK: rocdl.wait.asyncmark 0
83+ // GFX950: rocdl.wait.asyncmark 0
84+ ttg.async_wait {num = 0 : i32 }
85+ // CHECK: rocdl.wait.asyncmark 1
86+ // GFX950: rocdl.wait.asyncmark 1
87+ ttg.async_wait {num = 1 : i32 }
88+ // CHECK: rocdl.wait.asyncmark 62
89+ // GFX950: rocdl.wait.asyncmark 62
90+ ttg.async_wait {num = 62 : i32 }
91+ // CHECK: rocdl.wait.asyncmark 63
92+ // GFX950: rocdl.wait.asyncmark 63
93+ ttg.async_wait {num = 63 : i32 }
94+ // No clamping — LLVM handles it based on instruction count
95+ // CHECK: rocdl.wait.asyncmark 64
96+ // GFX950: rocdl.wait.asyncmark 64
97+ ttg.async_wait {num = 64 : i32 }
12798 tt.return
12899 }
129100}
@@ -133,13 +104,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
133104#blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
134105#shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 0 ]}>
135106#smem = #ttg.shared_memory
136- module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.shared = 8192 : i32 , ttg.target = " hip:gfx950 " , " ttg.threads-per-warp" = 64 : i32 } {
107+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.shared = 8192 : i32 , ttg.target = " hip:gfx942 " , " ttg.threads-per-warp" = 64 : i32 } {
137108 // CHECK-LABEL: async_commit_group
109+ // GFX950-LABEL: async_commit_group
138110 tt.func public @async_commit_group (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 },
139111 %arg1: i32 {tt.divisibility = 16 : i32 },
140112 %arg2: !ttg.memdesc <32 x64 xf16 , #shared , #smem , mutable >) {
113+ // CDNA3/CDNA4 emit asyncmark for async group tracking
114+ // CHECK: rocdl.asyncmark
141115 // CHECK: llvm.mlir.constant(0 : i32) : i32
142116 // CHECK-NEXT: llvm.return
117+ // GFX950: rocdl.asyncmark
118+ // GFX950: llvm.mlir.constant(0 : i32) : i32
119+ // GFX950-NEXT: llvm.return
143120 ttg.async_commit_group
144121 tt.return
145122 }
@@ -179,25 +156,25 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
179156 // Note that mask/other alignment is 1 so we need 4 conditionals
180157
181158 // CHECK: llvm.cond_br
182- // CHECK: rocdl.global.load.lds
159+ // CHECK: rocdl.global.load.async. lds
183160 // CHECK-NEXT: llvm.br
184161 // CHECK: llvm.cond_br
185162 // CHECK: llvm.store
186163
187164 // CHECK: llvm.cond_br
188- // CHECK: rocdl.global.load.lds
165+ // CHECK: rocdl.global.load.async. lds
189166 // CHECK-NEXT: llvm.br
190167 // CHECK: llvm.cond_br
191168 // CHECK: llvm.store
192169
193170 // CHECK: llvm.cond_br
194- // CHECK: rocdl.global.load.lds
171+ // CHECK: rocdl.global.load.async. lds
195172 // CHECK-NEXT: llvm.br
196173 // CHECK: llvm.cond_br
197174 // CHECK: llvm.store
198175
199176 // CHECK: llvm.cond_br
200- // CHECK: rocdl.global.load.lds
177+ // CHECK: rocdl.global.load.async. lds
201178 // CHECK-NEXT: llvm.br
202179 // CHECK: llvm.cond_br
203180 // CHECK: llvm.store
@@ -243,31 +220,31 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
243220 // CHECK: rocdl.ds_bpermute
244221 // CHECK: rocdl.ballot
245222 // CHECK: llvm.cond_br
246- // CHECK: rocdl.global.load.lds
223+ // CHECK: rocdl.global.load.async. lds
247224 // CHECK-NEXT: llvm.br
248225 // CHECK: llvm.cond_br
249226 // CHECK: llvm.store
250227
251228 // CHECK: rocdl.ds_bpermute
252229 // CHECK: rocdl.ballot
253230 // CHECK: llvm.cond_br
254- // CHECK: rocdl.global.load.lds
231+ // CHECK: rocdl.global.load.async. lds
255232 // CHECK-NEXT: llvm.br
256233 // CHECK: llvm.cond_br
257234 // CHECK: llvm.store
258235
259236 // CHECK: rocdl.ds_bpermute
260237 // CHECK: rocdl.ballot
261238 // CHECK: llvm.cond_br
262- // CHECK: rocdl.global.load.lds
239+ // CHECK: rocdl.global.load.async. lds
263240 // CHECK-NEXT: llvm.br
264241 // CHECK: llvm.cond_br
265242 // CHECK: llvm.store
266243
267244 // CHECK: rocdl.ds_bpermute
268245 // CHECK: rocdl.ballot
269246 // CHECK: llvm.cond_br
270- // CHECK: rocdl.global.load.lds
247+ // CHECK: rocdl.global.load.async. lds
271248 // CHECK-NEXT: llvm.br
272249 // CHECK: llvm.cond_br
273250 // CHECK: llvm.store
@@ -292,13 +269,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, ttg.sha
292269 // Each thread needs to load 1 element and we load 1 (sizePerThread) per global.load.lds
293270
294271 // CHECK: llvm.getelementptr
295- // CHECK: rocdl.global.load.lds {{.*}}, {{.*}}, 4, 0, 0
272+ // CHECK: rocdl.global.load.async. lds {{.*}}, {{.*}}, 4, 0, 0
296273 %2 = ttg.async_copy_global_to_local %1 , %arg2 cacheModifier = ca : tensor <32 x32 x!tt.ptr <f32 >, #blocked > -> <32 x32 xf32 , #shared , #smem , mutable >
297274 // CHECK: llvm.getelementptr
298- // CHECK: rocdl.global.load.lds {{.*}}, {{.*}}, 4, 0, 3
275+ // CHECK: rocdl.global.load.async. lds {{.*}}, {{.*}}, 4, 0, 3
299276 %3 = ttg.async_copy_global_to_local %1 , %arg2 cacheModifier = cg : tensor <32 x32 x!tt.ptr <f32 >, #blocked > -> <32 x32 xf32 , #shared , #smem , mutable >
300277 // CHECK: llvm.getelementptr
301- // CHECK: rocdl.global.load.lds {{.*}}, {{.*}}, 4, 0, 17
278+ // CHECK: rocdl.global.load.async. lds {{.*}}, {{.*}}, 4, 0, 17
302279 %4 = ttg.async_copy_global_to_local %1 , %arg2 cacheModifier = cv : tensor <32 x32 x!tt.ptr <f32 >, #blocked > -> <32 x32 xf32 , #shared , #smem , mutable >
303280 tt.return
304281 }
@@ -313,7 +290,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
313290 // CHECK-LABEL: async_copy_contiguity_hint
314291 tt.func @async_copy_contiguity_hint (%v: tensor <256 x!tt.ptr <f16 >, #blocked >, %smem: !ttg.memdesc <256 xf16 , #shared1D , #smem , mutable >) {
315292 // Check we load 4 bytes at a time
316- // CHECK: rocdl.global.load.lds {{.*}}, {{.*}}, 4
293+ // CHECK: rocdl.global.load.async. lds {{.*}}, {{.*}}, 4
317294 %0 = ttg.async_copy_global_to_local %v , %smem {contiguity = 2 : i32 } : tensor <256 x!tt.ptr <f16 >, #blocked > -> !ttg.memdesc <256 xf16 , #shared1D , #smem , mutable >
318295 tt.return
319296 }
@@ -332,7 +309,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
332309 %1 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <32 x64 x!tt.ptr <f32 >, #blocked >
333310 %2 = ttg.memdesc_subslice %arg2 [0 , 0 ] : !ttg.memdesc <32 x128 xf32 , #shared , #smem , mutable > -> !ttg.memdesc <32 x64 xf32 , #shared , #smem , mutable , 32 x128 >
334311 // We slice in the fastest dim but each warp loads one row, therefore we can write coalesced into LDS
335- // CHECK: rocdl.global.load.lds
312+ // CHECK: rocdl.global.load.async. lds
336313 %3 = ttg.async_copy_global_to_local %1 , %2 : tensor <32 x64 x!tt.ptr <f32 >, #blocked > -> <32 x64 xf32 , #shared , #smem , mutable , 32 x128 >
337314 tt.return
338315 }
@@ -351,7 +328,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
351328 %1 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <32 x32 x!tt.ptr <f32 >, #blocked >
352329 %2 = ttg.memdesc_subslice %arg2 [0 , 0 ] : !ttg.memdesc <64 x32 xf32 , #shared , #smem , mutable > -> !ttg.memdesc <32 x32 xf32 , #shared , #smem , mutable , 64 x32 >
353330 // We slice into the slowest dim which does not break coalesced writes into LDS
354- // CHECK: rocdl.global.load.lds
331+ // CHECK: rocdl.global.load.async. lds
355332 %3 = ttg.async_copy_global_to_local %1 , %2 : tensor <32 x32 x!tt.ptr <f32 >, #blocked > -> <32 x32 xf32 , #shared , #smem , mutable , 64 x32 >
356333 tt.return
357334 }
0 commit comments