@@ -52,15 +52,20 @@ void single_block_sort(
5252 contiguous &= check_strides (out, out_stride_sorted_axis);
5353
5454 // Prepare kernel name
55- std::ostringstream kname;
56- kname << (contiguous ? " c" : " nc" );
57- if (argsort) {
58- kname << " arg" ;
59- }
60-
61- kname << " _block_sort_" << type_to_name (in) << " _" << type_to_name (out)
62- << " _bn" << bn << " _tn" << tn;
63- auto kernel = get_sort_kernel (d, kname.str (), in, out, bn, tn);
55+ std::string kname;
56+ concatenate (
57+ kname,
58+ contiguous ? " c" : " nc" ,
59+ argsort ? " arg" : " " ,
60+ " _block_sort_" ,
61+ type_to_name (in),
62+ " _" ,
63+ type_to_name (out),
64+ " _bn" ,
65+ bn,
66+ " _tn" ,
67+ tn);
68+ auto kernel = get_sort_kernel (d, kname, in, out, bn, tn);
6469
6570 // Prepare command encoder
6671 auto & compute_encoder = d.get_command_encoder (s.index );
@@ -164,11 +169,18 @@ void multi_block_sort(
164169
165170 // Do blockwise sort
166171 {
167- std::ostringstream kname;
168- kname << " sort_mbsort_" << type_to_name (dev_vals_0) << " _"
169- << type_to_name (dev_idxs_0) << " _bn" << bn << " _tn" << tn;
170- auto kernel =
171- get_mb_sort_kernel (d, kname.str (), dev_vals_0, dev_idxs_0, bn, tn);
172+ std::string kname;
173+ concatenate (
174+ kname,
175+ " sort_mbsort_" ,
176+ type_to_name (dev_vals_0),
177+ " _" ,
178+ type_to_name (dev_idxs_0),
179+ " _bn" ,
180+ std::to_string (bn),
181+ " _tn" ,
182+ std::to_string (tn));
183+ auto kernel = get_mb_sort_kernel (d, kname, dev_vals_0, dev_idxs_0, bn, tn);
172184 compute_encoder.set_compute_pipeline_state (kernel);
173185
174186 compute_encoder.set_input_array (in, 0 );
@@ -204,12 +216,20 @@ void multi_block_sort(
204216
205217 // Do partition
206218 {
207- std::ostringstream kname;
208- kname << " partition_mbsort_" << type_to_name (dev_vals_in) << " _"
209- << type_to_name (dev_idxs_in) << " _bn" << bn << " _tn" << tn;
219+ std::string kname;
220+ concatenate (
221+ kname,
222+ " partition_mbsort_" ,
223+ type_to_name (dev_vals_in),
224+ " _" ,
225+ type_to_name (dev_idxs_in),
226+ " _bn" ,
227+ std::to_string (bn),
228+ " _tn" ,
229+ std::to_string (tn));
210230
211231 auto kernel =
212- get_mb_sort_kernel (d, kname. str () , dev_vals_0, dev_idxs_0, bn, tn);
232+ get_mb_sort_kernel (d, kname, dev_vals_0, dev_idxs_0, bn, tn);
213233 compute_encoder.set_compute_pipeline_state (kernel);
214234
215235 compute_encoder.set_output_array (block_partitions, 0 );
@@ -227,12 +247,20 @@ void multi_block_sort(
227247
228248 // Do merge
229249 {
230- std::ostringstream kname;
231- kname << " merge_mbsort_" << type_to_name (dev_vals_in) << " _"
232- << type_to_name (dev_idxs_in) << " _bn" << bn << " _tn" << tn;
250+ std::string kname;
251+ concatenate (
252+ kname,
253+ " merge_mbsort_" ,
254+ type_to_name (dev_vals_in),
255+ " _" ,
256+ type_to_name (dev_idxs_in),
257+ " _bn" ,
258+ std::to_string (bn),
259+ " _tn" ,
260+ std::to_string (tn));
233261
234262 auto kernel =
235- get_mb_sort_kernel (d, kname. str () , dev_vals_0, dev_idxs_0, bn, tn);
263+ get_mb_sort_kernel (d, kname, dev_vals_0, dev_idxs_0, bn, tn);
236264 compute_encoder.set_compute_pipeline_state (kernel);
237265
238266 compute_encoder.set_input_array (block_partitions, 0 );
@@ -313,14 +341,6 @@ void gpu_merge_sort(
313341 }
314342}
315343
316- // /////////////////////////////////////////////////////////////////////////////
317- // Radix Select for Partition Operations
318- //
319- // Uses radix-based selection for partition operations:
320- // - Small arrays (<=2048): Single-pass kernel with threadgroup memory
321- // - Large arrays (>2048): Streaming multi-pass kernel
322- // /////////////////////////////////////////////////////////////////////////////
323-
324344void gpu_radix_partition_small (
325345 const Stream& s,
326346 metal::Device& d,
@@ -340,13 +360,22 @@ void gpu_radix_partition_small(
340360 constexpr int bn = 256 ;
341361 constexpr int tn = 8 ;
342362
343- std::ostringstream kname;
344- kname << (contiguous ? " c" : " nc" );
345- kname << (arg_partition ? " arg_" : " _" );
346- kname << " radix_select_" << type_to_name (in) << " _" << type_to_name (out)
347- << " _bn" << bn << " _tn" << tn;
348-
349- auto kernel = get_radix_select_kernel (d, kname.str (), in, out, bn, tn);
363+ std::string kname;
364+ concatenate (
365+ kname,
366+ kname,
367+ contiguous ? " c" : " nc" ,
368+ arg_partition ? " arg_" : " _" ,
369+ " radix_select_" ,
370+ type_to_name (in),
371+ " _" ,
372+ type_to_name (out),
373+ " _bn" ,
374+ std::to_string (bn),
375+ " _tn" ,
376+ std::to_string (tn));
377+
378+ auto kernel = get_radix_select_kernel (d, kname, in, out, bn, tn);
350379
351380 auto & compute_encoder = d.get_command_encoder (s.index );
352381 compute_encoder.set_compute_pipeline_state (kernel);
@@ -421,11 +450,19 @@ void gpu_radix_partition_large(
421450 auto & compute_encoder = d.get_command_encoder (s.index );
422451
423452 // Use the streaming kernel that processes all passes in one dispatch
424- std::ostringstream kname;
425- kname << " radix_select_large_" << type_to_name (in) << " _" << type_to_name (out)
426- << " _" << (arg_partition ? " true" : " false" ) << " _bn" << bn;
427-
428- auto kernel = d.get_kernel (kname.str ());
453+ std::string kname;
454+ concatenate (
455+ kname,
456+ " radix_select_large_" ,
457+ type_to_name (in),
458+ " _" ,
459+ type_to_name (out),
460+ " _" ,
461+ arg_partition ? " true" : " false" ,
462+ " _bn" ,
463+ std::to_string (bn));
464+
465+ auto kernel = d.get_kernel (kname);
429466 compute_encoder.set_compute_pipeline_state (kernel);
430467
431468 compute_encoder.set_input_array (in, 0 );
@@ -475,12 +512,19 @@ void gpu_radix_partition_large_nc(
475512 auto & compute_encoder = d.get_command_encoder (s.index );
476513
477514 // Use the non-contiguous streaming kernel
478- std::ostringstream kname;
479- kname << " radix_select_large_nc_" << type_to_name (in) << " _"
480- << type_to_name (out) << " _" << (arg_partition ? " true" : " false" )
481- << " _bn" << bn;
482-
483- auto kernel = d.get_kernel (kname.str ());
515+ std::string kname;
516+ concatenate (
517+ kname,
518+ " radix_select_large_nc_" ,
519+ type_to_name (in),
520+ " _" ,
521+ type_to_name (out),
522+ " _" ,
523+ arg_partition ? " true" : " false" ,
524+ " _bn" ,
525+ bn);
526+
527+ auto kernel = d.get_kernel (kname);
484528 compute_encoder.set_compute_pipeline_state (kernel);
485529
486530 compute_encoder.set_input_array (in, 0 );
@@ -524,11 +568,6 @@ void gpu_radix_partition(
524568 int axis = axis_ < 0 ? axis_ + in.ndim () : axis_;
525569 int size_sorted_axis = in.shape (axis);
526570
527- // Normalize kth
528- if (kth < 0 ) {
529- kth += size_sorted_axis;
530- }
531-
532571 // For very small arrays, fall back to full sort
533572 constexpr int RADIX_SELECT_THRESHOLD = 64 ;
534573 if (size_sorted_axis <= RADIX_SELECT_THRESHOLD) {
0 commit comments