Skip to content

Commit 74bec5d

Browse files
Alex-Wenggclaude
andcommitted
[GlobalOpt] Accept arith.maxnumf in the softmax matcher (#24466)
The stabilizing max of a softmax can be spelled with either arith.maximumf (NaN-propagating, as emitted by e.g. StableHLO frontends) or arith.maxnumf (NaN-ignoring). The latter is what linalg.softmax itself decomposes to (SoftmaxOp::decomposeOperation and the iree-codegen-decompose-softmax pass both use arith.maxnumf), so matching only arith.maximumf made the matcher narrower than the op it raises to and missed that form. Accept both ops for the max reduction. This is strictly safer: maxnumf is the form linalg.softmax decomposes to, so raising it introduces no NaN behavior change. Add a positive lit test for the maxnumf spelling. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Han || Alex <36247722+Alex-Wengg@users.noreply.github.com>
1 parent 8c7f8c5 commit 74bec5d

2 files changed

Lines changed: 56 additions & 4 deletions

File tree

compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -639,12 +639,21 @@ static FailureOr<Value> matchSoftmax(linalg::LinalgOp rootOp) {
639639
return failure();
640640
}
641641

642-
// max = reduce_max(src), reducing the same source the subtraction reads.
642+
// max = reduce_max(src), reducing the same source the subtraction reads. The
643+
// init must be -inf or the lowest finite value. Accept both arith.maximumf
644+
// (NaN-propagating, as emitted by e.g. StableHLO frontends) and arith.maxnumf
645+
// (NaN-ignoring, which is what linalg.softmax itself decomposes to), since
646+
// both denote the stabilizing max of a softmax.
647+
auto isNegInfOrLowest = [](APFloat f) {
648+
return (f.isLargest() || f.isInfinity()) && f.isNegative();
649+
};
643650
Value source = subSource->get();
644651
Value reducedValue =
645-
matchInnermostReduction<arith::MaximumFOp>(maxValue, [](APFloat f) {
646-
return (f.isLargest() || f.isInfinity()) && f.isNegative();
647-
});
652+
matchInnermostReduction<arith::MaximumFOp>(maxValue, isNegInfOrLowest);
653+
if (!reducedValue) {
654+
reducedValue =
655+
matchInnermostReduction<arith::MaxNumFOp>(maxValue, isNegInfOrLowest);
656+
}
648657
if (reducedValue != source) {
649658
return failure();
650659
}

compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,49 @@ util.func public @softmax_broadcast(%93 : tensor<12x128x128xf32>) -> (tensor<12x
168168

169169
// -----
170170

171+
// The stabilizing max may use arith.maxnumf (NaN-ignoring) instead of
172+
// arith.maximumf -- this is the form linalg.softmax itself decomposes to.
173+
// CHECK-LABEL: @softmax_maxnumf
174+
// CHECK-SAME: %[[ARG:.+]]: tensor<2x4xf32>
175+
// CHECK: %[[S:.+]] = linalg.softmax dimension(1) ins(%[[ARG]] : tensor<2x4xf32>)
176+
// CHECK: util.return %[[S]]
177+
util.func public @softmax_maxnumf(%src : tensor<2x4xf32>) -> (tensor<2x4xf32>) {
178+
%cst0 = arith.constant 0.000000e+00 : f32
179+
%cstlow = arith.constant -3.40282347E+38 : f32
180+
%e1 = tensor.empty() : tensor<2xf32>
181+
%e2 = tensor.empty() : tensor<2x4xf32>
182+
%fillmax = linalg.fill ins(%cstlow : f32) outs(%e1 : tensor<2xf32>) -> tensor<2xf32>
183+
%max = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%src : tensor<2x4xf32>) outs(%fillmax : tensor<2xf32>) {
184+
^bb0(%a: f32, %b: f32):
185+
%m = arith.maxnumf %a, %b : f32
186+
linalg.yield %m : f32
187+
} -> tensor<2xf32>
188+
%sub = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%src, %max : tensor<2x4xf32>, tensor<2xf32>) outs(%e2 : tensor<2x4xf32>) {
189+
^bb0(%a: f32, %b: f32, %c: f32):
190+
%s = arith.subf %a, %b : f32
191+
linalg.yield %s : f32
192+
} -> tensor<2x4xf32>
193+
%exp = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%sub : tensor<2x4xf32>) outs(%e2 : tensor<2x4xf32>) {
194+
^bb0(%a: f32, %b: f32):
195+
%e = math.exp %a : f32
196+
linalg.yield %e : f32
197+
} -> tensor<2x4xf32>
198+
%fillsum = linalg.fill ins(%cst0 : f32) outs(%e1 : tensor<2xf32>) -> tensor<2xf32>
199+
%sum = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%exp : tensor<2x4xf32>) outs(%fillsum : tensor<2xf32>) {
200+
^bb0(%a: f32, %b: f32):
201+
%s = arith.addf %a, %b : f32
202+
linalg.yield %s : f32
203+
} -> tensor<2xf32>
204+
%div = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%exp, %sum : tensor<2x4xf32>, tensor<2xf32>) outs(%e2 : tensor<2x4xf32>) {
205+
^bb0(%a: f32, %b: f32, %c: f32):
206+
%d = arith.divf %a, %b : f32
207+
linalg.yield %d : f32
208+
} -> tensor<2x4xf32>
209+
util.return %div : tensor<2x4xf32>
210+
}
211+
212+
// -----
213+
171214
// Negative test: the max reduction is initialized with 0.0 instead of -inf, so
172215
// this is not a numerically-stabilized softmax and must not be raised.
173216
// CHECK-LABEL: @not_softmax_wrong_max_init

0 commit comments

Comments
 (0)