Skip to content

Commit f2f473b

Browse files
[ET-VK] Creating specialized version of conv2d pw shader for X and Y stride = 1 and padding = 0. (#11190)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11137 by @trivedivivek ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/97/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/97/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/96/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/97/orig @diff-train-skip-merge --------- Co-authored-by: Vivek Trivedi <[email protected]>
1 parent 3658479 commit f2f473b

File tree

3 files changed

+193
-3
lines changed

3 files changed

+193
-3
lines changed
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
#define TILE_SIZE_X ${TILE_SIZE_X}
16+
#define TILE_SIZE_Y ${TILE_SIZE_Y}
17+
18+
#define op(X, A, B) ${OPERATOR}
19+
20+
#include "indexing_utils.h"
21+
22+
layout(std430) buffer;
23+
24+
${layout_declare_tensor(0, "w", "t_out", DTYPE, "texture3d")}
25+
${layout_declare_tensor(1, "r", "t_in", DTYPE, "texture3d")}
26+
${layout_declare_tensor(2, "r", "t_kernel", DTYPE, "texture2d")}
27+
${layout_declare_tensor(3, "r", "t_bias", DTYPE, "texture2d")}
28+
29+
layout(push_constant) uniform restrict Block {
30+
ivec4 out_limits;
31+
ivec2 stride;
32+
ivec2 padding;
33+
int in_group_size;
34+
int dummy_padding;
35+
float out_min;
36+
float out_max;
37+
};
38+
39+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
40+
41+
#extension GL_EXT_control_flow_attributes : require
42+
43+
/*
44+
* Computes a 2D pointwise convolution of an NxN output tile. Calculating an
45+
* output tile for pointwise convolution is more efficient because the kernel
46+
* size is only 1x1, making it easier to re-use loaded texels from t_kernel.
47+
*/
48+
void main() {
49+
const int out_limits_scaled[2] = {out_limits.x + (TILE_SIZE_X - 1) * TILE_SIZE_X, out_limits.y + (TILE_SIZE_Y - 1) * TILE_SIZE_Y};
50+
51+
const int div_by_x = int(gl_GlobalInvocationID.x / out_limits_scaled[0]);
52+
const int out_pos[3] = {int(gl_GlobalInvocationID.x % out_limits_scaled[0]), div_by_x, int(gl_GlobalInvocationID.y)};
53+
54+
// If the top left position is out of bounds, then this invocation will have
55+
// no work to do.
56+
if (out_pos[1] >= out_limits_scaled[1] || out_pos[2] >= out_limits.z) {
57+
return;
58+
}
59+
60+
// Output position for TILE_SIZE = 2
61+
// +--------+--------+
62+
// | pos[0] | pos[1] |
63+
// +--------+--------+
64+
// | pos[2] | pos[3] |
65+
// +--------+--------+
66+
int pos[TILE_SIZE_X * TILE_SIZE_Y * 2];
67+
for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) {
68+
for (int x = 0; x < TILE_SIZE_X; ++x) {
69+
pos[i * 2] = out_pos[0] * TILE_SIZE_X + x;
70+
pos[i * 2 + 1] = out_pos[1] * TILE_SIZE_Y + y;
71+
i++;
72+
}
73+
}
74+
75+
// Final output array where each element is a tensor value.
76+
// Tuple of consecutive 4 elements represents a single output texel.
77+
float sum[TILE_SIZE_X * TILE_SIZE_Y * 4];
78+
79+
const vec4 bias = texelFetch(t_bias, ivec2(out_pos[2], 0), 0);
80+
81+
// Initialize the output array with the bias value
82+
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y * 4; i += 4) {
83+
sum[i] = bias.x;
84+
sum[i + 1] = bias.y;
85+
sum[i + 2] = bias.z;
86+
sum[i + 3] = bias.w;
87+
}
88+
89+
int z4 = 0;
90+
// Since the kernel is 1x1, we only have to loop over the depth dimension.
91+
for (int z = 0; z < in_group_size; z += 4, ++z4) {
92+
// During prepacking, the weight tensor has been permuted so that the
93+
// channel (IC) dim is along the x-axis, and the batch (OC) dim is along
94+
// the z-axis.
95+
float kernel_values[4 * 4]; // 4 channels, 4 elements per channel
96+
97+
// Load kernel values from texels to array
98+
[[unroll]] for (int i = 0; i < 4; ++i) {
99+
const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, out_pos[2]), 0);
100+
kernel_values[i * 4 + 0] = k_tex.x;
101+
kernel_values[i * 4 + 1] = k_tex.y;
102+
kernel_values[i * 4 + 2] = k_tex.z;
103+
kernel_values[i * 4 + 3] = k_tex.w;
104+
}
105+
106+
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
107+
const vec4 in_tex = texelFetch(t_in, ivec3(pos[i * 2], pos[i * 2 + 1], z4), 0);
108+
// Load the input texel into an array
109+
float tex_values[4];
110+
tex_values[0] = in_tex.x;
111+
tex_values[1] = in_tex.y;
112+
tex_values[2] = in_tex.z;
113+
tex_values[3] = in_tex.w;
114+
115+
// For 2x2 tile size algorithm works as follows.
116+
// To explain the calculations below, the contents of one in_tex and the
117+
// group of 4 texels loaded from t_kernel are shown:
118+
//
119+
// in_tex t_kernel
120+
// -x-> ---x--->
121+
// +---+ +----+----+----+----+
122+
// ^ | w | ^ | D0 | D1 | D2 | D3 |
123+
// | +---+ | +----+----+----+----+
124+
// | | z | | | C0 | C1 | C2 | C3 |
125+
// z +---+ z +----+----+----+----+
126+
// | | y | | | B0 | B2 | B2 | B3 |
127+
// | +---+ | +----+----+----+----+
128+
// | x | | A0 | A1 | A2 | A3 |
129+
// +---+ +----+----+----+----+
130+
//
131+
// In the t_kernel graphic, cells sharing the same letter are from
132+
// the same batch/output channel index, and the number denotes a unique
133+
// channel index. To calculate the output texel, the following
134+
// calculation is performed:
135+
//
136+
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
137+
// | x | | D0 | | y | | D1 | | z | | D2 | | w | | D3 |
138+
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
139+
// | x | | C0 | | y | | C1 | | z | | C2 | | w | | C3 |
140+
// +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+
141+
// | x | | B0 | | y | | B1 | | z | | B2 | | w | | B3 |
142+
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
143+
// | x | | A0 | | y | | A1 | | z | | A2 | | w | | A3 |
144+
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
145+
//
146+
// which is what is expressed in the following calculations. This is done
147+
// for each output position.
148+
for (int j = 0; j < 4; ++j) {
149+
sum[i * 4 + j] = tex_values[0] * kernel_values[0 + j] + sum[i * 4 + j];
150+
sum[i * 4 + j] = tex_values[1] * kernel_values[4 + j] + sum[i * 4 + j];
151+
sum[i * 4 + j] = tex_values[2] * kernel_values[8 + j] + sum[i * 4 + j];
152+
sum[i * 4 + j] = tex_values[3] * kernel_values[12 + j] + sum[i * 4 + j];
153+
}
154+
}
155+
}
156+
157+
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
158+
const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], out_pos[2]);
159+
if (all(lessThan(pos_l, out_limits.xyz))) {
160+
imageStore(t_out, pos_l, op(vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]), out_min, out_max));
161+
}
162+
}
163+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
conv2d_pw_s1p0:
8+
parameter_names_with_default_values:
9+
OPERATOR: X
10+
NDIM: 3
11+
DTYPE: float
12+
TILE_SIZE_X: 1
13+
TILE_SIZE_Y: 4
14+
generate_variant_forall:
15+
DTYPE:
16+
- VALUE: half
17+
- VALUE: float
18+
shader_variants:
19+
- NAME: conv2d_pw_s1p0
20+
- NAME: conv2d_pw_s1p0_clamp
21+
OPERATOR: clamp(X, A, B)

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ vkapi::ShaderInfo get_conv2d_shader(
127127
const Conv2dMethod method,
128128
const ValueRef weight,
129129
const bool clamp_out = false,
130-
const bool stride_equals_dilation = false) {
130+
const bool stride_equals_dilation = false,
131+
const bool stride_1_padding_0 = false) {
131132
std::string kernel_name;
132133
kernel_name.reserve(kShaderNameReserve);
133134
switch (method) {
@@ -150,7 +151,7 @@ vkapi::ShaderInfo get_conv2d_shader(
150151
if (prepack_weights) {
151152
kernel_name = "conv2d";
152153
} else {
153-
kernel_name = "conv2d_pw";
154+
kernel_name = stride_1_padding_0 ? "conv2d_pw_s1p0" : "conv2d_pw";
154155
}
155156
break;
156157
case Conv2dMethod::SlidingWindow:
@@ -382,6 +383,10 @@ void add_conv2d_node(
382383
(kernel_params.stride[0] == kernel_params.dilation[0] &&
383384
kernel_params.stride[1] == kernel_params.dilation[1]);
384385

386+
const bool stride_1_padding_0 =
387+
(kernel_params.stride[0] == 1 && kernel_params.stride[1] == 1 &&
388+
kernel_params.padding[0] == 0 && kernel_params.padding[1] == 0);
389+
385390
OutputParams out_params = {out_min_val, out_max_val};
386391

387392
check_conv2d_params(kernel_params, transposed_val);
@@ -393,7 +398,8 @@ void add_conv2d_node(
393398
method,
394399
weight_data,
395400
clamp_out,
396-
stride_equals_dilation);
401+
stride_equals_dilation,
402+
stride_1_padding_0);
397403

398404
utils::uvec3 wg_size = create_conv2d_global_wg_size(
399405
graph, method, out, weight_data, stride_equals_dilation);

0 commit comments

Comments
 (0)