Skip to content

Commit db712de

Browse files
author
morelos
committed
[ET-VK][Ops] aten.var.dim in reduce
Incorporated variance logic into reduce by adding additional logic Differential Revision: [D75247432](https://our.internmc.facebook.com/intern/diff/D75247432/) ghstack-source-id: 286717611 Pull Request resolved: #11198
1 parent 528b876 commit db712de

File tree

10 files changed

+369
-531
lines changed

10 files changed

+369
-531
lines changed

backends/vulkan/runtime/graph/ops/glsl/var_buffer.glsl renamed to backends/vulkan/runtime/graph/ops/glsl/reduce_buffer.glsl

Lines changed: 59 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,24 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3131

3232
layout(constant_id = 3) const int reduce_dim = 0;
3333

34+
$if VARIANCE_MODE:
35+
#define VARIANCE_MODE
36+
3437
#define NWORKERS 4
3538
#define MAX_THREADS 16
3639

37-
shared T shared_sum[NWORKERS];
40+
shared T shared_accum[NWORKERS];
41+
#ifdef VARIANCE_MODE
3842
shared T shared_sum_sq[NWORKERS];
3943
shared int shared_count[NWORKERS];
44+
#endif
4045

4146
#include "indexing_utils.h"
4247

48+
#define INIT_ACCUM(first_val) ${INIT_ACCUM}
49+
#define UPDATE_ACCUM(accum, new_val) ${UPDATE_ACCUM}
50+
#define POSTPROCESS(accum) ${POSTPROCESS}
51+
4352
void main() {
4453
const ivec4 out_idx = ivec4(
4554
gl_GlobalInvocationID.x,
@@ -49,9 +58,11 @@ void main() {
4958

5059
const uint tid = gl_LocalInvocationID[reduce_dim];
5160

52-
shared_sum[tid] = T(0);
61+
shared_accum[tid] = T(0);
62+
#ifdef VARIANCE_MODE
5363
shared_sum_sq[tid] = T(0);
5464
shared_count[tid] = 0;
65+
#endif
5566
barrier();
5667

5768
const int R = in_sizes[reduce_dim];
@@ -65,9 +76,25 @@ void main() {
6576
uint len = q + (tid < rem ? 1u : 0u);
6677
uint base = tid * q + min(tid, rem);
6778

68-
T sum = T(0);
79+
// Get the first value for initializing the accumulator if needed
80+
T first_val = T(0);
81+
if (R > 0) {
82+
ivec4 first_idx = out_idx;
83+
first_idx[reduce_dim] = 0;
84+
85+
if (reduce_dim == 2) {
86+
first_idx[reduce_dim + 1] = 0;
87+
}
88+
89+
first_val = in_buf[tidx_to_bufi(first_idx, in_strides)];
90+
}
91+
92+
// Initialize accumulator
93+
T accum = INIT_ACCUM(first_val);
94+
#ifdef VARIANCE_MODE
6995
T sum_sq = T(0);
7096
int count = 0;
97+
#endif
7198

7299
ivec4 in_idx = out_idx;
73100
for (uint off = 0u; off < len; ++off) {
@@ -83,39 +110,55 @@ void main() {
83110

84111
T v = in_buf[tidx_to_bufi(in_idx, in_strides)];
85112

86-
sum += v;
113+
accum = UPDATE_ACCUM(accum, v);
114+
115+
#ifdef VARIANCE_MODE
87116
sum_sq += v * v;
88117
count += 1;
118+
#endif
89119
}
90120

91-
shared_sum[tid] = sum;
121+
shared_accum[tid] = accum;
122+
#ifdef VARIANCE_MODE
92123
shared_sum_sq[tid] = sum_sq;
93124
shared_count[tid] = count;
125+
#endif
94126
barrier();
95127

96128
if (tid == 0u) {
97-
T tot_sum = T(0);
98-
T tot_sum_sq = T(0);
99-
int tot_count = 0;
129+
T result = shared_accum[0];
130+
131+
#ifdef VARIANCE_MODE
132+
T tot_sum = shared_accum[0];
133+
T tot_sum_sq = shared_sum_sq[0];
134+
int tot_count = shared_count[0];
135+
#endif
100136

101-
for (uint i = 0; i < N; ++i) {
102-
tot_sum += shared_sum[i];
137+
for (uint i = 1; i < N; ++i) {
138+
#ifdef VARIANCE_MODE
139+
tot_sum += shared_accum[i];
103140
tot_sum_sq += shared_sum_sq[i];
104141
tot_count += shared_count[i];
142+
#else
143+
result = UPDATE_ACCUM(result, shared_accum[i]);
144+
#endif
105145
}
106146

107-
T var;
147+
#ifdef VARIANCE_MODE
108148
if (tot_count > 0) {
109149
T mean = tot_sum / T(tot_count);
110-
var = (tot_sum_sq / T(tot_count)) - (mean * mean);
150+
result = (tot_sum_sq / T(tot_count)) - (mean * mean);
111151
if (pc.unbiased != 0 && tot_count > 1) {
112-
var *= T(tot_count) / T(tot_count - 1);
152+
result *= T(tot_count) / T(tot_count - 1);
113153
}
114-
} else{
154+
} else {
115155
// NaN to match PyTorch behavior
116-
var = T(0.0/0.0);
156+
result = T(0.0/0.0);
117157
}
158+
#else
159+
result = POSTPROCESS(result);
160+
#endif
118161

119-
out_buf[tidx_to_bufi(out_idx, out_strides)] = var;
162+
out_buf[tidx_to_bufi(out_idx, out_strides)] = result;
120163
}
121164
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
reduce_buffer:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: buffer
11+
INIT_ACCUM: T(0)
12+
UPDATE_ACCUM: accum + new_val
13+
POSTPROCESS: accum
14+
VARIANCE_MODE: false
15+
generate_variant_forall:
16+
DTYPE:
17+
- VALUE: half
18+
- VALUE: float
19+
shader_variants:
20+
- NAME: sum_buffer
21+
- NAME: mean_buffer
22+
POSTPROCESS: (accum / T(in_sizes[reduce_dim]))
23+
- NAME: amax_buffer
24+
INIT_ACCUM: first_val
25+
UPDATE_ACCUM: max(accum, new_val)
26+
POSTPROCESS: accum
27+
- NAME: amin_buffer
28+
INIT_ACCUM: first_val
29+
UPDATE_ACCUM: min(accum, new_val)
30+
POSTPROCESS: accum
31+
- NAME: var_buffer
32+
VARIANCE_MODE: true

backends/vulkan/runtime/graph/ops/glsl/reduce.glsl renamed to backends/vulkan/runtime/graph/ops/glsl/reduce_texture3d.glsl

Lines changed: 111 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,19 @@ ${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)}
2323
${layout_declare_ubo(B, "ivec3", "tin_limits")}
2424
${layout_declare_ubo(B, "ivec4", "tin_sizes")}
2525

26+
layout(push_constant) uniform PushConstants {
27+
int unbiased;
28+
} pc;
29+
2630
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2731

2832
layout(constant_id = 3) const int packed_dim = 0;
2933
layout(constant_id = 4) const int reduce_dim = 0;
3034
layout(constant_id = 5) const int group_dim = 1;
3135

36+
$if VARIANCE_MODE:
37+
#define VARIANCE_MODE
38+
3239
// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of
3340
// threads that will co-operate to compute one reduction output. There may be
3441
// multiple groups computing distinct reduction outputs within one work group.
@@ -39,15 +46,29 @@ layout(constant_id = 5) const int group_dim = 1;
3946
// work group will write into its assigned element in the shared array.
4047
#define MAX_NTHREADS 16
4148

42-
4349
shared vec4 shared_vecs[MAX_NTHREADS];
50+
// Second accumulator for variance mode - used for sum of values, prev
51+
// accumulator is used for sum of squares
52+
shared vec4 shared_sum_sq[MAX_NTHREADS];
53+
shared int shared_count[MAX_NTHREADS];
4454

4555
#include "indexing_utils.h"
4656

4757
int tid_to_smi(const ivec2 tid) {
4858
return tid.x + tid.y * NWORKERS;
4959
}
5060

61+
vec4 calculate_variance(vec4 sum, vec4 sum_sq, int count) {
62+
vec4 mean = sum / float(count);
63+
vec4 variance = (sum_sq / float(count)) - (mean * mean);
64+
65+
if ((pc.unbiased != 0) && (count > 1)) {
66+
variance = variance * (float(count) / float(count - 1.0));
67+
}
68+
69+
return variance;
70+
}
71+
5172
/*
5273
* The functions below compute reduction along a single dimension for a tensor.
5374
* The shader template generalize reduction by abstracting the initial value of
@@ -92,25 +113,48 @@ void reduce_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
92113
scan_pos[reduce_dim] = 0;
93114
vec4 accum = INIT_ACCUM(load_texel(tin, scan_pos));
94115

116+
#ifdef VARIANCE_MODE
117+
vec4 sum_sq = VEC4_T(0);
118+
int count = 0;
119+
#endif
120+
95121
scan_pos[reduce_dim] = tid.x;
96122
// Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of
97123
// the reduction row
98124
for (int i = tid.x; i < tin_sizes[reduce_dim];
99125
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
100-
accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos));
126+
vec4 val = load_texel(tin, scan_pos);
127+
accum = UPDATE_ACCUM(accum, val);
128+
#ifdef VARIANCE_MODE
129+
sum_sq += val * val;
130+
count += 1;
131+
#endif
101132
}
102133
// Write partial output to shared memory and synchronize work group
103134
shared_vecs[smi] = accum;
135+
#ifdef VARIANCE_MODE
136+
shared_sum_sq[smi] = sum_sq;
137+
shared_count[smi] = count;
138+
#endif
104139
barrier();
105140

106141
// Since the reduction row is reduced to only one element, only the "main"
107142
// thread in the group needs aggregate the partial outputs
108143
if (tid.x == 0) {
109144
// Iterate over the partial outputs to obtain the overall output
110145
int group_i = tid.y * NWORKERS;
111-
accum = shared_vecs[group_i++];
112-
for (int i = 1; i < NWORKERS; i++, group_i++) {
113-
accum = UPDATE_ACCUM(accum, shared_vecs[group_i]);
146+
accum = shared_vecs[group_i];
147+
#ifdef VARIANCE_MODE
148+
sum_sq = shared_sum_sq[group_i];
149+
count = shared_count[group_i];
150+
#endif
151+
for (int i = 1; i < NWORKERS; i++) {
152+
int idx = tid.y * NWORKERS + i;
153+
accum = UPDATE_ACCUM(accum, shared_vecs[idx]);
154+
#ifdef VARIANCE_MODE
155+
sum_sq += shared_sum_sq[idx];
156+
count += shared_count[idx];
157+
#endif
114158
}
115159

116160
// Determine if there are any padding elements in the final texel of the
@@ -121,14 +165,27 @@ void reduce_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
121165
const bool is_last_texel =
122166
scan_pos[packed_dim] == (tin_limits[packed_dim] - 1);
123167

168+
#ifdef VARIANCE_MODE
169+
vec4 variance = calculate_variance(accum, sum_sq, count);
170+
#endif
171+
124172
// Explicitly set padding elements to 0
125173
if (is_last_texel && nspill > 0) {
126174
[[unroll]] for (int i = nspill; i < 4; i++) {
175+
#ifdef VARIANCE_MODE
176+
variance[i] = 0;
177+
#else
127178
accum[i] = 0;
179+
#endif
128180
}
129181
}
182+
130183
scan_pos[reduce_dim] = tid.x;
184+
#ifdef VARIANCE_MODE
185+
write_texel(tout, scan_pos, variance);
186+
#else
131187
write_texel(tout, scan_pos, POSTPROCESS(accum));
188+
#endif
132189
}
133190
}
134191

@@ -153,35 +210,78 @@ void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) {
153210
scan_pos[reduce_dim] = 0;
154211
vec4 accum = INIT_ACCUM(vec4(load_texel(tin, scan_pos).x));
155212

213+
#ifdef VARIANCE_MODE
214+
vec4 sum_sq = VEC4_T(0);
215+
int count = 0;
216+
#endif
217+
156218
// Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of
157219
// the reduction row
158220
scan_pos[reduce_dim] = tid.x;
159221
for (int i = tid.x * 4; i < reduce_len;
160222
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
161-
accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos));
223+
vec4 val = load_texel(tin, scan_pos);
224+
accum = UPDATE_ACCUM(accum, val);
225+
#ifdef VARIANCE_MODE
226+
sum_sq += val * val;
227+
count += 4; // Each texel has 4 elements
228+
#endif
162229
}
163230
// For the last texel in the dim, if there are padding elements then each
164231
// element of the texel needs to be processed individually such that the
165232
// padding elements are ignored
166233
if (scan_pos[reduce_dim] == tin_limits[reduce_dim] - 1 && nspill > 0) {
167-
const vec4 intex = load_texel(tin, scan_pos);
234+
const vec4 val = load_texel(tin, scan_pos);
168235
for (int i = 0; i < nspill; i++) {
169-
accum.x = UPDATE_ACCUM(accum.x, intex[i]);
236+
accum.x = UPDATE_ACCUM(accum.x, val[i]);
237+
#ifdef VARIANCE_MODE
238+
sum_sq.x += val[i] * val[i];
239+
count += 1;
240+
#endif
170241
}
171242
}
172243
// Write partial output to shared memory and synchronize work group
173244
shared_vecs[smi] = accum;
245+
#ifdef VARIANCE_MODE
246+
shared_sum_sq[smi] = sum_sq;
247+
shared_count[smi] = count;
248+
#endif
174249
barrier();
175250

176251
// Since the reduction row is reduced to only one element, only the "main"
177252
// thread in the group needs aggregate the partial outputs
178253
if (tid.x == 0) {
179254
// Iterate over the partial maximums to obtain the overall maximum
180255
int group_i = tid.y * NWORKERS;
181-
accum = shared_vecs[group_i++];
256+
accum = shared_vecs[group_i];
257+
#ifdef VARIANCE_MODE
258+
sum_sq = shared_sum_sq[group_i];
259+
count = shared_count[group_i];
260+
#endif
182261
for (int i = 1; i < NWORKERS; i++, group_i++) {
183-
accum = UPDATE_ACCUM(accum, shared_vecs[group_i]);
262+
int idx = tid.y * NWORKERS + i;
263+
accum = UPDATE_ACCUM(accum, shared_vecs[idx]);
264+
#ifdef VARIANCE_MODE
265+
sum_sq += shared_sum_sq[idx];
266+
count += shared_count[idx];
267+
#endif
184268
}
269+
270+
#ifdef VARIANCE_MODE
271+
float total_sum = accum.x + accum.y + accum.z + accum.w;
272+
float total_sum_sq = sum_sq.x + sum_sq.y + sum_sq.z + sum_sq.w;
273+
int total_count = count;
274+
275+
float mean = total_sum / float(total_count);
276+
float variance = (total_sum_sq / float(total_count)) - (mean * mean);
277+
278+
if ((pc.unbiased != 0) && (total_count > 1)) {
279+
variance = variance * (float(total_count) / float(total_count - 1.0));
280+
}
281+
282+
scan_pos[reduce_dim] = tid.x;
283+
write_texel(tout, scan_pos, vec4(variance, 0, 0, 0));
284+
#else
185285
// Each element of the texel is itself a partial maximum; iterate over the
186286
// texel to find the actual maximum
187287
float accum_final = accum.x;
@@ -191,6 +291,7 @@ void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) {
191291

192292
scan_pos[reduce_dim] = tid.x;
193293
write_texel(tout, scan_pos, POSTPROCESS(vec4(accum_final, 0, 0, 0)));
294+
#endif
194295
}
195296
}
196297

0 commit comments

Comments
 (0)