Skip to content

Commit 7f91436

Browse files
authored
Fix GPU sort for large arrays (#1285)
* Fix GPU sort for large arrays
1 parent ebd7135 commit 7f91436

File tree

4 files changed

+34
-25
lines changed

4 files changed

+34
-25
lines changed

mlx/backend/metal/kernels/sort.h

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -522,13 +522,13 @@ template <
522522
bool ARG_SORT,
523523
short BLOCK_THREADS,
524524
short N_PER_THREAD>
525-
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
526-
mb_block_partition(
525+
[[kernel]] void mb_block_partition(
527526
device idx_t* block_partitions [[buffer(0)]],
528527
const device val_t* dev_vals [[buffer(1)]],
529528
const device idx_t* dev_idxs [[buffer(2)]],
530529
const constant int& size_sorted_axis [[buffer(3)]],
531530
const constant int& merge_tiles [[buffer(4)]],
531+
const constant int& n_blocks [[buffer(5)]],
532532
uint3 tid [[threadgroup_position_in_grid]],
533533
uint3 lid [[thread_position_in_threadgroup]],
534534
uint3 tgp_dims [[threads_per_threadgroup]]) {
@@ -543,23 +543,29 @@ mb_block_partition(
543543
dev_vals += tid.y * size_sorted_axis;
544544
dev_idxs += tid.y * size_sorted_axis;
545545

546-
// Find location in merge step
547-
int merge_group = lid.x / merge_tiles;
548-
int merge_lane = lid.x % merge_tiles;
546+
for (int i = lid.x; i <= n_blocks; i += tgp_dims.x) {
547+
// Find location in merge step
548+
int merge_group = i / merge_tiles;
549+
int merge_lane = i % merge_tiles;
549550

550-
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
551-
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
551+
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
552+
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
552553

553-
int A_st = min(size_sorted_axis, sort_st);
554-
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
555-
int B_st = A_ed;
556-
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
554+
int A_st = min(size_sorted_axis, sort_st);
555+
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
556+
int B_st = A_ed;
557+
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
557558

558-
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
559-
int partition = sort_kernel::merge_partition(
560-
dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);
559+
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
560+
int partition = sort_kernel::merge_partition(
561+
dev_vals + A_st,
562+
dev_vals + B_st,
563+
A_ed - A_st,
564+
B_ed - B_st,
565+
partition_at);
561566

562-
block_partitions[lid.x] = A_st + partition;
567+
block_partitions[i] = A_st + partition;
568+
}
563569
}
564570

565571
template <

mlx/backend/metal/sort.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ void multi_block_sort(
177177
array dev_vals_out = dev_vals_1;
178178
array dev_idxs_out = dev_idxs_1;
179179

180+
int n_thr_per_group = (n_blocks + 1) < 1024 ? (n_blocks + 1) : 1024;
181+
180182
for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) {
181183
dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
182184
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
@@ -199,8 +201,9 @@ void multi_block_sort(
199201
compute_encoder.set_input_array(dev_idxs_in, 2);
200202
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
201203
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
204+
compute_encoder->setBytes(&n_blocks, sizeof(int), 5);
202205

203-
MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1);
206+
MTL::Size group_dims = MTL::Size(n_thr_per_group, 1, 1);
204207
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
205208

206209
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);

mlx/ops.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1785,15 +1785,6 @@ array sort(const array& a, int axis, StreamOrDevice s /* = {} */) {
17851785
throw std::invalid_argument(msg.str());
17861786
}
17871787

1788-
// TODO: Fix GPU kernel
1789-
if (a.shape(axis) >= (1u << 21) && to_stream(s).device.type == Device::gpu) {
1790-
std::ostringstream msg;
1791-
msg << "[sort] GPU sort cannot handle sort axis of >= 2M elements,"
1792-
<< " got array with sort axis size " << a.shape(axis) << "."
1793-
<< " Please place this operation on the CPU instead.";
1794-
throw std::runtime_error(msg.str());
1795-
}
1796-
17971788
return array(
17981789
a.shape(), a.dtype(), std::make_shared<Sort>(to_stream(s), axis), {a});
17991790
}

python/tests/test_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1840,6 +1840,15 @@ def test_sort(self):
18401840
self.assertTrue(np.array_equal(c_np, c_mx))
18411841
self.assertEqual(b_mx.dtype, c_mx.dtype)
18421842

1843+
# Test very large array
1844+
if mx.default_device() == mx.gpu:
1845+
a_np = np.random.normal(20, 20, size=(2**22)).astype(np.float32)
1846+
a_mx = mx.array(a_np)
1847+
1848+
b_np = np.sort(a_np)
1849+
b_mx = mx.sort(a_mx)
1850+
self.assertTrue(np.array_equal(b_np, b_mx))
1851+
18431852
def test_partition(self):
18441853
shape = (3, 4, 5)
18451854
for dtype in ("int32", "float32"):

0 commit comments

Comments
 (0)