Skip to content

Commit bc79129

Browse files
zhanglx13claude
andauthored
[AMD] Generalize in-thread tree reduction to support ternary grouping for max/min (#9897)
Summary - Generalize treeReduceBinary into treeReduce parameterized by arity, enabling ternary (or higher) tree reductions when the target benefits from it - Add getReductionTreeArity(Operation*) to TargetInfoBase (default: 2) so targets can request wider grouping per combiner op - AMD override returns 3 for MaximumFOp/MinimumFOp/MaxNumFOp/MinNumFOp, generating max(max(a,b), c) groups that LLVM folds into v_maximum3_f32/v_minimum3_f32 Motivation The binary tree reduction creates an alternating pattern where every other level produces results that LLVM's DAG combiner cannot fold into ternary instructions. LLVM only matches max(max(a,b), c) → v_maximum3 when the inner max has a single use, but the balanced binary tree creates intermediate results consumed by the next level that alternate between foldable and unfoldable. With arity=3, every group maps directly to a ternary instruction, reducing max/min instruction count by ~23% (344 → 264 for a 256×256 f32 reduction). NVIDIA has no max3 equivalent, so the default arity=2 preserves existing behavior. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ddecfce commit bc79129

5 files changed

Lines changed: 99 additions & 11 deletions

File tree

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,14 @@ class TargetInfoBase {
106106
virtual bool supportLdStMatrixB8() const { return false; }
107107
virtual bool supportBitwidth16Elementwise() const { return false; }
108108
virtual bool supportBitwidth32Elementwise() const { return false; }
109+
110+
// Returns the preferred arity of the in-thread reduction tree for the given
111+
// combiner operation. The default is 2 (binary tree). Targets that have
112+
// native ternary instructions (e.g. AMD v_maximum3/v_minimum3) can return 3
113+
// to generate a ternary reduction tree that maps directly to hardware.
114+
virtual unsigned getReductionTreeArity(Operation *combinerOp) const {
115+
return 2;
116+
}
109117
virtual bool isCuda() const { return false; }
110118

111119
// Returns the shared memory partition size in bytes. A value of 0 means

lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -112,18 +112,28 @@ struct ReduceOpConversion
112112
private:
113113
const TargetInfoBase &targetInfo;
114114

115-
SmallVector<Value>
116-
treeReduceBinary(Location loc, ConversionPatternRewriter &rewriter,
117-
Region &combineOp,
118-
SmallVector<SmallVector<Value>> values) const {
119-
// The number of elements is always a power of two
120-
assert(llvm::isPowerOf2_64(values.size()) && !values.empty());
115+
// Reduce values using a tree of the given arity. Arity=3 generates
116+
// combine(combine(a, b), c) groups that LLVM folds into ternary
117+
// instructions (e.g. v_maximum3_f32 on AMD).
118+
SmallVector<Value> treeReduce(Location loc,
119+
ConversionPatternRewriter &rewriter,
120+
Region &combineOp,
121+
SmallVector<SmallVector<Value>> values,
122+
unsigned arity) const {
123+
assert(!values.empty() && arity >= 2);
121124
while (values.size() > 1) {
122125
SmallVector<SmallVector<Value>> next;
123-
for (size_t i = 0; i + 1 < values.size(); i += 2) {
124-
SmallVector<Value> acc = values[i];
125-
accumulate(loc, rewriter, combineOp, acc, values[i + 1]);
126-
next.push_back(std::move(acc));
126+
for (size_t i = 0; i < values.size(); i += arity) {
127+
size_t remaining = values.size() - i;
128+
size_t groupSize = std::min(static_cast<size_t>(arity), remaining);
129+
if (groupSize == 1) {
130+
next.push_back(std::move(values[i]));
131+
} else {
132+
SmallVector<Value> acc = std::move(values[i]);
133+
for (size_t j = 1; j < groupSize; ++j)
134+
accumulate(loc, rewriter, combineOp, acc, values[i + j]);
135+
next.push_back(std::move(acc));
136+
}
127137
}
128138
values = std::move(next);
129139
}
@@ -271,6 +281,9 @@ struct ReduceOpConversion
271281
Region &combineRegion =
272282
vectorCombineRegion ? *vectorCombineRegion : op.getCombineOp();
273283

284+
Operation &combinerOp = combineRegion.front().front();
285+
unsigned arity = targetInfo.getReductionTreeArity(&combinerOp);
286+
274287
// Perform a tree reduction
275288
unsigned numOperands = accs.size();
276289
SmallVector<SmallVector<Value>> reduced(numOperands);
@@ -286,7 +299,7 @@ struct ReduceOpConversion
286299
vals.push_back(std::move(cur));
287300
}
288301
auto acc =
289-
treeReduceBinary(loc, rewriter, combineRegion, std::move(vals));
302+
treeReduce(loc, rewriter, combineRegion, std::move(vals), arity);
290303
for (unsigned opIdx = 0; opIdx < numOperands; ++opIdx) {
291304
reduced[opIdx].push_back(acc[opIdx]);
292305
}

test/Conversion/amd/reduce_tree_vectorize.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,59 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
5959
}) : (tensor<1x128xf32, #blocked_reduce>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked_reduce}>>
6060
tt.return
6161
}
62+
63+
// Ternary tree reduction for max/min: generates a chain of 3 dependent ops
64+
// per group so LLVM can fold into v_maximum3/v_minimum3/v_max3/v_min3.
65+
66+
// GFX1250-LABEL: reduce_maximum_f32_ternary
67+
// GFX1250: %[[A:.*]] = llvm.intr.maximum(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
68+
// GFX1250-NEXT: %[[B:.*]] = llvm.intr.maximum(%[[A]], %{{.*}}) : (f32, f32) -> f32
69+
// GFX1250-NEXT: %[[C:.*]] = llvm.intr.maximum(%[[B]], %{{.*}}) : (f32, f32) -> f32
70+
tt.func public @reduce_maximum_f32_ternary(%arg0: tensor<1x128xf32, #blocked_reduce>) {
71+
%0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({
72+
^bb0(%a: f32, %b: f32):
73+
%max = arith.maximumf %a, %b : f32
74+
tt.reduce.return %max : f32
75+
}) : (tensor<1x128xf32, #blocked_reduce>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked_reduce}>>
76+
tt.return
77+
}
78+
79+
// GFX1250-LABEL: reduce_minimum_f32_ternary
80+
// GFX1250: %[[A:.*]] = llvm.intr.minimum(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
81+
// GFX1250-NEXT: %[[B:.*]] = llvm.intr.minimum(%[[A]], %{{.*}}) : (f32, f32) -> f32
82+
// GFX1250-NEXT: %[[C:.*]] = llvm.intr.minimum(%[[B]], %{{.*}}) : (f32, f32) -> f32
83+
tt.func public @reduce_minimum_f32_ternary(%arg0: tensor<1x128xf32, #blocked_reduce>) {
84+
%0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({
85+
^bb0(%a: f32, %b: f32):
86+
%min = arith.minimumf %a, %b : f32
87+
tt.reduce.return %min : f32
88+
}) : (tensor<1x128xf32, #blocked_reduce>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked_reduce}>>
89+
tt.return
90+
}
91+
92+
// GFX1250-LABEL: reduce_maxnum_f32_ternary
93+
// GFX1250: %[[A:.*]] = llvm.intr.maxnum(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
94+
// GFX1250-NEXT: %[[B:.*]] = llvm.intr.maxnum(%[[A]], %{{.*}}) : (f32, f32) -> f32
95+
// GFX1250-NEXT: %[[C:.*]] = llvm.intr.maxnum(%[[B]], %{{.*}}) : (f32, f32) -> f32
96+
tt.func public @reduce_maxnum_f32_ternary(%arg0: tensor<1x128xf32, #blocked_reduce>) {
97+
%0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({
98+
^bb0(%a: f32, %b: f32):
99+
%max = arith.maxnumf %a, %b : f32
100+
tt.reduce.return %max : f32
101+
}) : (tensor<1x128xf32, #blocked_reduce>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked_reduce}>>
102+
tt.return
103+
}
104+
105+
// GFX1250-LABEL: reduce_minnum_f32_ternary
106+
// GFX1250: %[[A:.*]] = llvm.intr.minnum(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
107+
// GFX1250-NEXT: %[[B:.*]] = llvm.intr.minnum(%[[A]], %{{.*}}) : (f32, f32) -> f32
108+
// GFX1250-NEXT: %[[C:.*]] = llvm.intr.minnum(%[[B]], %{{.*}}) : (f32, f32) -> f32
109+
tt.func public @reduce_minnum_f32_ternary(%arg0: tensor<1x128xf32, #blocked_reduce>) {
110+
%0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({
111+
^bb0(%a: f32, %b: f32):
112+
%min = arith.minnumf %a, %b : f32
113+
tt.reduce.return %min : f32
114+
}) : (tensor<1x128xf32, #blocked_reduce>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked_reduce}>>
115+
tt.return
116+
}
62117
}

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,16 @@ bool TargetInfo::supportBitwidth32Elementwise() const {
764764
}
765765
}
766766

767+
unsigned TargetInfo::getReductionTreeArity(Operation *combinerOp) const {
768+
// AMD has native ternary max/min instructions: v_max3/v_min3 on all GFX9+,
769+
// and v_maximum3/v_minimum3 additionally on GFX950 and GFX1250.
770+
// Use a ternary reduction tree so these map 1:1 to hardware.
771+
if (isa<arith::MaximumFOp, arith::MinimumFOp, arith::MaxNumFOp,
772+
arith::MinNumFOp>(combinerOp))
773+
return 3;
774+
return 2;
775+
}
776+
767777
bool TargetInfo::supportsDirectToLDSScattering() const {
768778
switch (getISAFamily()) {
769779
case ISAFamily::GFX1250:

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
101101
bool supportBitwidth16Elementwise() const override;
102102
bool supportBitwidth32Elementwise() const override;
103103

104+
unsigned getReductionTreeArity(Operation *combinerOp) const override;
105+
104106
// Returns true if the target supports per lane addresses into LDS for
105107
// direct-to-lds loads. Some architectures (e.g. GFX9) do not support
106108
// scattering and instead have to write warp coalesced into LDS

0 commit comments

Comments
 (0)