Skip to content

Commit 2615660

Browse files
authored
Fix strided sort bug (#1236)
* Use output strides in sort kernel * fix zero strides bug
1 parent 5b0af4c commit 2615660

File tree

7 files changed

+219
-259
lines changed

7 files changed

+219
-259
lines changed

mlx/backend/common/sort.cpp

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,14 @@ void sort(const array& in, array& out, int axis) {
113113
axis = axis < 0 ? axis + in.ndim() : axis;
114114
size_t n_rows = in.size() / in.shape(axis);
115115

116-
auto remaining_shape = in.shape();
116+
auto remaining_shape = out.shape();
117117
remaining_shape.erase(remaining_shape.begin() + axis);
118118

119-
auto remaining_strides = in.strides();
119+
auto remaining_strides = out.strides();
120120
remaining_strides.erase(remaining_strides.begin() + axis);
121121

122-
size_t axis_stride = in.strides()[axis];
123-
int axis_size = in.shape(axis);
122+
size_t axis_stride = out.strides()[axis];
123+
int axis_size = out.shape(axis);
124124

125125
// Perform sorting in place
126126
for (int i = 0; i < n_rows; i++) {
@@ -143,34 +143,42 @@ void argsort(const array& in, array& out, int axis) {
143143
axis = axis < 0 ? axis + in.ndim() : axis;
144144
size_t n_rows = in.size() / in.shape(axis);
145145

146-
auto remaining_shape = in.shape();
147-
remaining_shape.erase(remaining_shape.begin() + axis);
146+
auto in_remaining_shape = in.shape();
147+
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
148148

149-
auto remaining_strides = in.strides();
150-
remaining_strides.erase(remaining_strides.begin() + axis);
149+
auto in_remaining_strides = in.strides();
150+
in_remaining_strides.erase(in_remaining_strides.begin() + axis);
151151

152-
size_t axis_stride = in.strides()[axis];
152+
auto out_remaining_shape = out.shape();
153+
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
154+
155+
auto out_remaining_strides = out.strides();
156+
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
157+
158+
size_t in_stride = in.strides()[axis];
159+
size_t out_stride = out.strides()[axis];
153160
int axis_size = in.shape(axis);
154161

155162
// Perform sorting
156163
for (int i = 0; i < n_rows; i++) {
157-
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
158-
const T* data_ptr = in.data<T>() + loc;
159-
IdxT* idx_ptr = out.data<IdxT>() + loc;
164+
size_t in_loc = elem_to_loc(i, in_remaining_shape, in_remaining_strides);
165+
size_t out_loc = elem_to_loc(i, out_remaining_shape, out_remaining_strides);
166+
const T* data_ptr = in.data<T>() + in_loc;
167+
IdxT* idx_ptr = out.data<IdxT>() + out_loc;
160168

161-
StridedIterator st_(idx_ptr, axis_stride, 0);
162-
StridedIterator ed_(idx_ptr, axis_stride, axis_size);
169+
StridedIterator st_(idx_ptr, out_stride, 0);
170+
StridedIterator ed_(idx_ptr, out_stride, axis_size);
163171

164172
// Initialize with iota
165173
std::iota(st_, ed_, IdxT(0));
166174

167175
// Sort according to vals
168-
StridedIterator st(idx_ptr, axis_stride, 0);
169-
StridedIterator ed(idx_ptr, axis_stride, axis_size);
176+
StridedIterator st(idx_ptr, out_stride, 0);
177+
StridedIterator ed(idx_ptr, out_stride, axis_size);
170178

171-
std::stable_sort(st, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
172-
auto v1 = data_ptr[a * axis_stride];
173-
auto v2 = data_ptr[b * axis_stride];
179+
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
180+
auto v1 = data_ptr[a * in_stride];
181+
auto v2 = data_ptr[b * in_stride];
174182
return v1 < v2 || (v1 == v2 && a < b);
175183
});
176184
}

mlx/backend/metal/jit/sort.h

Lines changed: 0 additions & 81 deletions
This file was deleted.

mlx/backend/metal/jit_kernels.cpp

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include "mlx/backend/metal/jit/reduce.h"
99
#include "mlx/backend/metal/jit/scan.h"
1010
#include "mlx/backend/metal/jit/softmax.h"
11-
#include "mlx/backend/metal/jit/sort.h"
1211
#include "mlx/backend/metal/jit/steel_conv.h"
1312
#include "mlx/backend/metal/jit/steel_gemm.h"
1413
#include "mlx/backend/metal/kernels.h"
@@ -251,14 +250,29 @@ MTL::ComputePipelineState* get_sort_kernel(
251250
auto lib = d.get_library(lib_name);
252251
if (lib == nullptr) {
253252
std::ostringstream kernel_source;
254-
kernel_source << metal::utils() << metal::sort()
255-
<< fmt::format(
256-
block_sort_kernels,
257-
lib_name,
258-
get_type_string(in.dtype()),
259-
get_type_string(out.dtype()),
260-
bn,
261-
tn);
253+
auto in_type = get_type_string(in.dtype());
254+
auto out_type = get_type_string(out.dtype());
255+
kernel_source << metal::utils() << metal::sort();
256+
for (bool is_argsort : {true, false}) {
257+
std::string bool_string = is_argsort ? "true" : "false";
258+
std::string func_string = is_argsort ? "carg_" : "c_";
259+
kernel_source << get_template_definition(
260+
func_string + lib_name,
261+
"block_sort",
262+
in_type,
263+
out_type,
264+
bool_string,
265+
bn,
266+
tn);
267+
kernel_source << get_template_definition(
268+
"n" + func_string + lib_name,
269+
"block_sort_nc",
270+
in_type,
271+
out_type,
272+
bool_string,
273+
bn,
274+
tn);
275+
}
262276
lib = d.get_library(lib_name, kernel_source.str());
263277
}
264278
return d.get_kernel(kernel_name, lib);
@@ -275,14 +289,21 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
275289
auto lib = d.get_library(lib_name);
276290
if (lib == nullptr) {
277291
std::ostringstream kernel_source;
278-
kernel_source << metal::utils() << metal::sort()
279-
<< fmt::format(
280-
multiblock_sort_kernels,
281-
lib_name,
282-
get_type_string(in.dtype()),
283-
get_type_string(idx.dtype()),
284-
bn,
285-
tn);
292+
kernel_source << metal::utils() << metal::sort();
293+
std::vector<std::pair<std::string, std::string>> kernel_types = {
294+
{"sort_", "mb_block_sort"},
295+
{"partition_", "mb_block_partition"},
296+
{"merge_", "mb_block_merge"}};
297+
for (auto [name, func] : kernel_types) {
298+
kernel_source << get_template_definition(
299+
name + lib_name,
300+
func,
301+
get_type_string(in.dtype()),
302+
get_type_string(idx.dtype()),
303+
"true",
304+
bn,
305+
tn);
306+
}
286307
lib = d.get_library(lib_name, kernel_source.str());
287308
}
288309
return d.get_kernel(kernel_name, lib);

mlx/backend/metal/kernels/sort.h

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -235,19 +235,21 @@ struct KernelMergeSort {
235235
const device T* inp,
236236
device U* out,
237237
const constant int& size_sorted_axis,
238-
const constant int& stride_sorted_axis,
239-
const constant int& stride_segment_axis,
238+
const constant int& in_stride_sorted_axis,
239+
const constant int& out_stride_sorted_axis,
240+
const constant int& in_stride_segment_axis,
241+
const constant int& out_stride_segment_axis,
240242
threadgroup val_t* tgp_vals,
241243
threadgroup idx_t* tgp_idxs,
242244
uint3 tid [[threadgroup_position_in_grid]],
243245
uint3 lid [[thread_position_in_threadgroup]]) {
244246
// tid.y tells us the segment index
245-
inp += tid.y * stride_segment_axis;
246-
out += tid.y * stride_segment_axis;
247+
inp += tid.y * in_stride_segment_axis;
248+
out += tid.y * out_stride_segment_axis;
247249

248250
// Copy into threadgroup memory
249251
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
250-
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis]
252+
tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis]
251253
: val_t(CompareOp::init);
252254
if (ARG_SORT) {
253255
tgp_idxs[i] = i;
@@ -264,9 +266,9 @@ struct KernelMergeSort {
264266
// Write output
265267
for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
266268
if (ARG_SORT) {
267-
out[i * stride_sorted_axis] = tgp_idxs[i];
269+
out[i * out_stride_sorted_axis] = tgp_idxs[i];
268270
} else {
269-
out[i * stride_sorted_axis] = tgp_vals[i];
271+
out[i * out_stride_sorted_axis] = tgp_vals[i];
270272
}
271273
}
272274
}
@@ -282,8 +284,10 @@ template <
282284
const device T* inp [[buffer(0)]],
283285
device U* out [[buffer(1)]],
284286
const constant int& size_sorted_axis [[buffer(2)]],
285-
const constant int& stride_sorted_axis [[buffer(3)]],
286-
const constant int& stride_segment_axis [[buffer(4)]],
287+
const constant int& in_stride_sorted_axis [[buffer(3)]],
288+
const constant int& out_stride_sorted_axis [[buffer(4)]],
289+
const constant int& in_stride_segment_axis [[buffer(5)]],
290+
const constant int& out_stride_segment_axis [[buffer(6)]],
287291
uint3 tid [[threadgroup_position_in_grid]],
288292
uint3 lid [[thread_position_in_threadgroup]]) {
289293
using sort_kernel =
@@ -298,8 +302,10 @@ template <
298302
inp,
299303
out,
300304
size_sorted_axis,
301-
stride_sorted_axis,
302-
stride_segment_axis,
305+
in_stride_sorted_axis,
306+
out_stride_sorted_axis,
307+
in_stride_segment_axis,
308+
out_stride_segment_axis,
303309
tgp_vals,
304310
tgp_idxs,
305311
tid,
@@ -310,8 +316,10 @@ template <
310316
inp,
311317
out,
312318
size_sorted_axis,
313-
stride_sorted_axis,
314-
stride_segment_axis,
319+
in_stride_sorted_axis,
320+
out_stride_sorted_axis,
321+
in_stride_segment_axis,
322+
out_stride_segment_axis,
315323
tgp_vals,
316324
nullptr,
317325
tid,
@@ -331,20 +339,23 @@ template <
331339
const device T* inp [[buffer(0)]],
332340
device U* out [[buffer(1)]],
333341
const constant int& size_sorted_axis [[buffer(2)]],
334-
const constant int& stride_sorted_axis [[buffer(3)]],
335-
const constant int& nc_dim [[buffer(4)]],
336-
const device int* nc_shape [[buffer(5)]],
337-
const device size_t* nc_strides [[buffer(6)]],
342+
const constant int& in_stride_sorted_axis [[buffer(3)]],
343+
const constant int& out_stride_sorted_axis [[buffer(4)]],
344+
const constant int& nc_dim [[buffer(5)]],
345+
const device int* nc_shape [[buffer(6)]],
346+
const device size_t* in_nc_strides [[buffer(7)]],
347+
const device size_t* out_nc_strides [[buffer(8)]],
338348
uint3 tid [[threadgroup_position_in_grid]],
339349
uint3 lid [[thread_position_in_threadgroup]]) {
340350
using sort_kernel =
341351
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
342352
using val_t = typename sort_kernel::val_t;
343353
using idx_t = typename sort_kernel::idx_t;
344354

345-
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
346-
inp += block_idx;
347-
out += block_idx;
355+
auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim);
356+
auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim);
357+
inp += in_block_idx;
358+
out += out_block_idx;
348359

349360
if (ARG_SORT) {
350361
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
@@ -353,7 +364,9 @@ template <
353364
inp,
354365
out,
355366
size_sorted_axis,
356-
stride_sorted_axis,
367+
in_stride_sorted_axis,
368+
out_stride_sorted_axis,
369+
zero_helper,
357370
zero_helper,
358371
tgp_vals,
359372
tgp_idxs,
@@ -365,7 +378,9 @@ template <
365378
inp,
366379
out,
367380
size_sorted_axis,
368-
stride_sorted_axis,
381+
in_stride_sorted_axis,
382+
out_stride_sorted_axis,
383+
zero_helper,
369384
zero_helper,
370385
tgp_vals,
371386
nullptr,

0 commit comments

Comments
 (0)