Skip to content

Commit 5144eb6

Browse files
committed
nits
1 parent dfa6121 commit 5144eb6

File tree

1 file changed

+92
-53
lines changed

1 file changed

+92
-53
lines changed

mlx/backend/metal/sort.cpp

Lines changed: 92 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,20 @@ void single_block_sort(
5252
contiguous &= check_strides(out, out_stride_sorted_axis);
5353

5454
// Prepare kernel name
55-
std::ostringstream kname;
56-
kname << (contiguous ? "c" : "nc");
57-
if (argsort) {
58-
kname << "arg";
59-
}
60-
61-
kname << "_block_sort_" << type_to_name(in) << "_" << type_to_name(out)
62-
<< "_bn" << bn << "_tn" << tn;
63-
auto kernel = get_sort_kernel(d, kname.str(), in, out, bn, tn);
55+
std::string kname;
56+
concatenate(
57+
kname,
58+
contiguous ? "c" : "nc",
59+
argsort ? "arg" : "",
60+
"_block_sort_",
61+
type_to_name(in),
62+
"_",
63+
type_to_name(out),
64+
"_bn",
65+
bn,
66+
"_tn",
67+
tn);
68+
auto kernel = get_sort_kernel(d, kname, in, out, bn, tn);
6469

6570
// Prepare command encoder
6671
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -164,11 +169,18 @@ void multi_block_sort(
164169

165170
// Do blockwise sort
166171
{
167-
std::ostringstream kname;
168-
kname << "sort_mbsort_" << type_to_name(dev_vals_0) << "_"
169-
<< type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn;
170-
auto kernel =
171-
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
172+
std::string kname;
173+
concatenate(
174+
kname,
175+
"sort_mbsort_",
176+
type_to_name(dev_vals_0),
177+
"_",
178+
type_to_name(dev_idxs_0),
179+
"_bn",
180+
std::to_string(bn),
181+
"_tn",
182+
std::to_string(tn));
183+
auto kernel = get_mb_sort_kernel(d, kname, dev_vals_0, dev_idxs_0, bn, tn);
172184
compute_encoder.set_compute_pipeline_state(kernel);
173185

174186
compute_encoder.set_input_array(in, 0);
@@ -204,12 +216,20 @@ void multi_block_sort(
204216

205217
// Do partition
206218
{
207-
std::ostringstream kname;
208-
kname << "partition_mbsort_" << type_to_name(dev_vals_in) << "_"
209-
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
219+
std::string kname;
220+
concatenate(
221+
kname,
222+
"partition_mbsort_",
223+
type_to_name(dev_vals_in),
224+
"_",
225+
type_to_name(dev_idxs_in),
226+
"_bn",
227+
std::to_string(bn),
228+
"_tn",
229+
std::to_string(tn));
210230

211231
auto kernel =
212-
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
232+
get_mb_sort_kernel(d, kname, dev_vals_0, dev_idxs_0, bn, tn);
213233
compute_encoder.set_compute_pipeline_state(kernel);
214234

215235
compute_encoder.set_output_array(block_partitions, 0);
@@ -227,12 +247,20 @@ void multi_block_sort(
227247

228248
// Do merge
229249
{
230-
std::ostringstream kname;
231-
kname << "merge_mbsort_" << type_to_name(dev_vals_in) << "_"
232-
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
250+
std::string kname;
251+
concatenate(
252+
kname,
253+
"merge_mbsort_",
254+
type_to_name(dev_vals_in),
255+
"_",
256+
type_to_name(dev_idxs_in),
257+
"_bn",
258+
std::to_string(bn),
259+
"_tn",
260+
std::to_string(tn));
233261

234262
auto kernel =
235-
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
263+
get_mb_sort_kernel(d, kname, dev_vals_0, dev_idxs_0, bn, tn);
236264
compute_encoder.set_compute_pipeline_state(kernel);
237265

238266
compute_encoder.set_input_array(block_partitions, 0);
@@ -313,14 +341,6 @@ void gpu_merge_sort(
313341
}
314342
}
315343

316-
///////////////////////////////////////////////////////////////////////////////
317-
// Radix Select for Partition Operations
318-
//
319-
// Uses radix-based selection for partition operations:
320-
// - Small arrays (<=2048): Single-pass kernel with threadgroup memory
321-
// - Large arrays (>2048): Streaming multi-pass kernel
322-
///////////////////////////////////////////////////////////////////////////////
323-
324344
void gpu_radix_partition_small(
325345
const Stream& s,
326346
metal::Device& d,
@@ -340,13 +360,22 @@ void gpu_radix_partition_small(
340360
constexpr int bn = 256;
341361
constexpr int tn = 8;
342362

343-
std::ostringstream kname;
344-
kname << (contiguous ? "c" : "nc");
345-
kname << (arg_partition ? "arg_" : "_");
346-
kname << "radix_select_" << type_to_name(in) << "_" << type_to_name(out)
347-
<< "_bn" << bn << "_tn" << tn;
348-
349-
auto kernel = get_radix_select_kernel(d, kname.str(), in, out, bn, tn);
363+
std::string kname;
364+
concatenate(
365+
kname,
366+
kname,
367+
contiguous ? "c" : "nc",
368+
arg_partition ? "arg_" : "_",
369+
"radix_select_",
370+
type_to_name(in),
371+
"_",
372+
type_to_name(out),
373+
"_bn",
374+
std::to_string(bn),
375+
"_tn",
376+
std::to_string(tn));
377+
378+
auto kernel = get_radix_select_kernel(d, kname, in, out, bn, tn);
350379

351380
auto& compute_encoder = d.get_command_encoder(s.index);
352381
compute_encoder.set_compute_pipeline_state(kernel);
@@ -421,11 +450,19 @@ void gpu_radix_partition_large(
421450
auto& compute_encoder = d.get_command_encoder(s.index);
422451

423452
// Use the streaming kernel that processes all passes in one dispatch
424-
std::ostringstream kname;
425-
kname << "radix_select_large_" << type_to_name(in) << "_" << type_to_name(out)
426-
<< "_" << (arg_partition ? "true" : "false") << "_bn" << bn;
427-
428-
auto kernel = d.get_kernel(kname.str());
453+
std::string kname;
454+
concatenate(
455+
kname,
456+
"radix_select_large_",
457+
type_to_name(in),
458+
"_",
459+
type_to_name(out),
460+
"_",
461+
arg_partition ? "true" : "false",
462+
"_bn",
463+
std::to_string(bn));
464+
465+
auto kernel = d.get_kernel(kname);
429466
compute_encoder.set_compute_pipeline_state(kernel);
430467

431468
compute_encoder.set_input_array(in, 0);
@@ -475,12 +512,19 @@ void gpu_radix_partition_large_nc(
475512
auto& compute_encoder = d.get_command_encoder(s.index);
476513

477514
// Use the non-contiguous streaming kernel
478-
std::ostringstream kname;
479-
kname << "radix_select_large_nc_" << type_to_name(in) << "_"
480-
<< type_to_name(out) << "_" << (arg_partition ? "true" : "false")
481-
<< "_bn" << bn;
482-
483-
auto kernel = d.get_kernel(kname.str());
515+
std::string kname;
516+
concatenate(
517+
kname,
518+
"radix_select_large_nc_",
519+
type_to_name(in),
520+
"_",
521+
type_to_name(out),
522+
"_",
523+
arg_partition ? "true" : "false",
524+
"_bn",
525+
bn);
526+
527+
auto kernel = d.get_kernel(kname);
484528
compute_encoder.set_compute_pipeline_state(kernel);
485529

486530
compute_encoder.set_input_array(in, 0);
@@ -524,11 +568,6 @@ void gpu_radix_partition(
524568
int axis = axis_ < 0 ? axis_ + in.ndim() : axis_;
525569
int size_sorted_axis = in.shape(axis);
526570

527-
// Normalize kth
528-
if (kth < 0) {
529-
kth += size_sorted_axis;
530-
}
531-
532571
// For very small arrays, fall back to full sort
533572
constexpr int RADIX_SELECT_THRESHOLD = 64;
534573
if (size_sorted_axis <= RADIX_SELECT_THRESHOLD) {

0 commit comments

Comments
 (0)