@@ -522,13 +522,13 @@ template <
522522 bool ARG_SORT,
523523 short BLOCK_THREADS,
524524 short N_PER_THREAD>
525- [[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
526- mb_block_partition (
525+ [[kernel]] void mb_block_partition (
527526 device idx_t * block_partitions [[buffer(0 )]],
528527 const device val_t* dev_vals [[buffer(1 )]],
529528 const device idx_t* dev_idxs [[buffer(2 )]],
530529 const constant int& size_sorted_axis [[buffer(3 )]],
531530 const constant int& merge_tiles [[buffer(4 )]],
531+ const constant int& n_blocks [[buffer(5 )]],
532532 uint3 tid [[threadgroup_position_in_grid]],
533533 uint3 lid [[thread_position_in_threadgroup]],
534534 uint3 tgp_dims [[threads_per_threadgroup]]) {
@@ -543,23 +543,29 @@ mb_block_partition(
543543 dev_vals += tid.y * size_sorted_axis;
544544 dev_idxs += tid.y * size_sorted_axis;
545545
546- // Find location in merge step
547- int merge_group = lid.x / merge_tiles;
548- int merge_lane = lid.x % merge_tiles;
546+ for (int i = lid.x ; i <= n_blocks; i += tgp_dims.x ) {
547+ // Find location in merge step
548+ int merge_group = i / merge_tiles;
549+ int merge_lane = i % merge_tiles;
549550
550- int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
551- int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
551+ int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
552+ int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
552553
553- int A_st = min (size_sorted_axis, sort_st);
554- int A_ed = min (size_sorted_axis, sort_st + sort_sz / 2 );
555- int B_st = A_ed;
556- int B_ed = min (size_sorted_axis, B_st + sort_sz / 2 );
554+ int A_st = min (size_sorted_axis, sort_st);
555+ int A_ed = min (size_sorted_axis, sort_st + sort_sz / 2 );
556+ int B_st = A_ed;
557+ int B_ed = min (size_sorted_axis, B_st + sort_sz / 2 );
557558
558- int partition_at = min (B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
559- int partition = sort_kernel::merge_partition (
560- dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);
559+ int partition_at = min (B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
560+ int partition = sort_kernel::merge_partition (
561+ dev_vals + A_st,
562+ dev_vals + B_st,
563+ A_ed - A_st,
564+ B_ed - B_st,
565+ partition_at);
561566
562- block_partitions[lid.x ] = A_st + partition;
567+ block_partitions[i] = A_st + partition;
568+ }
563569}
564570
565571template <
0 commit comments