Skip to content

Commit 248431e

Browse files
authored
Reductions update (#1351)
1 parent 76f275b commit 248431e

File tree

8 files changed

+453
-197
lines changed

8 files changed

+453
-197
lines changed

benchmarks/python/comparative/bench_mlx.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,13 @@ def reduction(op, axis, x):
144144
mx.eval(ys)
145145

146146

147+
def sum_and_add(axis, x, y):
148+
z = x.sum(axis=axis, keepdims=True)
149+
for i in range(50):
150+
z = (z + y).sum(axis=axis, keepdims=True)
151+
mx.eval(z)
152+
153+
147154
def softmax(axis, x):
148155
ys = []
149156
for i in range(100):
@@ -505,5 +512,8 @@ def selu(x):
505512
elif args.benchmark == "selu":
506513
print(bench(selu, x))
507514

515+
elif args.benchmark == "sum_and_add":
516+
print(bench(sum_and_add, axis, *xs))
517+
508518
else:
509519
raise ValueError("Unknown benchmark")

mlx/backend/metal/jit_kernels.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,16 +319,18 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
319319
MTL::ComputePipelineState* get_reduce_init_kernel(
320320
metal::Device& d,
321321
const std::string& kernel_name,
322+
const std::string& func_name,
323+
const std::string& op_name,
322324
const array& out) {
323325
auto lib = d.get_library(kernel_name, [&]() {
324326
std::ostringstream kernel_source;
325-
std::string op_type = op_name(out);
326-
op_type[0] = std::toupper(op_name(out)[0]);
327+
std::string op_type = op_name;
328+
op_type[0] = std::toupper(op_name[0]);
327329
auto out_type = get_type_string(out.dtype());
328330
std::string op = op_type + "<" + out_type + ">";
329331
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
330332
kernel_source << get_template_definition(
331-
kernel_name, "init_reduce", out_type, op);
333+
kernel_name, func_name, out_type, op);
332334
return kernel_source.str();
333335
});
334336
return d.get_kernel(kernel_name, lib);

mlx/backend/metal/kernels.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
7979
MTL::ComputePipelineState* get_reduce_init_kernel(
8080
metal::Device& d,
8181
const std::string& kernel_name,
82+
const std::string& func_name,
83+
const std::string& op_name,
8284
const array& out);
8385

8486
MTL::ComputePipelineState* get_reduce_kernel(

mlx/backend/metal/kernels/reduce.metal

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,19 +113,27 @@ instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
113113
// special case bool with larger output type
114114
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
115115

116-
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
117-
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
118-
col_reduce_small, \
116+
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
117+
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
118+
col_reduce_small, \
119+
itype, otype, op, dim) \
120+
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
121+
col_reduce_longcolumn, \
119122
itype, otype, op, dim)
120123

121124
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
122125
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
123126
col_reduce_looped, \
124127
itype, otype, op, dim, bm, bn)
125128

129+
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
130+
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
131+
col_reduce_2pass, \
132+
itype, otype, op, dim, bm, bn)
133+
126134
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
127-
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \
128-
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32)
135+
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \
136+
instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32)
129137

130138
#define instantiate_col_reduce_general(name, itype, otype, op) \
131139
instantiate_col_reduce_small(name, itype, otype, op, 0) \

0 commit comments

Comments
 (0)