Skip to content

Commit 3dcf9b7

Browse files
[ET-VK] Removed shared memory usage and simplied conv2d dw op shader to improve performance. (#11270)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11178 by @trivedivivek ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/102/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/102/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/101/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/102/orig @diff-train-skip-merge --------- Co-authored-by: Vivek Trivedi <[email protected]>
1 parent 91fbd59 commit 3dcf9b7

File tree

1 file changed

+9
-19
lines changed

1 file changed

+9
-19
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,6 @@ layout(push_constant) uniform restrict Block {
4747

4848
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4949

50-
// For performance improvement, reduce register usage by caching positions in shared memory.
51-
// Offset index by 1 every 16 points to avoid bank access conflict.
52-
#define offset_pos_index(index) (index + ((index) >> 4))
53-
shared ivec3 pos_shared[offset_pos_index(LOCAL_WG_SIZE)];
54-
5550
/*
5651
* Computes a depthwise convolution. Each shader invocation calculates the
5752
* output at a single output location.
@@ -77,8 +72,6 @@ void main() {
7772
return;
7873
}
7974

80-
pos_shared[offset_pos_index(gl_LocalInvocationIndex)] = pos;
81-
8275
// Compute the index of the top-left element of the overlay region. Negative
8376
// indices indicate that the top-left element is in a region added by padding.
8477
const ivec2 ipos = pos.xy * stride - padding;
@@ -89,13 +82,10 @@ void main() {
8982
const ivec2 end = ipos + overlay_region.xy;
9083

9184
// sum outputs
92-
VEC4_T sum[BATCH_SIZE_Y][BATCH_SIZE_X];
85+
VEC4_T sum[BATCH_SIZE_Y * BATCH_SIZE_X];
9386

94-
sum[0][0] = texelFetch(t_bias, ivec2(pos.z, 0), 0);
95-
for (int y = 0; y < BATCH_SIZE_Y; y++) {
96-
for (int x = 0; x < BATCH_SIZE_X; x++) {
97-
sum[y][x] = sum[0][0];
98-
}
87+
for (int i = 0; i < BATCH_SIZE_Y * BATCH_SIZE_X; i++) {
88+
sum[i] = VEC4_T(0);
9989
}
10090

10191
// array to store input texels
@@ -115,7 +105,7 @@ void main() {
115105
if (i > 0) {
116106
for (int j = 0; j < TILE_SIZE; j++) {
117107
for (int s = 0; s < BATCH_SIZE_X; s++) {
118-
sum[1][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[1][s]);
108+
sum[BATCH_SIZE_X + s] = fma(in_texels[j + s], prev_kernel_line[j], sum[BATCH_SIZE_X + s]);
119109
}
120110
}
121111
}
@@ -125,19 +115,19 @@ void main() {
125115
for (int j = 0; j < TILE_SIZE; j++, kx++) {
126116
prev_kernel_line[j] = texelFetch(t_kernel, ivec2(kx, pos.z), 0);
127117
for (int s = 0; s < BATCH_SIZE_X; s++) {
128-
sum[0][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[0][s]);
118+
sum[s] = fma(in_texels[j + s], prev_kernel_line[j], sum[s]);
129119
}
130120
}
131121
}
132122
}
133123

134-
const ivec3 out_pos = pos_shared[offset_pos_index(gl_LocalInvocationIndex)];
124+
const VEC4_T bias = texelFetch(t_bias, ivec2(pos.z, 0), 0);
135125
for (int y = 0; y < BATCH_SIZE_Y; y++) {
136126
for (int x = 0; x < BATCH_SIZE_X; x++) {
137-
if (any(greaterThanEqual(ivec3(out_pos.x + x, out_pos.y + y, out_pos.z), out_limits.xyz))) {
138-
continue;
127+
const ivec3 out_pos = ivec3(pos.x + x, pos.y + y, pos.z);
128+
if (all(lessThan(out_pos.xy, out_limits.xy))) {
129+
imageStore(t_out, out_pos, op(sum[y * BATCH_SIZE_X + x] + bias, out_min, out_max));
139130
}
140-
imageStore(t_out, ivec3(out_pos.x + x, out_pos.y + y, out_pos.z), op(sum[y][x], out_min, out_max));
141131
}
142132
}
143133
}

0 commit comments

Comments
 (0)