@@ -192,6 +192,31 @@ tt.func public @atomic_add_f16_cuda80(%arg0: !tt.ptr<f16> {tt.divisibility = 16
192192}
193193// -----
194194
195+ #blocked = #ttg.blocked <{sizePerThread = [1 ], threadsPerWarp = [32 ], warpsPerCTA = [4 ], order = [0 ]}>
196+ // CHECK: #[[$ATOMIC_F16_LAYOUT:.*]] = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
197+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " hip:gfx1250" , " ttg.threads-per-warp" = 32 : i32 } {
198+ // CHECK-LABEL: @atomic_add_f16_gfx1250
199+ tt.func public @atomic_add_f16_gfx1250 (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: i32 ) {
200+ %c1024_i32 = arith.constant 1024 : i32
201+ %cst = arith.constant dense <1.000000e+00 > : tensor <1024 xf16 , #blocked >
202+ %0 = tt.get_program_id x : i32
203+ %1 = arith.muli %0 , %c1024_i32 : i32
204+ %2 = tt.make_range {end = 1024 : i32 , start = 0 : i32 } : tensor <1024 xi32 , #blocked >
205+ %3 = tt.splat %1 : i32 -> tensor <1024 xi32 , #blocked >
206+ %4 = arith.addi %3 , %2 : tensor <1024 xi32 , #blocked >
207+ %5 = tt.splat %arg1 : i32 -> tensor <1024 xi32 , #blocked >
208+ %6 = arith.cmpi " slt" , %4 , %5 : tensor <1024 xi32 , #blocked >
209+ %7 = tt.splat %arg0 : !tt.ptr <f16 > -> tensor <1024 x!tt.ptr <f16 >, #blocked >
210+ %8 = tt.addptr %7 , %4 : tensor <1024 x!tt.ptr <f16 >, #blocked >, tensor <1024 xi32 , #blocked >
211+ // CHECK: ttg.convert_layout %{{.*}} : tensor<1024x!tt.ptr<f16>, #blocked> -> tensor<1024x!tt.ptr<f16>, #[[$ATOMIC_F16_LAYOUT]]>
212+ // CHECK: tt.atomic_rmw fadd, relaxed, gpu, %{{.*}}, %{{.*}}, %{{.*}} : (tensor<1024x!tt.ptr<f16>, #[[$ATOMIC_F16_LAYOUT]]>, tensor<1024xf16, #[[$ATOMIC_F16_LAYOUT]]>, tensor<1024xi1, #[[$ATOMIC_F16_LAYOUT]]>) -> tensor<1024xf16, #[[$ATOMIC_F16_LAYOUT]]>
213+ %9 = tt.atomic_rmw fadd , relaxed , gpu , %8 , %cst , %6 : (tensor <1024 x!tt.ptr <f16 >, #blocked >, tensor <1024 xf16 , #blocked >, tensor <1024 xi1 , #blocked >) -> tensor <1024 xf16 , #blocked >
214+ tt.return
215+ }
216+ }
217+
218+ // -----
219+
195220// COM: Reproducer for issue #5122
196221// CHECK-LABEL: @test_5122
197222module attributes {" ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 16 : i32 } {
0 commit comments