Skip to content

Commit dea2e9d

Browse files
authored
[AMD] Add atomic vectorization cap (#10093)
AMD only has vectorized/packed atomics for fadd with fp16/bf16 dtype on newer targets. For other targets, LLVM lowers it with CAS + arithmetic, so 32-bit vectorization is still best.
1 parent 029b260 commit dea2e9d

2 files changed

Lines changed: 32 additions & 1 deletion

File tree

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,16 @@ getAtomicWriteElementsPerThreadCap(Operation *op) {
133133
if (elemTy.isInteger() || elemTy.isF64())
134134
return 1;
135135

136+
auto moduleOp = op->getParentOfType<ModuleOp>();
137+
138+
if (moduleOp && getAMDArch(moduleOp)) {
139+
unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
140+
return std::max(1u, 32u / elemBitwidth);
141+
}
142+
136143
if (atomicRmw.getAtomicRmwOp() != RMWOp::FADD)
137144
return std::nullopt;
138145

139-
auto moduleOp = op->getParentOfType<ModuleOp>();
140146
auto targetAttr =
141147
moduleOp ? moduleOp->getAttrOfType<StringAttr>(ttg::AttrTargetName)
142148
: nullptr;

test/TritonGPU/coalesce.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,31 @@ tt.func public @atomic_add_f16_cuda80(%arg0: !tt.ptr<f16> {tt.divisibility = 16
192192
}
193193
// -----
194194

195+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
196+
// CHECK: #[[$ATOMIC_F16_LAYOUT:.*]] = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
197+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} {
198+
// CHECK-LABEL: @atomic_add_f16_gfx1250
199+
tt.func public @atomic_add_f16_gfx1250(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: i32) {
200+
%c1024_i32 = arith.constant 1024 : i32
201+
%cst = arith.constant dense<1.000000e+00> : tensor<1024xf16, #blocked>
202+
%0 = tt.get_program_id x : i32
203+
%1 = arith.muli %0, %c1024_i32 : i32
204+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
205+
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
206+
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
207+
%5 = tt.splat %arg1 : i32 -> tensor<1024xi32, #blocked>
208+
%6 = arith.cmpi "slt", %4, %5 : tensor<1024xi32, #blocked>
209+
%7 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<1024x!tt.ptr<f16>, #blocked>
210+
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f16>, #blocked>, tensor<1024xi32, #blocked>
211+
// CHECK: ttg.convert_layout %{{.*}} : tensor<1024x!tt.ptr<f16>, #blocked> -> tensor<1024x!tt.ptr<f16>, #[[$ATOMIC_F16_LAYOUT]]>
212+
// CHECK: tt.atomic_rmw fadd, relaxed, gpu, %{{.*}}, %{{.*}}, %{{.*}} : (tensor<1024x!tt.ptr<f16>, #[[$ATOMIC_F16_LAYOUT]]>, tensor<1024xf16, #[[$ATOMIC_F16_LAYOUT]]>, tensor<1024xi1, #[[$ATOMIC_F16_LAYOUT]]>) -> tensor<1024xf16, #[[$ATOMIC_F16_LAYOUT]]>
213+
%9 = tt.atomic_rmw fadd, relaxed, gpu, %8, %cst, %6 : (tensor<1024x!tt.ptr<f16>, #blocked>, tensor<1024xf16, #blocked>, tensor<1024xi1, #blocked>) -> tensor<1024xf16, #blocked>
214+
tt.return
215+
}
216+
}
217+
218+
// -----
219+
195220
// COM: Reproducer for issue #5122
196221
// CHECK-LABEL: @test_5122
197222
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} {

0 commit comments

Comments
 (0)