66
77using namespace metal ;
88
9+ // /////////////////////////////////////////////////////////////////////////////
10+ // Small column reduce kernel
11+ // /////////////////////////////////////////////////////////////////////////////
12+
13+ template <typename T, typename U, typename Op>
14+ [[kernel]] void col_reduce_small (
15+ const device T *in [[buffer(0 )]],
16+ device U *out [[buffer(1 )]],
17+ const constant size_t& reduction_size [[buffer(2 )]],
18+ const constant size_t& reduction_stride [[buffer(3 )]],
19+ const constant size_t& out_size [[buffer(4 )]],
20+ const constant int* shape [[buffer(5 )]],
21+ const constant size_t* strides [[buffer(6 )]],
22+ const constant int& ndim [[buffer(7 )]],
23+ const constant size_t& non_col_reductions [[buffer(8 )]],
24+ const constant int* non_col_shapes [[buffer(9 )]],
25+ const constant size_t* non_col_strides [[buffer(10 )]],
26+ const constant int& non_col_ndim [[buffer(11 )]],
27+ uint tid [[thread_position_in_grid]]) {
28+
29+ // Appease the compiler
30+ (void )out_size;
31+
32+ Op op;
33+ U total_val = Op::init;
34+
35+ auto out_idx = tid;
36+
37+ in += elem_to_loc (
38+ out_idx,
39+ shape + non_col_ndim,
40+ strides + non_col_ndim,
41+ ndim - non_col_ndim);
42+
43+ for (uint i = 0 ; i < non_col_reductions; i++) {
44+ size_t in_idx = elem_to_loc (i, non_col_shapes, non_col_strides, non_col_ndim);
45+
46+ for (uint j = 0 ; j < reduction_size; j++, in_idx += reduction_stride) {
47+ U val = static_cast <U>(in[in_idx]);
48+ total_val = op (total_val, val);
49+ }
50+ }
51+
52+ out[out_idx] = total_val;
53+ }
54+
55+ #define instantiate_col_reduce_small (name, itype, otype, op ) \
56+ template [[host_name(" col_reduce_small_" #name)]] \
57+ [[kernel]] void col_reduce_small<itype, otype, op>( \
58+ const device itype *in [[buffer(0 )]], \
59+ device otype *out [[buffer(1 )]], \
60+ const constant size_t & reduction_size [[buffer(2 )]], \
61+ const constant size_t & reduction_stride [[buffer(3 )]], \
62+ const constant size_t & out_size [[buffer(4 )]], \
63+ const constant int * shape [[buffer(5 )]], \
64+ const constant size_t * strides [[buffer(6 )]], \
65+ const constant int & ndim [[buffer(7 )]], \
66+ const constant size_t & non_col_reductions [[buffer(8 )]], \
67+ const constant int * non_col_shapes [[buffer(9 )]], \
68+ const constant size_t * non_col_strides [[buffer(10 )]], \
69+ const constant int & non_col_ndim [[buffer(11 )]], \
70+ uint tid [[thread_position_in_grid]]);
71+
972// /////////////////////////////////////////////////////////////////////////////
1073// Column reduce helper
1174// /////////////////////////////////////////////////////////////////////////////
@@ -171,14 +234,20 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
171234// /////////////////////////////////////////////////////////////////////////////
172235
173236#define instantiate_same_col_reduce_helper (name, tname, type, op ) \
237+ instantiate_col_reduce_small (name ##tname, type, type, op<type>) \
174238 instantiate_col_reduce_general(name ##tname, type, type, op<type>)
175239
176240#define instantiate_same_col_reduce_na_helper (name, tname, type, op ) \
241+ instantiate_col_reduce_small (name ##tname, type, type, op<type>) \
177242 instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
178243
179244instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
180245instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
181246
182247instantiate_col_reduce_general(sumbool_, bool , uint32_t , Sum<uint32_t >)
183248instantiate_reduce_from_types(instantiate_col_reduce_general, and , bool , And)
184- instantiate_reduce_from_types(instantiate_col_reduce_general, or , bool , Or)
249+ instantiate_reduce_from_types(instantiate_col_reduce_general, or , bool , Or)
250+
251+ instantiate_col_reduce_small(sumbool_, bool , uint32_t , Sum<uint32_t >)
252+ instantiate_reduce_from_types(instantiate_col_reduce_small, and , bool , And)
253+ instantiate_reduce_from_types(instantiate_col_reduce_small, or , bool , Or)
0 commit comments