Skip to content

Commit df82d98

Browse files
zhanglx13claude
andauthored
[AMD][gfx9] Use asyncmark/wait_asyncmark for CDNA3/CDNA4 buffer_load_to_lds (#9883)
## Problem On CDNA3 (gfx942) and CDNA4 (gfx950), there are two issues with `buffer_load_to_lds` synchronization: **1. Conservative waits before `ds_read` due to sync modeling** LLVM models `buffer_load_to_lds` as a synchronous operation that writes to LDS. When it sees a subsequent `ds_read` from LDS, it assumes a potential data hazard and inserts a conservative `s_waitcnt vmcnt(0)` — even when the addresses are disjoint. The current workaround annotates `buffer_load_to_lds` and LDS loads with no-alias attributes so LLVM skips the barrier. This is fragile and relies on alias analysis heuristics. **2. Conflicting wait count management between Triton and LLVM** `buffer_load_to_lds` and regular `buffer_load` share the same hardware `vmcnt` counter. Triton manages the wait counts for `buffer_load_to_lds`, while LLVM independently manages wait counts for `buffer_load` — without accounting for the outstanding `buffer_load_to_lds` instructions on the same counter. This leads to LLVM inserting very conservative `s_waitcnt vmcnt(0)` before some `ds_read` instructions, because it doesn't know how many total outstanding loads (of both types) are pending. The no-alias workaround cannot solve this problem. ## Solution LLVM recently added `asyncmark` / `wait_asyncmark` intrinsics that properly solve both issues. When using the async-variant intrinsics (`llvm.amdgcn.raw.ptr.buffer.load.lds.async`, `llvm.amdgcn.global.load.lds.async`) alongside these markers, LLVM models LDS-bound loads as async operations rather than synchronous ones. This means: - LLVM no longer treats `buffer_load_to_lds` as an immediate LDS write, so it doesn't insert spurious waits before `ds_read` (solving issue 1) - LLVM has full visibility into both load types and can compute precise wait counts on the shared `vmcnt` counter (solving issue 2) This is a strictly better solution than the no-alias workaround for both issues. This PR switches CDNA3/CDNA4 to use the async intrinsic variants and asyncmark-based synchronization: - **BufferOpsEmitter**: Emit `rocdl.raw.ptr.buffer.load.async.lds` instead of `rocdl.raw.ptr.buffer.load.lds` - **LoadStoreOpToLLVM**: Emit `rocdl.global.load.async.lds` for global-to-LDS copies, `rocdl.asyncmark` for commit groups, and `rocdl.wait.asyncmark` for async waits - **UpdateAsyncWaitCount**: With asyncmark, LLVM handles vmcnt computation — Triton passes through the commit group count directly and skips the expensive `ModuleAxisInfoAnalysis` - **TargetInfo**: Add `useAsyncMarks()` feature flag to centralize the arch check (makes future arch enablement a one-line change) --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 28b3589 commit df82d98

13 files changed

Lines changed: 253 additions & 221 deletions

test/Conversion/amd/async-ops-alias-scopes.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
1616
%ptr = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x1x!tt.ptr<f32>, #blocked>
1717
%mask = tt.splat %maskVal : i1 -> tensor<64x1xi1, #blocked>
1818

19-
// COMMON: rocdl.global.load.lds {{.*}} {alias_scopes = [[[$ASYNC_COPY_SCOPE]]]
19+
// COMMON: rocdl.global.load.async.lds {{.*}} {alias_scopes = [[[$ASYNC_COPY_SCOPE]]]
2020
// Check that store for 'other' has alias information set
2121
// COMMON: llvm.store {{.*}} {alias_scopes = [[[$LOCAL_LOAD_SCOPE]]], {{.*}}, noalias_scopes = [[[$ASYNC_COPY_SCOPE]]]
2222
%0 = ttg.async_copy_global_to_local %ptr, %arg1 mask %mask other %other : tensor<64x1x!tt.ptr<f32>, #blocked> -> <64x1xf32, #shared, #smem, mutable>
@@ -41,7 +41,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
4141
%mask = tt.splat %maskVal : i1 -> tensor<8x64xi1, #blocked>
4242
%other = arith.constant dense<1.000000e+00> : tensor<8x64xf32, #blocked>
4343

44-
// COMMON: rocdl.raw.ptr.buffer.load.lds {{.*}} {alias_scopes = [[[$ASYNC_COPY_SCOPE]]]
44+
// COMMON: rocdl.raw.ptr.buffer.load.async.lds {{.*}} {alias_scopes = [[[$ASYNC_COPY_SCOPE]]]
4545
// Check that store for 'other' has alias information set
4646
// COMMON: llvm.store {{.*}} {alias_scopes = [[[$LOCAL_LOAD_SCOPE]]], {{.*}}, noalias_scopes = [[[$ASYNC_COPY_SCOPE]]]
4747
%65 = amdg.buffer_load_to_local %arg1[%arg2] mask=%mask other=%other into %arg3 : <f32>[tensor<8x64xi32, #blocked>] tensor<8x64xf32, #blocked> -> <8x64xf32, #shared, #smem, mutable>
@@ -107,7 +107,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
107107
// We need the splat to allow the AxisAnalysis to work during lowering
108108
%ptr = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x1x!tt.ptr<f32>, #blocked>
109109

110-
// COMMON: rocdl.global.load.lds {{.*}} {alias_scopes = [[[$ASYNC_COPY_SCOPE]]]
110+
// COMMON: rocdl.global.load.async.lds {{.*}} {alias_scopes = [[[$ASYNC_COPY_SCOPE]]]
111111
%0 = ttg.async_copy_global_to_local %ptr, %arg1 : tensor<64x1x!tt.ptr<f32>, #blocked> -> <64x1xf32, #shared, #smem, mutable>
112112
%1 = ttg.async_commit_group tokens %0
113113

test/Conversion/amd/async_ops_to_llvm.mlir

Lines changed: 53 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s --check-prefix=GFX950
2-
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --verify-diagnostics | FileCheck %s
2+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s
33

44
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
55
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
@@ -11,9 +11,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
1111
%arg2: !ttg.memdesc<32x64xf32, #shared, #smem, mutable>) {
1212
// We need the splat to allow the AxisAnalysis to work during lowering
1313
%1 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x64x!tt.ptr<f32>, #blocked>
14-
// Each thread needs to load 8 elements and we load 1 (sizePerThread) per global.load.lds
15-
// CHECK-COUNT-8: rocdl.global.load.lds
16-
// CHECK-NOT: rocdl.global.load.lds
14+
// Each thread needs to load 8 elements and we load 1 (sizePerThread) per load.
15+
// CDNA3/CDNA4 use the async variant so LLVM tracks via asyncmark.
16+
// CHECK-COUNT-8: rocdl.global.load.async.lds
17+
// CHECK-NOT: rocdl.global.load.async.lds
1718
%2 = ttg.async_copy_global_to_local %1, %arg2 : tensor<32x64x!tt.ptr<f32>, #blocked> -> <32x64xf32, #shared, #smem, mutable>
1819
tt.return
1920
}
@@ -31,9 +32,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
3132
%arg2: !ttg.memdesc<32x64xf32, #shared, #smem, mutable>) {
3233
// We need the splat to allow the AxisAnalysis to work during lowering
3334
%1 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x64x!tt.ptr<f32>, #blocked>
34-
// Each thread needs to load 8 elements and we load 1 () per global.load.lds
35-
// CHECK-COUNT-8: rocdl.global.load.lds
36-
// CHECK-NOT: rocdl.global.load.lds
35+
// Each thread needs to load 8 elements and we load 1 () per load
36+
// CHECK-COUNT-8: rocdl.global.load.async.lds
37+
// CHECK-NOT: rocdl.global.load.async.lds
3738
%2 = ttg.async_copy_global_to_local %1, %arg2 : tensor<32x64x!tt.ptr<f32>, #blocked> -> <32x64xf32, #shared, #smem, mutable>
3839
tt.return
3940
}
@@ -56,9 +57,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
5657
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
5758
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>
5859

