Skip to content

Commit 11f8f4a

Browse files
[ET-VK] Tuning local workgroup size calculation for conv2d pw to improve performance. (pytorch#11188)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: pytorch#11135 by @trivedivivek ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/95/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/95/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/94/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/95/orig @diff-train-skip-merge --------- Co-authored-by: Vivek Trivedi <[email protected]>
1 parent a1f8373 commit 11f8f4a

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,21 @@ void add_conv2d_node(
404404
wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1};
405405
}
406406

407+
utils::uvec3 local_wg_size;
408+
if (method == Conv2dMethod::Pointwise) {
409+
uint32_t local_wg_size_y = 1;
410+
if (wg_size[1] % 8 == 0) {
411+
local_wg_size_y = 8;
412+
} else if (wg_size[1] % 4 == 0) {
413+
local_wg_size_y = 4;
414+
} else if (wg_size[1] % 2 == 0) {
415+
local_wg_size_y = 2;
416+
}
417+
local_wg_size = {64 / local_wg_size_y, local_wg_size_y, 1};
418+
} else {
419+
local_wg_size = graph.create_local_wg_size(wg_size);
420+
}
421+
407422
vkapi::ParamsBindList param_buffers;
408423
std::vector<PushConstantDataInfo> push_constants;
409424
if (method == Conv2dMethod::Pointwise) {
@@ -464,7 +479,7 @@ void add_conv2d_node(
464479
graph,
465480
shader,
466481
wg_size,
467-
graph.create_local_wg_size(wg_size),
482+
local_wg_size,
468483
// Inputs and Outputs
469484
{{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}},
470485
// Shader params buffers

0 commit comments

Comments
 (0)