Skip to content

Commit 5200778

Browse files
[ET-VK] Storing positions in uint16 to instead of int in conv2d pw shader. (#11191)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11138 by @trivedivivek ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/98/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/98/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/97/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/98/orig @diff-train-skip-merge --------- Co-authored-by: Vivek Trivedi <[email protected]>
1 parent f2f473b commit 5200778

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88

99
#version 450 core
1010

11+
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
12+
1113
#define PRECISION ${PRECISION}
1214

1315
#define VEC4_T ${texel_type(DTYPE)}
1416

15-
#define TILE_SIZE_X ${TILE_SIZE_X}
16-
#define TILE_SIZE_Y ${TILE_SIZE_Y}
17+
#define TILE_SIZE_X uint16_t(${TILE_SIZE_X})
18+
#define TILE_SIZE_Y uint16_t(${TILE_SIZE_Y})
1719

1820
#define op(X, A, B) ${OPERATOR}
1921

@@ -63,11 +65,11 @@ void main() {
6365
// +--------+--------+
6466
// | pos[2] | pos[3] |
6567
// +--------+--------+
66-
int pos[TILE_SIZE_X * TILE_SIZE_Y * 2];
67-
for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) {
68-
for (int x = 0; x < TILE_SIZE_X; ++x) {
69-
pos[i * 2] = out_pos[0] * TILE_SIZE_X + x;
70-
pos[i * 2 + 1] = out_pos[1] * TILE_SIZE_Y + y;
68+
uint16_t pos[TILE_SIZE_X * TILE_SIZE_Y * 2];
69+
for (uint16_t y = uint16_t(0), i = uint16_t(0); y < TILE_SIZE_Y; ++y) {
70+
for (uint16_t x = uint16_t(0); x < TILE_SIZE_X; ++x) {
71+
pos[i * 2] = uint16_t(out_pos[0]) * TILE_SIZE_X + x;
72+
pos[i * 2 + 1] = uint16_t(out_pos[1]) * TILE_SIZE_Y + y;
7173
i++;
7274
}
7375
}

0 commit comments

Comments
 (0)