59-
// Each thread needs to load 8 elements and we load 2 (sizePerThread) per global.load.lds
60-
// CHECK-COUNT-4: rocdl.global.load.lds
61-
// CHECK-NOT: rocdl.global.load.lds
60+
// Each thread needs to load 8 elements and we load 2 (sizePerThread) per load
61+
// CHECK-COUNT-4: rocdl.global.load.async.lds
62+
// CHECK-NOT: rocdl.global.load.async.lds
6263
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
6364
tt.return
6465
}
@@ -69,61 +70,31 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
6970
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
7071
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
7172
#smem = #ttg.shared_memory
72-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
73-
// GFX950-LABEL: async_copy_vectorized_8xf16
74-
tt.func public @async_copy_vectorized_8xf16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
75-
%arg1: i32 {tt.divisibility = 16 : i32},
76-
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
77-
// We need the index calculation so AxisAnalysis sees that we can vectorize the load
78-
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
79-
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
80-
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
81-
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
82-
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>
83-
84-
// Each thread needs to load 8 elements and we load 8 (sizePerThread) per global.load.lds
85-
// GFX950: rocdl.global.load.lds
86-
// GFX950-next: llvm.return
87-
88-
// GFX942 does not support vectorization > 4bytes
89-
// expected-error@+1 {{failed to legalize operation 'ttg.async_copy_global_to_local' that was explicitly marked illegal}}
90-
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
91-
tt.return
92-
}
93-
}
94-
95-
// -----
96-
97-
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
98-
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
99-
#smem = #ttg.shared_memory
100-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
73+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
10174
// CHECK-LABEL: async_wait
75+
// GFX950-LABEL: async_wait
10276
tt.func public @async_wait(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
10377
%arg1: i32 {tt.divisibility = 16 : i32},
10478
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
105-
// The waitcnt stores all counters in one i32 bits 15:14 and 3:0 store the vmcnt we have to wait on
106-
// CHECK: rocdl.s.waitcnt -49168
107-
// CHECK: rocdl.s.waitcnt 49279
108-
// CHECK: rocdl.s.barrier
109-
amdg.async_wait {num_inst = 0 : i32}
110-
// CHECK: rocdl.s.waitcnt -49167
111-
// CHECK: rocdl.s.waitcnt 49279
112-
// CHECK: rocdl.s.barrier
113-
amdg.async_wait {num_inst = 1 : i32}
114-
// CHECK: rocdl.s.waitcnt -2
115-
// CHECK: rocdl.s.waitcnt 49279
116-
// CHECK: rocdl.s.barrier
117-
amdg.async_wait {num_inst = 62 : i32}
118-
// CHECK: rocdl.s.waitcnt -1
119-
// CHECK: rocdl.s.waitcnt 49279
120-
// CHECK: rocdl.s.barrier
121-
amdg.async_wait {num_inst = 63 : i32}
122-
// Check that we clamp values > 63
123-
// CHECK: rocdl.s.waitcnt -1
124-
// CHECK: rocdl.s.waitcnt 49279
125-
// CHECK: rocdl.s.barrier
126-
amdg.async_wait {num_inst = 64 : i32}
79+
// CDNA3/CDNA4 lower ttg.async_wait directly to wait_asyncmark.
80+
// The commit group count is passed through without clamping since
81+
// LLVM will compute the final waitcnt.
82+
// CHECK: rocdl.wait.asyncmark 0
83+
// GFX950: rocdl.wait.asyncmark 0
84+
ttg.async_wait {num = 0 : i32}
85+
// CHECK: rocdl.wait.asyncmark 1
86+
// GFX950: rocdl.wait.asyncmark 1
87+
ttg.async_wait {num = 1 : i32}
88+
// CHECK: rocdl.wait.asyncmark 62
89+
// GFX950: rocdl.wait.asyncmark 62
90+
ttg.async_wait {num = 62 : i32}
91+
// CHECK: rocdl.wait.asyncmark 63
92+
// GFX950: rocdl.wait.asyncmark 63
93+
ttg.async_wait {num = 63 : i32}
94+
// No clamping — LLVM handles it based on instruction count
95+
// CHECK: rocdl.wait.asyncmark 64
96+
// GFX950: rocdl.wait.asyncmark 64
97+
ttg.async_wait {num = 64 : i32}
12798
tt.return
12899
}
129100
}
@@ -133,13 +104,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
133104
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
134105
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
135106
#smem = #ttg.shared_memory
136-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
107+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
137108
// CHECK-LABEL: async_commit_group
109+
// GFX950-LABEL: async_commit_group
138110
tt.func public @async_commit_group(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
139111
%arg1: i32 {tt.divisibility = 16 : i32},
140112
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
113+
// CDNA3/CDNA4 emit asyncmark for async group tracking
114+
// CHECK: rocdl.asyncmark
141115
// CHECK: llvm.mlir.constant(0 : i32) : i32
142116
// CHECK-NEXT: llvm.return
117+
// GFX950: rocdl.asyncmark
118+
// GFX950: llvm.mlir.constant(0 : i32) : i32
119+
// GFX950-NEXT: llvm.return
143120
ttg.async_commit_group
144121
tt.return
145122
}
@@ -179,25 +156,25 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
179156
// Note that mask/other alignment is 1 so we need 4 conditionals
180157

