@@ -31,15 +31,24 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
31
31
32
32
layout (constant_id = 3 ) const int reduce_dim = 0 ;
33
33
34
+ $if VARIANCE_MODE:
35
+ #define VARIANCE_MODE
36
+
34
37
#define NWORKERS 4
35
38
#define MAX_THREADS 16
36
39
37
- shared T shared_sum[NWORKERS];
40
+ shared T shared_accum[NWORKERS];
41
+ #ifdef VARIANCE_MODE
38
42
shared T shared_sum_sq[NWORKERS];
39
43
shared int shared_count[NWORKERS];
44
+ #endif
40
45
41
46
#include "indexing_utils.h"
42
47
48
+ #define INIT_ACCUM(first_val) ${INIT_ACCUM}
49
+ #define UPDATE_ACCUM(accum, new_val) ${UPDATE_ACCUM}
50
+ #define POSTPROCESS(accum) ${POSTPROCESS}
51
+
43
52
void main() {
44
53
const ivec4 out_idx = ivec4 (
45
54
gl_GlobalInvocationID.x,
@@ -49,9 +58,11 @@ void main() {
49
58
50
59
const uint tid = gl_LocalInvocationID[reduce_dim];
51
60
52
- shared_sum[tid] = T(0 );
61
+ shared_accum[tid] = T(0 );
62
+ #ifdef VARIANCE_MODE
53
63
shared_sum_sq[tid] = T(0 );
54
64
shared_count[tid] = 0 ;
65
+ #endif
55
66
barrier();
56
67
57
68
const int R = in_sizes[reduce_dim];
@@ -65,9 +76,25 @@ void main() {
65
76
uint len = q + (tid < rem ? 1u : 0u);
66
77
uint base = tid * q + min (tid, rem);
67
78
68
- T sum = T(0 );
79
+ // Get the first value for initializing the accumulator if needed
80
+ T first_val = T(0 );
81
+ if (R > 0 ) {
82
+ ivec4 first_idx = out_idx;
83
+ first_idx[reduce_dim] = 0 ;
84
+
85
+ if (reduce_dim == 2 ) {
86
+ first_idx[reduce_dim + 1 ] = 0 ;
87
+ }
88
+
89
+ first_val = in_buf[tidx_to_bufi(first_idx, in_strides)];
90
+ }
91
+
92
+ // Initialize accumulator
93
+ T accum = INIT_ACCUM(first_val);
94
+ #ifdef VARIANCE_MODE
69
95
T sum_sq = T(0 );
70
96
int count = 0 ;
97
+ #endif
71
98
72
99
ivec4 in_idx = out_idx;
73
100
for (uint off = 0u; off < len; ++ off) {
@@ -83,39 +110,55 @@ void main() {
83
110
84
111
T v = in_buf[tidx_to_bufi(in_idx, in_strides)];
85
112
86
- sum += v;
113
+ accum = UPDATE_ACCUM(accum, v);
114
+
115
+ #ifdef VARIANCE_MODE
87
116
sum_sq += v * v;
88
117
count += 1 ;
118
+ #endif
89
119
}
90
120
91
- shared_sum[tid] = sum;
121
+ shared_accum[tid] = accum;
122
+ #ifdef VARIANCE_MODE
92
123
shared_sum_sq[tid] = sum_sq;
93
124
shared_count[tid] = count;
125
+ #endif
94
126
barrier();
95
127
96
128
if (tid == 0u) {
97
- T tot_sum = T(0 );
98
- T tot_sum_sq = T(0 );
99
- int tot_count = 0 ;
129
+ T result = shared_accum[0 ];
130
+
131
+ #ifdef VARIANCE_MODE
132
+ T tot_sum = shared_accum[0 ];
133
+ T tot_sum_sq = shared_sum_sq[0 ];
134
+ int tot_count = shared_count[0 ];
135
+ #endif
100
136
101
- for (uint i = 0 ; i < N; ++ i) {
102
- tot_sum += shared_sum[i];
137
+ for (uint i = 1 ; i < N; ++ i) {
138
+ #ifdef VARIANCE_MODE
139
+ tot_sum += shared_accum[i];
103
140
tot_sum_sq += shared_sum_sq[i];
104
141
tot_count += shared_count[i];
142
+ #else
143
+ result = UPDATE_ACCUM(result, shared_accum[i]);
144
+ #endif
105
145
}
106
146
107
- T var;
147
+ #ifdef VARIANCE_MODE
108
148
if (tot_count > 0 ) {
109
149
T mean = tot_sum / T(tot_count);
110
- var = (tot_sum_sq / T(tot_count)) - (mean * mean);
150
+ result = (tot_sum_sq / T(tot_count)) - (mean * mean);
111
151
if (pc.unbiased != 0 && tot_count > 1 ) {
112
- var *= T(tot_count) / T(tot_count - 1 );
152
+ result *= T(tot_count) / T(tot_count - 1 );
113
153
}
114
- } else {
154
+ } else {
115
155
// NaN to match PyTorch behavior
116
- var = T(0.0 / 0.0 );
156
+ result = T(0.0 / 0.0 );
117
157
}
158
+ #else
159
+ result = POSTPROCESS(result);
160
+ #endif
118
161
119
- out_buf[tidx_to_bufi(out_idx, out_strides)] = var ;
162
+ out_buf[tidx_to_bufi(out_idx, out_strides)] = result ;
120
163
}
121
164
}
0 commit comments