Skip to content

Commit 76b6cec

Browse files
authored
Fix multi-block sort stride management (#1169)
* Fix multi-block sort stride management * Add seed to tests
1 parent 9f0df51 commit 76b6cec

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

mlx/backend/metal/sort.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -222,22 +222,24 @@ void multi_block_sort(
222222
// Copy outputs with appropriate strides
223223
array strided_out_arr = argsort ? dev_idxs_out : dev_vals_out;
224224

225-
if (axis == strided_out_arr.ndim() - 1) {
225+
if (axis == in.ndim() - 1) {
226226
copy_gpu_inplace(strided_out_arr, out, CopyType::Vector, s);
227227
} else {
228-
std::vector<int> strided_out_shape = strided_out_arr.shape();
229-
std::vector<size_t> strided_out_str = strided_out_arr.strides();
230-
228+
std::vector<int> strided_out_shape = in.shape();
231229
int out_axis_shape = strided_out_shape[axis];
232-
int out_axis_str = strided_out_str[axis];
233230

234231
strided_out_shape.erase(strided_out_shape.begin() + axis);
235-
strided_out_str.erase(strided_out_str.begin() + axis);
236-
237232
strided_out_shape.push_back(out_axis_shape);
238-
strided_out_str.push_back(out_axis_str);
239233

240-
array strided_out_slice(strided_out_shape, out.dtype(), nullptr, {});
234+
std::vector<size_t> strided_out_str(in.ndim(), 1);
235+
for (int i = in.ndim() - 2; i >= 0; --i) {
236+
strided_out_str[i] = strided_out_str[i + 1] * strided_out_shape[i + 1];
237+
}
238+
239+
strided_out_str.erase(strided_out_str.end() - 1);
240+
strided_out_str.insert(strided_out_str.begin() + axis, 1);
241+
242+
array strided_out_slice(in.shape(), out.dtype(), nullptr, {});
241243
strided_out_slice.copy_shared_buffer(
242244
strided_out_arr,
243245
strided_out_str,

python/tests/test_ops.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1754,6 +1754,9 @@ def test_sort(self):
17541754
self.assertTrue(np.array_equal(d_np, d_mx))
17551755
self.assertEqual(c_mx.dtype, mx.uint32)
17561756

1757+
# Set random seed
1758+
np.random.seed(0)
1759+
17571760
# Test multi-block sort
17581761
a_np = np.random.normal(size=(32769,)).astype(np.float32)
17591762
a_mx = mx.array(a_np)
@@ -1764,6 +1767,25 @@ def test_sort(self):
17641767
self.assertTrue(np.array_equal(b_np, b_mx))
17651768
self.assertEqual(b_mx.dtype, a_mx.dtype)
17661769

1770+
# Test multi-dum multi-block sort
1771+
a_np = np.random.normal(size=(2, 4, 32769)).astype(np.float32)
1772+
a_mx = mx.array(a_np)
1773+
1774+
b_np = np.sort(a_np, axis=-1)
1775+
b_mx = mx.sort(a_mx, axis=-1)
1776+
1777+
self.assertTrue(np.array_equal(b_np, b_mx))
1778+
self.assertEqual(b_mx.dtype, a_mx.dtype)
1779+
1780+
a_np = np.random.normal(size=(2, 32769, 4)).astype(np.float32)
1781+
a_mx = mx.array(a_np)
1782+
1783+
b_np = np.sort(a_np, axis=1)
1784+
b_mx = mx.sort(a_mx, axis=1)
1785+
1786+
self.assertTrue(np.array_equal(b_np, b_mx))
1787+
self.assertEqual(b_mx.dtype, a_mx.dtype)
1788+
17671789
def test_partition(self):
17681790
shape = (3, 4, 5)
17691791
for dtype in ("int32", "float32"):

0 commit comments

Comments
 (0)