181158
// CHECK: llvm.cond_br
182-
// CHECK: rocdl.global.load.lds
159+
// CHECK: rocdl.global.load.async.lds
183160
// CHECK-NEXT: llvm.br
184161
// CHECK: llvm.cond_br
185162
// CHECK: llvm.store
186163

187164
// CHECK: llvm.cond_br
188-
// CHECK: rocdl.global.load.lds
165+
// CHECK: rocdl.global.load.async.lds
189166
// CHECK-NEXT: llvm.br
190167
// CHECK: llvm.cond_br
191168
// CHECK: llvm.store
192169

193170
// CHECK: llvm.cond_br
194-
// CHECK: rocdl.global.load.lds
171+
// CHECK: rocdl.global.load.async.lds
195172
// CHECK-NEXT: llvm.br
196173
// CHECK: llvm.cond_br
197174
// CHECK: llvm.store
198175

199176
// CHECK: llvm.cond_br
200-
// CHECK: rocdl.global.load.lds
177+
// CHECK: rocdl.global.load.async.lds
201178
// CHECK-NEXT: llvm.br
202179
// CHECK: llvm.cond_br
203180
// CHECK: llvm.store
@@ -243,31 +220,31 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
243220
// CHECK: rocdl.ds_bpermute
244221
// CHECK: rocdl.ballot
245222
// CHECK: llvm.cond_br
246-
// CHECK: rocdl.global.load.lds
223+
// CHECK: rocdl.global.load.async.lds
247224
// CHECK-NEXT: llvm.br
248225
// CHECK: llvm.cond_br
249226
// CHECK: llvm.store
250227

