@@ -23,12 +23,19 @@ ${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)}
23
23
${layout_declare_ubo(B, "ivec3 ", "tin_limits")}
24
24
${layout_declare_ubo(B, "ivec4 ", "tin_sizes")}
25
25
26
+ layout (push_constant) uniform PushConstants {
27
+ int unbiased;
28
+ } pc;
29
+
26
30
layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
27
31
28
32
layout (constant_id = 3 ) const int packed_dim = 0 ;
29
33
layout (constant_id = 4 ) const int reduce_dim = 0 ;
30
34
layout (constant_id = 5 ) const int group_dim = 1 ;
31
35
36
+ $if VARIANCE_MODE:
37
+ #define VARIANCE_MODE
38
+
32
39
// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of
33
40
// threads that will co-operate to compute one reduction output. There may be
34
41
// multiple groups computing distinct reduction outputs within one work group.
@@ -39,15 +46,29 @@ layout(constant_id = 5) const int group_dim = 1;
39
46
// work group will write into its assigned element in the shared array.
40
47
#define MAX_NTHREADS 16
41
48
42
-
43
49
shared vec4 shared_vecs[MAX_NTHREADS];
50
+ // Second accumulator for variance mode - used for sum of values, prev
51
+ // accumulator is used for sum of squares
52
+ shared vec4 shared_sum_sq[MAX_NTHREADS];
53
+ shared int shared_count[MAX_NTHREADS];
44
54
45
55
#include "indexing_utils.h"
46
56
47
57
int tid_to_smi(const ivec2 tid) {
48
58
return tid.x + tid.y * NWORKERS;
49
59
}
50
60
61
+ vec4 calculate_variance(vec4 sum, vec4 sum_sq, int count) {
62
+ vec4 mean = sum / float (count);
63
+ vec4 variance = (sum_sq / float (count)) - (mean * mean);
64
+
65
+ if ((pc.unbiased != 0 ) && (count > 1 )) {
66
+ variance = variance * (float (count) / float (count - 1.0 ));
67
+ }
68
+
69
+ return variance;
70
+ }
71
+
51
72
/*
52
73
* The functions below compute reduction along a single dimension for a tensor.
53
74
* The shader template generalize reduction by abstracting the initial value of
@@ -92,25 +113,48 @@ void reduce_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
92
113
scan_pos[reduce_dim] = 0 ;
93
114
vec4 accum = INIT_ACCUM(load_texel(tin, scan_pos));
94
115
116
+ #ifdef VARIANCE_MODE
117
+ vec4 sum_sq = VEC4_T(0 );
118
+ int count = 0 ;
119
+ #endif
120
+
95
121
scan_pos[reduce_dim] = tid.x;
96
122
// Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of
97
123
// the reduction row
98
124
for (int i = tid.x; i < tin_sizes[reduce_dim];
99
125
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
100
- accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos));
126
+ vec4 val = load_texel(tin, scan_pos);
127
+ accum = UPDATE_ACCUM(accum, val);
128
+ #ifdef VARIANCE_MODE
129
+ sum_sq += val * val;
130
+ count += 1 ;
131
+ #endif
101
132
}
102
133
// Write partial output to shared memory and synchronize work group
103
134
shared_vecs[smi] = accum;
135
+ #ifdef VARIANCE_MODE
136
+ shared_sum_sq[smi] = sum_sq;
137
+ shared_count[smi] = count;
138
+ #endif
104
139
barrier();
105
140
106
141
// Since the reduction row is reduced to only one element, only the "main"
107
142
// thread in the group needs aggregate the partial outputs
108
143
if (tid.x == 0 ) {
109
144
// Iterate over the partial outputs to obtain the overall output
110
145
int group_i = tid.y * NWORKERS;
111
- accum = shared_vecs[group_i++ ];
112
- for (int i = 1 ; i < NWORKERS; i++ , group_i++ ) {
113
- accum = UPDATE_ACCUM(accum, shared_vecs[group_i]);
146
+ accum = shared_vecs[group_i];
147
+ #ifdef VARIANCE_MODE
148
+ sum_sq = shared_sum_sq[group_i];
149
+ count = shared_count[group_i];
150
+ #endif
151
+ for (int i = 1 ; i < NWORKERS; i++ ) {
152
+ int idx = tid.y * NWORKERS + i;
153
+ accum = UPDATE_ACCUM(accum, shared_vecs[idx]);
154
+ #ifdef VARIANCE_MODE
155
+ sum_sq += shared_sum_sq[idx];
156
+ count += shared_count[idx];
157
+ #endif
114
158
}
115
159
116
160
// Determine if there are any padding elements in the final texel of the
@@ -121,14 +165,27 @@ void reduce_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
121
165
const bool is_last_texel =
122
166
scan_pos[packed_dim] == (tin_limits[packed_dim] - 1 );
123
167
168
+ #ifdef VARIANCE_MODE
169
+ vec4 variance = calculate_variance(accum, sum_sq, count);
170
+ #endif
171
+
124
172
// Explicitly set padding elements to 0
125
173
if (is_last_texel && nspill > 0 ) {
126
174
[[unroll]] for (int i = nspill; i < 4 ; i++ ) {
175
+ #ifdef VARIANCE_MODE
176
+ variance[i] = 0 ;
177
+ #else
127
178
accum[i] = 0 ;
179
+ #endif
128
180
}
129
181
}
182
+
130
183
scan_pos[reduce_dim] = tid.x;
184
+ #ifdef VARIANCE_MODE
185
+ write_texel(tout, scan_pos, variance);
186
+ #else
131
187
write_texel(tout, scan_pos, POSTPROCESS(accum));
188
+ #endif
132
189
}
133
190
}
134
191
@@ -153,35 +210,78 @@ void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) {
153
210
scan_pos[reduce_dim] = 0 ;
154
211
vec4 accum = INIT_ACCUM(vec4 (load_texel(tin, scan_pos).x));
155
212
213
+ #ifdef VARIANCE_MODE
214
+ vec4 sum_sq = VEC4_T(0 );
215
+ int count = 0 ;
216
+ #endif
217
+
156
218
// Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of
157
219
// the reduction row
158
220
scan_pos[reduce_dim] = tid.x;
159
221
for (int i = tid.x * 4 ; i < reduce_len;
160
222
i += NWORKERS * 4 , scan_pos[reduce_dim] += NWORKERS) {
161
- accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos));
223
+ vec4 val = load_texel(tin, scan_pos);
224
+ accum = UPDATE_ACCUM(accum, val);
225
+ #ifdef VARIANCE_MODE
226
+ sum_sq += val * val;
227
+ count += 4 ; // Each texel has 4 elements
228
+ #endif
162
229
}
163
230
// For the last texel in the dim, if there are padding elements then each
164
231
// element of the texel needs to be processed individually such that the
165
232
// padding elements are ignored
166
233
if (scan_pos[reduce_dim] == tin_limits[reduce_dim] - 1 && nspill > 0 ) {
167
- const vec4 intex = load_texel(tin, scan_pos);
234
+ const vec4 val = load_texel(tin, scan_pos);
168
235
for (int i = 0 ; i < nspill; i++ ) {
169
- accum.x = UPDATE_ACCUM(accum.x, intex[i]);
236
+ accum.x = UPDATE_ACCUM(accum.x, val[i]);
237
+ #ifdef VARIANCE_MODE
238
+ sum_sq.x += val[i] * val[i];
239
+ count += 1 ;
240
+ #endif
170
241
}
171
242
}
172
243
// Write partial output to shared memory and synchronize work group
173
244
shared_vecs[smi] = accum;
245
+ #ifdef VARIANCE_MODE
246
+ shared_sum_sq[smi] = sum_sq;
247
+ shared_count[smi] = count;
248
+ #endif
174
249
barrier();
175
250
176
251
// Since the reduction row is reduced to only one element, only the "main"
177
252
// thread in the group needs aggregate the partial outputs
178
253
if (tid.x == 0 ) {
179
254
// Iterate over the partial maximums to obtain the overall maximum
180
255
int group_i = tid.y * NWORKERS;
181
- accum = shared_vecs[group_i++ ];
256
+ accum = shared_vecs[group_i];
257
+ #ifdef VARIANCE_MODE
258
+ sum_sq = shared_sum_sq[group_i];
259
+ count = shared_count[group_i];
260
+ #endif
182
261
for (int i = 1 ; i < NWORKERS; i++ , group_i++ ) {
183
- accum = UPDATE_ACCUM(accum, shared_vecs[group_i]);
262
+ int idx = tid.y * NWORKERS + i;
263
+ accum = UPDATE_ACCUM(accum, shared_vecs[idx]);
264
+ #ifdef VARIANCE_MODE
265
+ sum_sq += shared_sum_sq[idx];
266
+ count += shared_count[idx];
267
+ #endif
184
268
}
269
+
270
+ #ifdef VARIANCE_MODE
271
+ float total_sum = accum.x + accum.y + accum.z + accum.w;
272
+ float total_sum_sq = sum_sq.x + sum_sq.y + sum_sq.z + sum_sq.w;
273
+ int total_count = count;
274
+
275
+ float mean = total_sum / float (total_count);
276
+ float variance = (total_sum_sq / float (total_count)) - (mean * mean);
277
+
278
+ if ((pc.unbiased != 0 ) && (total_count > 1 )) {
279
+ variance = variance * (float (total_count) / float (total_count - 1.0 ));
280
+ }
281
+
282
+ scan_pos[reduce_dim] = tid.x;
283
+ write_texel(tout, scan_pos, vec4 (variance, 0 , 0 , 0 ));
284
+ #else
185
285
// Each element of the texel is itself a partial maximum; iterate over the
186
286
// texel to find the actual maximum
187
287
float accum_final = accum.x;
@@ -191,6 +291,7 @@ void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) {
191
291
192
292
scan_pos[reduce_dim] = tid.x;
193
293
write_texel(tout, scan_pos, POSTPROCESS(vec4 (accum_final, 0 , 0 , 0 )));
294
+ #endif
194
295
}
195
296
}
196
297
0 commit comments