Skip to content

Commit 9cbb1b0

Browse files
Maalvi14awni
andauthored
Modified sort behavior when running CPU or Metal to match NumPy/JAX (#2667)
* Modified sort behavior when running CPU or Metal to match NumPy/JAX sorting behavior. * Modified sort behavior when running CPU or Metal to match NumPy/JAX * nits --------- Co-authored-by: Awni Hannun <[email protected]>
1 parent 9bfc476 commit 9cbb1b0

File tree

3 files changed

+58
-7
lines changed

3 files changed

+58
-7
lines changed

mlx/backend/cpu/sort.cpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@ namespace mlx::core {
1515

1616
namespace {
1717

18+
// NaN-aware comparator that places NaNs at the end
19+
template <typename T>
20+
bool nan_aware_less(T a, T b) {
21+
if constexpr (std::is_floating_point_v<T> || std::is_same_v<T, complex64_t>) {
22+
if (std::isnan(a))
23+
return false;
24+
if (std::isnan(b))
25+
return true;
26+
}
27+
return a < b;
28+
}
29+
1830
template <typename T>
1931
struct StridedIterator {
2032
using iterator_category = std::random_access_iterator_tag;
@@ -130,7 +142,7 @@ void sort(array& out, int axis) {
130142
StridedIterator st(data_ptr, axis_stride, 0);
131143
StridedIterator ed(data_ptr, axis_stride, axis_size);
132144

133-
std::stable_sort(st, ed);
145+
std::stable_sort(st, ed, nan_aware_less<T>);
134146
src_it.step();
135147
}
136148
}
@@ -184,6 +196,15 @@ void argsort(const array& in, array& out, int axis) {
184196
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
185197
auto v1 = data_ptr[a * in_stride];
186198
auto v2 = data_ptr[b * in_stride];
199+
200+
// Handle NaNs (place them at the end)
201+
if (std::is_floating_point<T>::value) {
202+
if (std::isnan(v1))
203+
return false;
204+
if (std::isnan(v2))
205+
return true;
206+
}
207+
187208
return v1 < v2 || (v1 == v2 && a < b);
188209
});
189210
}
@@ -219,7 +240,7 @@ void partition(array& out, int axis, int kth) {
219240
StridedIterator md(data_ptr, axis_stride, kth);
220241
StridedIterator ed(data_ptr, axis_stride, axis_size);
221242

222-
std::nth_element(st, md, ed);
243+
std::nth_element(st, md, ed, nan_aware_less<T>);
223244
}
224245
}
225246

@@ -276,6 +297,15 @@ void argpartition(const array& in, array& out, int axis, int kth) {
276297
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
277298
auto v1 = data_ptr[a * in_stride];
278299
auto v2 = data_ptr[b * in_stride];
300+
301+
// Handle NaNs (place them at the end)
302+
if (std::is_floating_point<T>::value) {
303+
if (std::isnan(v1))
304+
return false;
305+
if (std::isnan(v2))
306+
return true;
307+
}
308+
279309
return v1 < v2 || (v1 == v2 && a < b);
280310
});
281311
}

mlx/backend/metal/kernels/sort.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,28 @@ METAL_FUNC void thread_swap(thread T& a, thread T& b) {
1919
b = w;
2020
}
2121

22+
template <typename T, typename = void>
23+
struct Init {
24+
static constexpr constant T v = Limits<T>::max;
25+
};
26+
2227
template <typename T>
23-
struct LessThan {
24-
static constexpr constant T init = Limits<T>::max;
28+
struct Init<T, metal::enable_if_t<metal::is_floating_point_v<T>>> {
29+
static constexpr constant T v = metal::numeric_limits<T>::quiet_NaN();
30+
};
2531

26-
METAL_FUNC bool operator()(T a, T b) {
32+
template <typename T>
33+
struct LessThan {
34+
static constexpr constant T init = Init<T>::v;
35+
METAL_FUNC bool operator()(T a, T b) const {
36+
if constexpr (
37+
metal::is_floating_point_v<T> || metal::is_same_v<T, complex64_t>) {
38+
bool an = isnan(a);
39+
bool bn = isnan(b);
40+
if (an | bn) {
41+
return (!an) & bn;
42+
}
43+
}
2744
return a < b;
2845
}
2946
};

python/tests/test_ops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3100,8 +3100,6 @@ def test_depends(self):
31003100
out = mx.depends(b, c)
31013101
self.assertTrue(mx.array_equal(out, b))
31023102

3103-
3104-
class TestBroadcast(mlx_tests.MLXTestCase):
31053103
def test_broadcast_shapes(self):
31063104
# Basic broadcasting
31073105
self.assertEqual(mx.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3))
@@ -3140,6 +3138,12 @@ def test_broadcast_shapes(self):
31403138
with self.assertRaises(ValueError):
31413139
mx.broadcast_shapes()
31423140

3141+
def test_sort_nan(self):
3142+
x = mx.array([3.0, mx.nan, 2.0, 0.0])
3143+
expected = mx.array([0.0, 2.0, 3.0, mx.nan])
3144+
self.assertTrue(mx.array_equal(mx.sort(x), expected, equal_nan=True))
3145+
x = mx.array([3.0, mx.nan, 2.0, 0.0]) + 1j * mx.array([1.0] * 4)
3146+
31433147

31443148
if __name__ == "__main__":
31453149
mlx_tests.MLXTestRunner()

0 commit comments

Comments
 (0)