@@ -222,22 +222,24 @@ void multi_block_sort(
222222 // Copy outputs with appropriate strides
223223 array strided_out_arr = argsort ? dev_idxs_out : dev_vals_out;
224224
225- if (axis == strided_out_arr .ndim () - 1 ) {
225+ if (axis == in .ndim () - 1 ) {
226226 copy_gpu_inplace (strided_out_arr, out, CopyType::Vector, s);
227227 } else {
228- std::vector<int > strided_out_shape = strided_out_arr.shape ();
229- std::vector<size_t > strided_out_str = strided_out_arr.strides ();
230-
228+ std::vector<int > strided_out_shape = in.shape ();
231229 int out_axis_shape = strided_out_shape[axis];
232- int out_axis_str = strided_out_str[axis];
233230
234231 strided_out_shape.erase (strided_out_shape.begin () + axis);
235- strided_out_str.erase (strided_out_str.begin () + axis);
236-
237232 strided_out_shape.push_back (out_axis_shape);
238- strided_out_str.push_back (out_axis_str);
239233
240- array strided_out_slice (strided_out_shape, out.dtype (), nullptr , {});
234+ std::vector<size_t > strided_out_str (in.ndim (), 1 );
235+ for (int i = in.ndim () - 2 ; i >= 0 ; --i) {
236+ strided_out_str[i] = strided_out_str[i + 1 ] * strided_out_shape[i + 1 ];
237+ }
238+
239+ strided_out_str.erase (strided_out_str.end () - 1 );
240+ strided_out_str.insert (strided_out_str.begin () + axis, 1 );
241+
242+ array strided_out_slice (in.shape (), out.dtype (), nullptr , {});
241243 strided_out_slice.copy_shared_buffer (
242244 strided_out_arr,
243245 strided_out_str,
0 commit comments