Skip to content

Commit b2bc905

Browse files
committed
[AMD] Enable supportBitwidth{16|32}Elementwise in TargetInfo
This helps to optimize reduction code generation.
1 parent 917dbde commit b2bc905

3 files changed

Lines changed: 70 additions & 0 deletions

File tree

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 -cse | FileCheck %s --check-prefix=GFX942
2+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 -cse | FileCheck %s --check-prefix=GFX950
3+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx1250 -cse | FileCheck %s --check-prefix=GFX1250
4+
5+
#blocked_reduce = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 64], warpsPerCTA = [1, 1], order = [1, 0]}>
6+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} {
7+
// GFX942-LABEL: reduce_f16
8+
// GFX942: llvm.fadd {{.*}} : vector<2xf16>
9+
// GFX950-LABEL: reduce_f16
10+
// GFX950: llvm.fadd {{.*}} : vector<2xf16>
11+
tt.func public @reduce_f16(%arg0: tensor<1x256xf16, #blocked_reduce>) {
12+
%0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({
13+
^bb0(%a: f16, %b: f16):
14+
%sum = arith.addf %a, %b : f16
15+
tt.reduce.return %sum : f16
16+
}) : (tensor<1x256xf16, #blocked_reduce>) -> tensor<1xf16, #ttg.slice<{dim = 1, parent = #blocked_reduce}>>
17+
tt.return
18+
}
19+
20+
// GFX942-LABEL: reduce_f32
21+
// GFX942-NOT: llvm.fadd {{.*}} : vector<2xf32>
22+
// GFX942: llvm.return
23+
// GFX950-LABEL: reduce_f32
24+
// GFX950-NOT: llvm.fadd {{.*}} : vector<2xf32>
25+
// GFX950: llvm.return
26+
tt.func public @reduce_f32(%arg0: tensor<1x256xf32, #blocked_reduce>) {
27+
%0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({
28+
^bb0(%a: f32, %b: f32):
29+
%sum = arith.addf %a, %b : f32
30+
tt.reduce.return %sum : f32
31+
}) : (tensor<1x256xf32, #blocked_reduce>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked_reduce}>>
32+
tt.return
33+
}
34+
}
35+
36+
// -----
37+
38+
#blocked_reduce = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 1], order = [1, 0]}>
39+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
40+
// GFX1250-LABEL: reduce_f16_tree_vectorize
41+
// GFX1250: llvm.fadd {{.*}} : vector<2xf16>
42+
tt.func public @reduce_f16_tree_vectorize(%arg0: tensor<1x128xf16, #blocked_reduce>) {
43+
%0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({
44+
^bb0(%a: f16, %b: f16):
45+
%sum = arith.addf %a, %b : f16
46+
tt.reduce.return %sum : f16
47+
}) : (tensor<1x128xf16, #blocked_reduce>) -> tensor<1xf16, #ttg.slice<{dim = 1, parent = #blocked_reduce}>>
48+
tt.return
49+
}
50+
51+
// GFX1250-LABEL: reduce_f32_tree_vectorize
52+
// GFX1250: llvm.fadd {{.*}} : vector<2xf32>
53+
tt.func public @reduce_f32_tree_vectorize(%arg0: tensor<1x128xf32, #blocked_reduce>) {
54+
%0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({
55+
^bb0(%a: f32, %b: f32):
56+
%sum = arith.addf %a, %b : f32
57+
tt.reduce.return %sum : f32
58+
}) : (tensor<1x128xf32, #blocked_reduce>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked_reduce}>>
59+
tt.return
60+
}
61+
}

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,12 @@ bool TargetInfo::supportVectorizedAtomics() const {
675675
return true;
676676
}
677677

678+
bool TargetInfo::supportBitwidth16Elementwise() const { return true; }
679+
680+
bool TargetInfo::supportBitwidth32Elementwise() const {
681+
return getISAFamily() == ISAFamily::GFX1250;
682+
}
683+
678684
bool TargetInfo::supportsDirectToLDSScattering() const {
679685
switch (getISAFamily()) {
680686
case ISAFamily::GFX1250:

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
9898

9999
bool supportVectorizedAtomics() const override;
100100

101+
bool supportBitwidth16Elementwise() const override;
102+
bool supportBitwidth32Elementwise() const override;
103+
101104
// Returns true if the target supports per lane addresses into LDS for
102105
// direct-to-lds loads. Some architectures (e.g. GFX9) do not support
103106
// scattering and instead have to write warp coalesced into LDS

0 commit comments

Comments
 (0)