Skip to content

Commit 932ade9

Browse files
andreanicastrofacebook-github-bot
authored andcommitted
Where layer (#11181)
Summary: The `where` layer was missing and this diff adds it along with the bool tensor support. Reviewed By: SS-JIA Differential Revision: D74175287
1 parent d8ac866 commit 932ade9

File tree

13 files changed

+301
-8
lines changed

13 files changed

+301
-8
lines changed

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
"uint": "uimage3D",
6363
"int8": "iimage3D",
6464
"uint8": "uimage3D",
65+
"bool": "uimage3D",
6566
},
6667
2: {
6768
"float": "image2D",
@@ -70,6 +71,7 @@
7071
"uint": "uimage2D",
7172
"int8": "iimage2D",
7273
"uint8": "uimage2D",
74+
"bool": "uimage2D",
7375
},
7476
},
7577
"SAMPLER_T": {
@@ -80,6 +82,7 @@
8082
"uint": "usampler3D",
8183
"int8": "isampler3D",
8284
"uint8": "usampler3D",
85+
"bool": "usampler3D",
8386
},
8487
2: {
8588
"float": "sampler2D",
@@ -88,6 +91,7 @@
8891
"uint": "usampler2D",
8992
"int8": "isampler2D",
9093
"uint8": "usampler2D",
94+
"bool": "usampler2D",
9195
},
9296
},
9397
"IMAGE_FORMAT": {
@@ -97,6 +101,7 @@
97101
"uint": "rgba32ui",
98102
"int8": "rgba8i",
99103
"uint8": "rgba8ui",
104+
"bool": "rgba8ui",
100105
},
101106
}
102107

@@ -115,7 +120,8 @@ def buffer_scalar_type(dtype: str) -> str:
115120
return "float16_t"
116121
elif dtype[-1] == "8":
117122
return dtype + "_t"
118-
123+
elif dtype == "bool":
124+
return "uint8_t"
119125
return dtype
120126

121127

@@ -135,17 +141,19 @@ def buffer_gvec_type(dtype: str, n: int) -> str:
135141
return f"i8vec{n}"
136142
elif dtype == "uint8":
137143
return f"u8vec{n}"
144+
elif dtype == "bool":
145+
return f"u8vec{n}"
138146

139147
raise AssertionError(f"Invalid dtype: {dtype}")
140148

141149

142150
def texel_type(dtype: str) -> str:
143151
image_format = TYPE_MAPPINGS["IMAGE_FORMAT"][dtype]
144-
if image_format[-1] == "f":
152+
if image_format[-1:] == "f":
145153
return "vec4"
146-
elif image_format[-2] == "ui":
154+
elif image_format[-2:] == "ui":
147155
return "uvec4"
148-
elif image_format[-1] == "i":
156+
elif image_format[-1:] == "i":
149157
return "ivec4"
150158
raise AssertionError(f"Invalid image format: {image_format}")
151159

@@ -360,7 +368,7 @@ def define_required_extensions(dtypes: Union[str, List[str]]):
360368
elif dtype == "int16" or dtype == "uint16":
361369
nbit = "16bit"
362370
glsl_type = "int16"
363-
elif dtype == "int8" or dtype == "uint8":
371+
elif dtype == "int8" or dtype == "uint8" or dtype == "bool":
364372
nbit = "8bit"
365373
glsl_type = "int8"
366374

backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ buffer_to_nchw:
1414
- VALUE: float
1515
- VALUE: int
1616
- VALUE: int8
17+
- VALUE: uint8
1718
shader_variants:
1819
- NAME: buffer_to_nchw

backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ image_to_nchw:
1515
- VALUE: float
1616
- VALUE: int
1717
- VALUE: int8
18+
- VALUE: uint8
1819
shader_variants:
1920
- NAME: image_to_nchw_texture3d
2021
- NAME: image_to_nchw_texture2d

backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ nchw_to_buffer:
1414
- VALUE: float
1515
- VALUE: int
1616
- VALUE: int8
17+
- VALUE: uint8
1718
shader_variants:
1819
- NAME: nchw_to_buffer

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ nchw_to_image:
1515
- VALUE: float
1616
- VALUE: int
1717
- VALUE: int8
18+
- VALUE: uint8
1819
shader_variants:
1920
- NAME: nchw_to_image_texture3d
2021
- NAME: nchw_to_image_texture2d
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// where.glsl
2+
3+
/*
4+
* Copyright (c) Meta Platforms, Inc. and affiliates.
5+
* All rights reserved.
6+
*
7+
* This source code is licensed under the BSD-style license found in the
8+
* LICENSE file in the root directory of this source tree.
9+
*/
10+
11+
12+
#version 450 core
13+
14+
#define PRECISION ${PRECISION}
15+
16+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
17+
#define T ${buffer_scalar_type(DTYPE)}
18+
#define COND_T ${buffer_scalar_type("bool")}
19+
20+
${define_active_storage_type(STORAGE)}
21+
${define_required_extensions(DTYPE)}
22+
${define_required_extensions("bool")}
23+
24+
layout(std430) buffer;
25+
26+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
27+
${layout_declare_tensor(B, "r", "t_condition", "bool", STORAGE)}
28+
${layout_declare_tensor(B, "r", "t_self", DTYPE, STORAGE)}
29+
${layout_declare_tensor(B, "r", "t_other", DTYPE, STORAGE)}
30+
31+
32+
#include "indexing_utils.h"
33+
34+
$if STORAGE == "buffer":
35+
${layout_declare_ubo(B, "int", "out_numl")}
36+
${layout_declare_ubo(B, "ivec4", "out_strides")}
37+
${layout_declare_ubo(B, "ivec4", "cond_strides")}
38+
${layout_declare_ubo(B, "ivec4", "self_strides")}
39+
${layout_declare_ubo(B, "ivec4", "other_strides")}
40+
41+
${layout_declare_spec_const(C, "int", "out_packed_dim", "DEFAULT_LAYOUT")}
42+
${layout_declare_spec_const(C, "int", "cond_packed_dim", "DEFAULT_LAYOUT")}
43+
${layout_declare_spec_const(C, "int", "self_packed_dim", "DEFAULT_LAYOUT")}
44+
${layout_declare_spec_const(C, "int", "other_packed_dim", "DEFAULT_LAYOUT")}
45+
$else:
46+
${layout_declare_ubo(B, "ivec3", "out_limits")}
47+
48+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
49+
50+
#ifdef USING_BUFFER
51+
52+
void main() {
53+
int out_bufi = int(gl_GlobalInvocationID.x);
54+
// ivec4 tidx = ivec4(gl_GlobalInvocationID, 0);
55+
// int out_bufi = tidx_to_bufi(tidx, out_strides);
56+
// int cond_bufi = tidx_to_bufi(tidx, cond_strides);
57+
// int self_bufi = tidx_to_bufi(tidx, self_strides);
58+
// int other_bufi = tidx_to_bufi(tidx, other_strides);
59+
if (out_bufi >= out_numl) {
60+
return;
61+
}
62+
63+
const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim);
64+
out_bufi = tidx_to_bufi(out_tidx, out_strides);
65+
66+
const ivec4 cond_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim);
67+
const int cond_bufi = tidx_to_bufi(cond_tidx, cond_strides);
68+
69+
const ivec4 self_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim);
70+
const int self_bufi = tidx_to_bufi(self_tidx, self_strides);
71+
72+
const ivec4 other_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim);
73+
const int other_bufi = tidx_to_bufi(other_tidx, other_strides);
74+
75+
COND_T cond = t_condition[cond_bufi] ;
76+
T v_self = t_self[self_bufi];
77+
T v_other = t_other[other_bufi];
78+
79+
if (cond > 0) {
80+
t_out[out_bufi] = v_self;
81+
} else {
82+
t_out[out_bufi] = v_other;
83+
}
84+
}
85+
86+
#else // !USING_BUFFER
87+
88+
void main() {
89+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
90+
91+
92+
if (any(greaterThanEqual(pos, out_limits))) {
93+
return;
94+
}
95+
96+
vec4 cond = load_texel(t_condition, pos);
97+
VEC4_T selftex = load_texel(t_self, pos);
98+
VEC4_T othertex = load_texel(t_other, pos);
99+
100+
VEC4_T outtex;
101+
102+
for (int idx = 0; idx < 4; ++idx) {
103+
if (cond[idx] == 1) {
104+
outtex[idx] = selftex[idx];
105+
} else {
106+
outtex[idx] = othertex[idx];
107+
}
108+
}
109+
write_texel(t_out, pos, outtex);
110+
}
111+
#endif // !USING_BUFFER
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
where:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
generate_variant_forall:
5+
STORAGE:
6+
- VALUE: texture3d
7+
- VALUE: buffer
8+
DTYPE:
9+
- VALUE: half
10+
- VALUE: float
11+
shader_variants:
12+
- NAME: where
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// Where.cpp
2+
3+
/*
4+
* Copyright (c) Meta Platforms, Inc. and affiliates.
5+
* All rights reserved.
6+
*
7+
* This source code is licensed under the BSD-style license found in the
8+
* LICENSE file in the root directory of this source tree.
9+
*/
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
14+
15+
namespace vkcompute {
16+
17+
void resize_where_node(
18+
ComputeGraph* graph,
19+
const std::vector<ArgGroup>& args,
20+
const std::vector<ValueRef>& extra_args) {
21+
(void)extra_args;
22+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
23+
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
24+
25+
std::vector<int64_t> in_sizes = in->sizes();
26+
out->virtual_resize(in_sizes);
27+
}
28+
29+
void add_where_texture_node(
30+
ComputeGraph& graph,
31+
const ValueRef cond,
32+
const ValueRef self,
33+
const ValueRef other,
34+
const ValueRef out) {
35+
std::string kernel_name = "where";
36+
37+
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
38+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
39+
40+
const utils::uvec3 global_wg_size = graph.create_global_wg_size(out);
41+
const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);
42+
43+
graph.execute_nodes().emplace_back(new DispatchNode(
44+
graph,
45+
// Shader
46+
VK_KERNEL_FROM_STR(kernel_name),
47+
// Workgroup sizes
48+
global_wg_size,
49+
local_wg_size,
50+
// Inputs and Outputs
51+
{{out, vkapi::kWrite}, {{cond, self, other}, vkapi::kRead}},
52+
// Parameter buffers
53+
{graph.logical_limits_ubo(self)},
54+
// Push Constants
55+
{},
56+
// Specialization Constants
57+
{graph.packed_dim_of(out)},
58+
// Resize Arguments
59+
{},
60+
// Resizing Logic
61+
resize_where_node));
62+
}
63+
64+
void add_where_buffer_node(
65+
ComputeGraph& graph,
66+
const ValueRef cond,
67+
const ValueRef self,
68+
const ValueRef other,
69+
const ValueRef out) {
70+
std::string kernel_name = "where";
71+
72+
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
73+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
74+
75+
const utils::uvec3 global_wg_size = graph.create_global_wg_size(out);
76+
const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);
77+
78+
vkapi::ParamsBindList ubos = {
79+
graph.numel_ubo(out),
80+
graph.strides_ubo(out),
81+
graph.strides_ubo(cond),
82+
graph.strides_ubo(self),
83+
graph.strides_ubo(other)};
84+
85+
graph.execute_nodes().emplace_back(new DispatchNode(
86+
graph,
87+
// Shader
88+
VK_KERNEL_FROM_STR(kernel_name),
89+
// Workgroup sizes
90+
global_wg_size,
91+
local_wg_size,
92+
// Inputs and Outputs
93+
{{out, vkapi::kWrite}, {{cond, self, other}, vkapi::kRead}},
94+
// Parameter buffers
95+
ubos,
96+
// Push Constants
97+
{},
98+
// Specialization Constants
99+
{graph.packed_dim_of(out),
100+
graph.packed_dim_of(cond),
101+
graph.packed_dim_of(self),
102+
graph.packed_dim_of(other)},
103+
// Resize Arguments
104+
{},
105+
// Resizing Logic
106+
resize_where_node));
107+
}
108+
109+
void where(ComputeGraph& graph, const std::vector<ValueRef>& args) {
110+
int args_i = 0;
111+
const ValueRef cond = args[args_i++];
112+
const ValueRef self = args[args_i++];
113+
const ValueRef other = args[args_i++];
114+
const ValueRef out = args[args_i++];
115+
if (graph.is_buffer_storage(out)) {
116+
add_where_buffer_node(graph, cond, self, other, out);
117+
} else {
118+
add_where_texture_node(graph, cond, self, other, out);
119+
}
120+
}
121+
122+
REGISTER_OPERATORS {
123+
VK_REGISTER_OP(aten.where.self, where);
124+
}
125+
126+
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype) {
4949
break;
5050
case vkapi::kByte:
5151
case vkapi::kQUInt8:
52+
case vkapi::kBool:
5253
kernel_name += "_uint8";
5354
break;
5455
default:

backends/vulkan/runtime/vk_api/Types.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
_(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte) \
2828
_(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \
2929
_(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \
30-
_(bool, VK_FORMAT_R8G8B8A8_SINT, Bool) \
30+
_(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Bool) \
3131
_(uint16_t, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \
3232
_(float, VK_FORMAT_FLOAT4, Float) \
3333
_(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8) \

0 commit comments

Comments
 (0)