251228
// CHECK: rocdl.ds_bpermute
252229
// CHECK: rocdl.ballot
253230
// CHECK: llvm.cond_br
254-
// CHECK: rocdl.global.load.lds
231+
// CHECK: rocdl.global.load.async.lds
255232
// CHECK-NEXT: llvm.br
256233
// CHECK: llvm.cond_br
257234
// CHECK: llvm.store
258235

259236
// CHECK: rocdl.ds_bpermute
260237
// CHECK: rocdl.ballot
261238
// CHECK: llvm.cond_br
262-
// CHECK: rocdl.global.load.lds
239+
// CHECK: rocdl.global.load.async.lds
263240
// CHECK-NEXT: llvm.br
264241
// CHECK: llvm.cond_br
265242
// CHECK: llvm.store
266243

267244
// CHECK: rocdl.ds_bpermute
268245
// CHECK: rocdl.ballot
269246
// CHECK: llvm.cond_br
270-
// CHECK: rocdl.global.load.lds
247+
// CHECK: rocdl.global.load.async.lds
271248
// CHECK-NEXT: llvm.br
272249
// CHECK: llvm.cond_br
273250
// CHECK: llvm.store
@@ -292,13 +269,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, ttg.sha
292269
// Each thread needs to load 1 element and we load 1 (sizePerThread) per global.load.lds
293270

