@@ -235,19 +235,21 @@ struct KernelMergeSort {
235235 const device T* inp,
236236 device U* out,
237237 const constant int & size_sorted_axis,
238- const constant int & stride_sorted_axis,
239- const constant int & stride_segment_axis,
238+ const constant int & in_stride_sorted_axis,
239+ const constant int & out_stride_sorted_axis,
240+ const constant int & in_stride_segment_axis,
241+ const constant int & out_stride_segment_axis,
240242 threadgroup val_t * tgp_vals,
241243 threadgroup idx_t * tgp_idxs,
242244 uint3 tid [[threadgroup_position_in_grid]],
243245 uint3 lid [[thread_position_in_threadgroup]]) {
244246 // tid.y tells us the segment index
245- inp += tid.y * stride_segment_axis ;
246- out += tid.y * stride_segment_axis ;
247+ inp += tid.y * in_stride_segment_axis ;
248+ out += tid.y * out_stride_segment_axis ;
247249
248250 // Copy into threadgroup memory
249251 for (short i = lid.x ; i < N_PER_BLOCK; i += BLOCK_THREADS) {
250- tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis ]
252+ tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis ]
251253 : val_t (CompareOp::init);
252254 if (ARG_SORT) {
253255 tgp_idxs[i] = i;
@@ -264,9 +266,9 @@ struct KernelMergeSort {
264266 // Write output
265267 for (int i = lid.x ; i < size_sorted_axis; i += BLOCK_THREADS) {
266268 if (ARG_SORT) {
267- out[i * stride_sorted_axis ] = tgp_idxs[i];
269+ out[i * out_stride_sorted_axis ] = tgp_idxs[i];
268270 } else {
269- out[i * stride_sorted_axis ] = tgp_vals[i];
271+ out[i * out_stride_sorted_axis ] = tgp_vals[i];
270272 }
271273 }
272274 }
@@ -282,8 +284,10 @@ template <
282284 const device T* inp [[buffer(0 )]],
283285 device U* out [[buffer(1 )]],
284286 const constant int& size_sorted_axis [[buffer(2 )]],
285- const constant int& stride_sorted_axis [[buffer(3 )]],
286- const constant int& stride_segment_axis [[buffer(4 )]],
287+ const constant int& in_stride_sorted_axis [[buffer(3 )]],
288+ const constant int& out_stride_sorted_axis [[buffer(4 )]],
289+ const constant int& in_stride_segment_axis [[buffer(5 )]],
290+ const constant int& out_stride_segment_axis [[buffer(6 )]],
287291 uint3 tid [[threadgroup_position_in_grid]],
288292 uint3 lid [[thread_position_in_threadgroup]]) {
289293 using sort_kernel =
@@ -298,8 +302,10 @@ template <
298302 inp,
299303 out,
300304 size_sorted_axis,
301- stride_sorted_axis,
302- stride_segment_axis,
305+ in_stride_sorted_axis,
306+ out_stride_sorted_axis,
307+ in_stride_segment_axis,
308+ out_stride_segment_axis,
303309 tgp_vals,
304310 tgp_idxs,
305311 tid,
@@ -310,8 +316,10 @@ template <
310316 inp,
311317 out,
312318 size_sorted_axis,
313- stride_sorted_axis,
314- stride_segment_axis,
319+ in_stride_sorted_axis,
320+ out_stride_sorted_axis,
321+ in_stride_segment_axis,
322+ out_stride_segment_axis,
315323 tgp_vals,
316324 nullptr ,
317325 tid,
@@ -331,20 +339,23 @@ template <
331339 const device T* inp [[buffer(0 )]],
332340 device U* out [[buffer(1 )]],
333341 const constant int& size_sorted_axis [[buffer(2 )]],
334- const constant int& stride_sorted_axis [[buffer(3 )]],
335- const constant int& nc_dim [[buffer(4 )]],
336- const device int* nc_shape [[buffer(5 )]],
337- const device size_t* nc_strides [[buffer(6 )]],
342+ const constant int& in_stride_sorted_axis [[buffer(3 )]],
343+ const constant int& out_stride_sorted_axis [[buffer(4 )]],
344+ const constant int& nc_dim [[buffer(5 )]],
345+ const device int* nc_shape [[buffer(6 )]],
346+ const device size_t* in_nc_strides [[buffer(7 )]],
347+ const device size_t* out_nc_strides [[buffer(8 )]],
338348 uint3 tid [[threadgroup_position_in_grid]],
339349 uint3 lid [[thread_position_in_threadgroup]]) {
340350 using sort_kernel =
341351 KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
342352 using val_t = typename sort_kernel::val_t ;
343353 using idx_t = typename sort_kernel::idx_t ;
344354
345- auto block_idx = elem_to_loc (tid.y , nc_shape, nc_strides, nc_dim);
346- inp += block_idx;
347- out += block_idx;
355+ auto in_block_idx = elem_to_loc (tid.y , nc_shape, in_nc_strides, nc_dim);
356+ auto out_block_idx = elem_to_loc (tid.y , nc_shape, out_nc_strides, nc_dim);
357+ inp += in_block_idx;
358+ out += out_block_idx;
348359
349360 if (ARG_SORT) {
350361 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
@@ -353,7 +364,9 @@ template <
353364 inp,
354365 out,
355366 size_sorted_axis,
356- stride_sorted_axis,
367+ in_stride_sorted_axis,
368+ out_stride_sorted_axis,
369+ zero_helper,
357370 zero_helper,
358371 tgp_vals,
359372 tgp_idxs,
@@ -365,7 +378,9 @@ template <
365378 inp,
366379 out,
367380 size_sorted_axis,
368- stride_sorted_axis,
381+ in_stride_sorted_axis,
382+ out_stride_sorted_axis,
383+ zero_helper,
369384 zero_helper,
370385 tgp_vals,
371386 nullptr ,
0 commit comments