Skip to content

Commit 1fa0d20

Browse files
authored
consistently handle all -inf in softmax (#1470)
1 parent 3274c6a commit 1fa0d20

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

mlx/backend/accelerate/softmax.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ namespace {
3333
* Note: The implementation below is a general fast exp. There could be faster
3434
* implementations for numbers strictly < 0.
3535
*/
36-
inline simd_float16 simd_fast_exp(simd_float16 x) {
37-
x *= 1.442695; // multiply with log_2(e)
36+
inline simd_float16 simd_fast_exp(simd_float16 x_init) {
37+
auto x = x_init * 1.442695; // multiply with log_2(e)
3838
simd_float16 ipart, fpart;
3939
simd_int16 epart;
4040
x = simd_clamp(x, -80, 80);
@@ -53,7 +53,9 @@ inline simd_float16 simd_fast_exp(simd_float16 x) {
5353
// bitshifting
5454
epart = (simd_int(ipart) + 127) << 23;
5555

56-
return (*(simd_float16*)&epart) * x;
56+
// Avoid supressing NaNs
57+
simd_int16 eq = (x_init == x_init);
58+
return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq);
5759
}
5860

5961
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC

mlx/backend/metal/kernels/softmax.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
3232
}
3333
} else {
3434
for (int i = 0; i < N_READS; i++) {
35-
ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i])
36-
: Limits<AccT>::finite_min;
35+
ld[i] =
36+
((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
3737
}
3838
}
3939
if (simd_group_id == 0) {
40-
local_max[simd_lane_id] = Limits<AccT>::finite_min;
40+
local_max[simd_lane_id] = Limits<AccT>::min;
4141
local_normalizer[simd_lane_id] = 0;
4242
}
4343
threadgroup_barrier(mem_flags::mem_threadgroup);

python/tests/test_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1567,6 +1567,11 @@ def np_softmax(x, axis):
15671567
out = mx.softmax(a, axis=-1, precise=True)
15681568
self.assertTrue(mx.allclose(out_expect, out))
15691569

1570+
# All Infs give NaNs
1571+
for n in [127, 128, 129]:
1572+
x = mx.full((n,), vals=-float("inf"))
1573+
self.assertTrue(mx.all(mx.isnan(mx.softmax(x))))
1574+
15701575
def test_concatenate(self):
15711576
a_npy = np.random.randn(32, 32, 32)
15721577
b_npy = np.random.randn(32, 32, 32)

0 commit comments

Comments
 (0)