294271
// CHECK: llvm.getelementptr
295-
// CHECK: rocdl.global.load.lds {{.*}}, {{.*}}, 4, 0, 0
272+
// CHECK: rocdl.global.load.async.lds {{.*}}, {{.*}}, 4, 0, 0
296273
%2 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = ca: tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
297274
// CHECK: llvm.getelementptr
298-
// CHECK: rocdl.global.load.lds {{.*}}, {{.*}}, 4, 0, 3
275+
// CHECK: rocdl.global.load.async.lds {{.*}}, {{.*}}, 4, 0, 3
299276
%3 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = cg: tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
300277
// CHECK: llvm.getelementptr
301-
// CHECK: rocdl.global.load.lds {{.*}}, {{.*}}, 4, 0, 17
278+
// CHECK: rocdl.global.load.async.lds {{.*}}, {{.*}}, 4, 0, 17
302279
%4 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = cv: tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
303280
tt.return
304281
}
@@ -313,7 +290,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
313290
// CHECK-LABEL: async_copy_contiguity_hint
314291
tt.func @async_copy_contiguity_hint(%v: tensor<256x!tt.ptr<f16>, #blocked>, %smem: !ttg.memdesc<256xf16, #shared1D, #smem, mutable>) {
315292
// Check we load 4 bytes at a time
316-
// CHECK: rocdl.global.load.lds {{.*}}, {{.*}}, 4
293+
// CHECK: rocdl.global.load.async.lds {{.*}}, {{.*}}, 4
317294
%0 = ttg.async_copy_global_to_local %v, %smem {contiguity = 2 : i32} : tensor<256x!tt.ptr<f16>, #blocked> -> !ttg.memdesc<256xf16, #shared1D, #smem, mutable>
318295
tt.return
319296
}
@@ -332,7 +309,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
332309
%1 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x64x!tt.ptr<f32>, #blocked>
333310
%2 = ttg.memdesc_subslice %arg2 [0, 0] : !ttg.memdesc<32x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<32x64xf32, #shared, #smem, mutable, 32x128>
334311
// We slice in the fastest dim but each warp loads one row, therefore we can write coalesced into LDS
335-
// CHECK: rocdl.global.load.lds
312+
// CHECK: rocdl.global.load.async.lds
336313
%3 = ttg.async_copy_global_to_local %1, %2 : tensor<32x64x!tt.ptr<f32>, #blocked> -> <32x64xf32, #shared, #smem, mutable, 32x128>
337314
tt.return
338315
}
@@ -351,7 +328,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
351328
%1 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #blocked>
352329
%2 = ttg.memdesc_subslice %arg2 [0, 0] : !ttg.memdesc<64x32xf32, #shared, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable, 64x32>
353330
// We slice into the slowest dim which does not break coalesced writes into LDS
354-
// CHECK: rocdl.global.load.lds
331+
// CHECK: rocdl.global.load.async.lds
355332
%3 = ttg.async_copy_global_to_local %1, %2 : tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable, 64x32>
356333
tt.return
357334
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --verify-diagnostics
2+
3+
// GFX942 does not support vectorization > 4bytes for direct-to-LDS loads
4+
5+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
6+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
7+
#smem = #ttg.shared_memory
8+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
9+
tt.func public @async_copy_vectorized_8xf16_error(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
10+
%arg1: i32 {tt.divisibility = 16 : i32},
11+
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
12+
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
13+
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
14+
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
15+
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
16+
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>
17+
18+
// expected-error@+1 {{failed to legalize operation 'ttg.async_copy_global_to_local' that was explicitly marked illegal}}
19+
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
20+
tt.return
21+
}
22+
}

0 commit comments

Comments
 (0)