@@ -28,10 +28,8 @@ template <
2828 looped_elem_to_loc<NDIMS> loop;
2929 const device T* row;
3030
31- // Case 1:
32- // reduction_stride is small, reduction_size is small and non_col_reductions
33- // is small. Each thread computes reduction_stride outputs.
34- if (reduction_size * non_col_reductions < 64 ) {
31+ // Case 1: Small row small column
32+ if (reduction_size * non_col_reductions < 64 && reduction_stride < 32 ) {
3533 U totals[31 ];
3634 for (int i = 0 ; i < 31 ; i++) {
3735 totals[i] = Op::init;
@@ -71,10 +69,55 @@ template <
7169 }
7270 }
7371
74- // Case 2:
75- // Reduction stride is small but everything else can be big. We loop both
76- // across reduction size and non_col_reductions. Each simdgroup produces
77- // N_READS outputs.
72+ // Case 2: Long row small column
73+ else if (reduction_size * non_col_reductions < 32 ) {
74+ U totals[N_READS];
75+ for (int i = 0 ; i < N_READS; i++) {
76+ totals[i] = Op::init;
77+ }
78+
79+ short size = reduction_size;
80+ size_t offset = size_t (tid.x ) * N_READS;
81+ bool safe = offset + N_READS <= reduction_stride;
82+ short extra = reduction_stride - offset;
83+
84+ size_t out_idx = tid.y + tsize.z * size_t (tid.z );
85+ in += elem_to_loc (out_idx, shape, strides, ndim) + offset;
86+
87+ for (uint r = 0 ; r < non_col_reductions; r++) {
88+ row = in + loop.location (r, reduce_shape, reduce_strides, reduce_ndim);
89+
90+ if (safe) {
91+ for (short i = 0 ; i < size; i++) {
92+ for (short j = 0 ; j < N_READS; j++) {
93+ totals[j] =
94+ op (static_cast <U>(row[i * reduction_stride + j]), totals[j]);
95+ }
96+ }
97+ } else {
98+ for (short i = 0 ; i < size; i++) {
99+ for (short j = 0 ; j < extra; j++) {
100+ totals[j] =
101+ op (static_cast <U>(row[i * reduction_stride + j]), totals[j]);
102+ }
103+ }
104+ }
105+
106+ loop.next (reduce_shape, reduce_strides);
107+ }
108+ out += out_idx * reduction_stride + offset;
109+ if (safe) {
110+ for (short i = 0 ; i < N_READS; i++) {
111+ out[i] = totals[i];
112+ }
113+ } else {
114+ for (short i = 0 ; i < extra; i++) {
115+ out[i] = totals[i];
116+ }
117+ }
118+ }
119+
120+ // Case 3: Long row medium column
78121 else {
79122 threadgroup U shared_vals[1024 ];
80123 U totals[N_READS];
@@ -147,17 +190,13 @@ template <
147190/* *
148191 * Our approach is the following simple looped approach:
149192 * 1. Each thread keeps running totals for BN / n_simdgroups outputs.
150- * 2. Load a tile BM, BN in shared memory.
151- * 3. Add the values from shared memory to the current running totals.
152- * Neighboring threads access different rows (transposed acces) .
153- * 4. Move ahead to the next tile until the M axis is exhausted .
154- * 5. Move ahead to the next non column reduction
155- * 6. Simd reduce the running totals
193+ * 2. Load a tile BM, BN in registers and accumulate in the running totals
194+ * 3. Move ahead by BM steps until the column axis and the non column
195+ * reductions are exhausted .
196+ * 6. If BM == 32 then transpose in SM and simd reduce the running totals .
197+ * Otherwise write in shared memory and BN threads accumulate the running
198+ * totals with a loop.
156199 * 7. Write them to the output
157- *
158- * The kernel becomes verbose because we support all kinds of OOB checks. For
159- * instance if we choose that reduction_stride must be larger than BN then we
160- * can get rid of half the kernel.
161200 */
162201template <
163202 typename T,
0 commit comments