Skip to content

Commit b4f21c4

Browse files
author
morelos
committed
[ET-VK][Ops] aten.var.dim in reduce
Pull Request resolved: #11198 Incorporated variance logic into reduce by adding additional logic ghstack-source-id: 286745109 Differential Revision: [D75247432](https://our.internmc.facebook.com/intern/diff/D75247432/)
1 parent 5d70f27 commit b4f21c4

File tree

10 files changed

+363
-536
lines changed

10 files changed

+363
-536
lines changed

backends/vulkan/runtime/graph/ops/glsl/var_buffer.glsl renamed to backends/vulkan/runtime/graph/ops/glsl/reduce_buffer.glsl

Lines changed: 59 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,24 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3131

3232
layout(constant_id = 3) const int reduce_dim = 0;
3333

34+
$if VARIANCE_MODE:
35+
#define VARIANCE_MODE
36+
3437
#define NWORKERS 4
3538
#define MAX_THREADS 16
3639

37-
shared T shared_sum[NWORKERS];
40+
shared T shared_accum[NWORKERS];
41+
#ifdef VARIANCE_MODE
3842
shared T shared_sum_sq[NWORKERS];
3943
shared int shared_count[NWORKERS];
44+
#endif
4045

4146
#include "indexing_utils.h"
4247

48+
#define INIT_ACCUM(first_val) ${INIT_ACCUM}
49+
#define UPDATE_ACCUM(accum, new_val) ${UPDATE_ACCUM}
50+
#define POSTPROCESS(accum) ${POSTPROCESS}
51+
4352
void main() {
4453
const ivec4 out_idx = ivec4(
4554
gl_GlobalInvocationID.x,
@@ -49,9 +58,11 @@ void main() {
4958

5059
const uint tid = gl_LocalInvocationID[reduce_dim];
5160

52-
shared_sum[tid] = T(0);
61+
shared_accum[tid] = T(0);
62+
#ifdef VARIANCE_MODE
5363
shared_sum_sq[tid] = T(0);
5464
shared_count[tid] = 0;
65+
#endif
5566
barrier();
5667

5768
const int R = in_sizes[reduce_dim];
@@ -65,9 +76,25 @@ void main() {
6576
uint len = q + (tid < rem ? 1u : 0u);
6677
uint base = tid * q + min(tid, rem);
6778

68-
T sum = T(0);
79+
// Get the first value for initializing the accumulator if needed
80+
T first_val = T(0);
81+
if (R > 0) {
82+
ivec4 first_idx = out_idx;
83+
first_idx[reduce_dim] = 0;
84+
85+
if (reduce_dim == 2) {
86+
first_idx[reduce_dim + 1] = 0;
87+
}
88+
89+
first_val = in_buf[tidx_to_bufi(first_idx, in_strides)];
90+
}
91+
92+
// Initialize accumulator
93+
T accum = INIT_ACCUM(first_val);
94+
#ifdef VARIANCE_MODE
6995
T sum_sq = T(0);
7096
int count = 0;
97+
#endif
7198

7299
ivec4 in_idx = out_idx;
73100
for (uint off = 0u; off < len; ++off) {
@@ -83,39 +110,55 @@ void main() {
83110

84111
T v = in_buf[tidx_to_bufi(in_idx, in_strides)];
85112

86-
sum += v;
113+
accum = UPDATE_ACCUM(accum, v);
114+
115+
#ifdef VARIANCE_MODE
87116
sum_sq += v * v;
88117
count += 1;
118+
#endif
89119
}
90120

91-
shared_sum[tid] = sum;
121+
shared_accum[tid] = accum;
122+
#ifdef VARIANCE_MODE
92123
shared_sum_sq[tid] = sum_sq;
93124
shared_count[tid] = count;
125+
#endif
94126
barrier();
95127

96128
if (tid == 0u) {
97-
T tot_sum = T(0);
98-
T tot_sum_sq = T(0);
99-
int tot_count = 0;
129+
T result = shared_accum[0];
130+
131+
#ifdef VARIANCE_MODE
132+
T tot_sum = shared_accum[0];
133+
T tot_sum_sq = shared_sum_sq[0];
134+
int tot_count = shared_count[0];
135+
#endif
100136

101-
for (uint i = 0; i < N; ++i) {
102-
tot_sum += shared_sum[i];
137+
for (uint i = 1; i < N; ++i) {
138+
#ifdef VARIANCE_MODE
139+
tot_sum += shared_accum[i];
103140
tot_sum_sq += shared_sum_sq[i];
104141
tot_count += shared_count[i];
142+
#else
143+
result = UPDATE_ACCUM(result, shared_accum[i]);
144+
#endif
105145
}
106146

107-
T var;
147+
#ifdef VARIANCE_MODE
108148
if (tot_count > 0) {
109149
T mean = tot_sum / T(tot_count);
110-
var = (tot_sum_sq / T(tot_count)) - (mean * mean);
150+
result = (tot_sum_sq / T(tot_count)) - (mean * mean);
111151
if (pc.unbiased != 0 && tot_count > 1) {
112-
var *= T(tot_count) / T(tot_count - 1);
152+
result *= T(tot_count) / T(tot_count - 1);
113153
}
114-
} else{
154+
} else {
115155
// NaN to match PyTorch behavior
116-
var = T(0.0/0.0);
156+
result = T(0.0/0.0);
117157
}
158+
#else
159+
result = POSTPROCESS(result);
160+
#endif
118161

119-
out_buf[tidx_to_bufi(out_idx, out_strides)] = var;
162+
out_buf[tidx_to_bufi(out_idx, out_strides)] = result;
120163
}
121164
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
reduce_buffer:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: buffer
11+
INIT_ACCUM: T(0)
12+
UPDATE_ACCUM: accum + new_val
13+
POSTPROCESS: accum
14+
VARIANCE_MODE: false
15+
generate_variant_forall:
16+
DTYPE:
17+
- VALUE: half
18+
- VALUE: float
19+
shader_variants:
20+
- NAME: sum_buffer
21+
- NAME: mean_buffer
22+
POSTPROCESS: (accum / T(in_sizes[reduce_dim]))
23+
- NAME: amax_buffer
24+
INIT_ACCUM: first_val
25+
UPDATE_ACCUM: max(accum, new_val)
26+
POSTPROCESS: accum
27+
- NAME: amin_buffer
28+
INIT_ACCUM: first_val
29+
UPDATE_ACCUM: min(accum, new_val)
30+
POSTPROCESS: accum
31+
- NAME: var_buffer
32+
VARIANCE_MODE: true

0 commit comments

Comments
 (0)