Skip to content

Where layer #11181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions backends/vulkan/runtime/gen_vulkan_spv.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"uint": "uimage3D",
"int8": "iimage3D",
"uint8": "uimage3D",
"bool": "uimage3D",
},
2: {
"float": "image2D",
Expand All @@ -70,6 +71,7 @@
"uint": "uimage2D",
"int8": "iimage2D",
"uint8": "uimage2D",
"bool": "uimage2D",
},
},
"SAMPLER_T": {
Expand All @@ -80,6 +82,7 @@
"uint": "usampler3D",
"int8": "isampler3D",
"uint8": "usampler3D",
"bool": "usampler3D",
},
2: {
"float": "sampler2D",
Expand All @@ -88,6 +91,7 @@
"uint": "usampler2D",
"int8": "isampler2D",
"uint8": "usampler2D",
"bool": "usampler2D",
},
},
"IMAGE_FORMAT": {
Expand All @@ -97,6 +101,7 @@
"uint": "rgba32ui",
"int8": "rgba8i",
"uint8": "rgba8ui",
"bool": "rgba8ui",
},
}

Expand All @@ -115,7 +120,8 @@ def buffer_scalar_type(dtype: str) -> str:
return "float16_t"
elif dtype[-1] == "8":
return dtype + "_t"

elif dtype == "bool":
return "uint8_t"
return dtype


Expand All @@ -135,17 +141,19 @@ def buffer_gvec_type(dtype: str, n: int) -> str:
return f"i8vec{n}"
elif dtype == "uint8":
return f"u8vec{n}"
elif dtype == "bool":
return f"u8vec{n}"

raise AssertionError(f"Invalid dtype: {dtype}")


def texel_type(dtype: str) -> str:
image_format = TYPE_MAPPINGS["IMAGE_FORMAT"][dtype]
if image_format[-1] == "f":
if image_format[-1:] == "f":
return "vec4"
elif image_format[-2] == "ui":
elif image_format[-2:] == "ui":
return "uvec4"
elif image_format[-1] == "i":
elif image_format[-1:] == "i":
return "ivec4"
raise AssertionError(f"Invalid image format: {image_format}")

Expand Down Expand Up @@ -360,7 +368,7 @@ def define_required_extensions(dtypes: Union[str, List[str]]):
elif dtype == "int16" or dtype == "uint16":
nbit = "16bit"
glsl_type = "int16"
elif dtype == "int8" or dtype == "uint8":
elif dtype == "int8" or dtype == "uint8" or dtype == "bool":
nbit = "8bit"
glsl_type = "int8"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ buffer_to_nchw:
- VALUE: float
- VALUE: int
- VALUE: int8
- VALUE: uint8
shader_variants:
- NAME: buffer_to_nchw
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ image_to_nchw:
- VALUE: float
- VALUE: int
- VALUE: int8
- VALUE: uint8
shader_variants:
- NAME: image_to_nchw_texture3d
- NAME: image_to_nchw_texture2d
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ nchw_to_buffer:
- VALUE: float
- VALUE: int
- VALUE: int8
- VALUE: uint8
shader_variants:
- NAME: nchw_to_buffer
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ nchw_to_image:
- VALUE: float
- VALUE: int
- VALUE: int8
- VALUE: uint8
shader_variants:
- NAME: nchw_to_image_texture3d
- NAME: nchw_to_image_texture2d
Expand Down
111 changes: 111 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/where.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// where.glsl

/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/


#version 450 core

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
#define T ${buffer_scalar_type(DTYPE)}
#define COND_T ${buffer_scalar_type("bool")}

${define_active_storage_type(STORAGE)}
${define_required_extensions(DTYPE)}
${define_required_extensions("bool")}

layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_condition", "bool", STORAGE)}
${layout_declare_tensor(B, "r", "t_self", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_other", DTYPE, STORAGE)}


#include "indexing_utils.h"

$if STORAGE == "buffer":
${layout_declare_ubo(B, "int", "out_numl")}
${layout_declare_ubo(B, "ivec4", "out_strides")}
${layout_declare_ubo(B, "ivec4", "cond_strides")}
${layout_declare_ubo(B, "ivec4", "self_strides")}
${layout_declare_ubo(B, "ivec4", "other_strides")}

${layout_declare_spec_const(C, "int", "out_packed_dim", "DEFAULT_LAYOUT")}
${layout_declare_spec_const(C, "int", "cond_packed_dim", "DEFAULT_LAYOUT")}
${layout_declare_spec_const(C, "int", "self_packed_dim", "DEFAULT_LAYOUT")}
${layout_declare_spec_const(C, "int", "other_packed_dim", "DEFAULT_LAYOUT")}
$else:
${layout_declare_ubo(B, "ivec3", "out_limits")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#ifdef USING_BUFFER

void main() {
int out_bufi = int(gl_GlobalInvocationID.x);
// ivec4 tidx = ivec4(gl_GlobalInvocationID, 0);
// int out_bufi = tidx_to_bufi(tidx, out_strides);
// int cond_bufi = tidx_to_bufi(tidx, cond_strides);
// int self_bufi = tidx_to_bufi(tidx, self_strides);
// int other_bufi = tidx_to_bufi(tidx, other_strides);
if (out_bufi >= out_numl) {
return;
}

const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim);
out_bufi = tidx_to_bufi(out_tidx, out_strides);

const ivec4 cond_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim);
const int cond_bufi = tidx_to_bufi(cond_tidx, cond_strides);

const ivec4 self_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim);
const int self_bufi = tidx_to_bufi(self_tidx, self_strides);

const ivec4 other_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim);
const int other_bufi = tidx_to_bufi(other_tidx, other_strides);

COND_T cond = t_condition[cond_bufi] ;
T v_self = t_self[self_bufi];
T v_other = t_other[other_bufi];

if (cond > 0) {
t_out[out_bufi] = v_self;
} else {
t_out[out_bufi] = v_other;
}
}

#else // !USING_BUFFER

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);


if (any(greaterThanEqual(pos, out_limits))) {
return;
}

vec4 cond = load_texel(t_condition, pos);
VEC4_T selftex = load_texel(t_self, pos);
VEC4_T othertex = load_texel(t_other, pos);

VEC4_T outtex;

for (int idx = 0; idx < 4; ++idx) {
if (cond[idx] == 1) {
outtex[idx] = selftex[idx];
} else {
outtex[idx] = othertex[idx];
}
}
write_texel(t_out, pos, outtex);
}
#endif // !USING_BUFFER
12 changes: 12 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/where.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
where:
parameter_names_with_default_values:
DTYPE: float
generate_variant_forall:
STORAGE:
- VALUE: texture3d
- VALUE: buffer
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: where
126 changes: 126 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Where.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Where.cpp

/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

namespace vkcompute {

void resize_where_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
(void)extra_args;
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
vTensorPtr in = graph->get_tensor(args[1].refs[0]);

std::vector<int64_t> in_sizes = in->sizes();
out->virtual_resize(in_sizes);
}

void add_where_texture_node(
ComputeGraph& graph,
const ValueRef cond,
const ValueRef self,
const ValueRef other,
const ValueRef out) {
std::string kernel_name = "where";

add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_dtype_suffix(kernel_name, graph.dtype_of(out));

const utils::uvec3 global_wg_size = graph.create_global_wg_size(out);
const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);

graph.execute_nodes().emplace_back(new DispatchNode(
graph,
// Shader
VK_KERNEL_FROM_STR(kernel_name),
// Workgroup sizes
global_wg_size,
local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{cond, self, other}, vkapi::kRead}},
// Parameter buffers
{graph.logical_limits_ubo(self)},
// Push Constants
{},
// Specialization Constants
{graph.packed_dim_of(out)},
// Resize Arguments
{},
// Resizing Logic
resize_where_node));
}

void add_where_buffer_node(
ComputeGraph& graph,
const ValueRef cond,
const ValueRef self,
const ValueRef other,
const ValueRef out) {
std::string kernel_name = "where";

add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_dtype_suffix(kernel_name, graph.dtype_of(out));

const utils::uvec3 global_wg_size = graph.create_global_wg_size(out);
const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);

vkapi::ParamsBindList ubos = {
graph.numel_ubo(out),
graph.strides_ubo(out),
graph.strides_ubo(cond),
graph.strides_ubo(self),
graph.strides_ubo(other)};

graph.execute_nodes().emplace_back(new DispatchNode(
graph,
// Shader
VK_KERNEL_FROM_STR(kernel_name),
// Workgroup sizes
global_wg_size,
local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{cond, self, other}, vkapi::kRead}},
// Parameter buffers
ubos,
// Push Constants
{},
// Specialization Constants
{graph.packed_dim_of(out),
graph.packed_dim_of(cond),
graph.packed_dim_of(self),
graph.packed_dim_of(other)},
// Resize Arguments
{},
// Resizing Logic
resize_where_node));
}

void where(ComputeGraph& graph, const std::vector<ValueRef>& args) {
int args_i = 0;
const ValueRef cond = args[args_i++];
const ValueRef self = args[args_i++];
const ValueRef other = args[args_i++];
const ValueRef out = args[args_i++];
if (graph.is_buffer_storage(out)) {
add_where_buffer_node(graph, cond, self, other, out);
} else {
add_where_texture_node(graph, cond, self, other, out);
}
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.where.self, where);
}

} // namespace vkcompute
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype) {
break;
case vkapi::kByte:
case vkapi::kQUInt8:
case vkapi::kBool:
kernel_name += "_uint8";
break;
default:
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/vk_api/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
_(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte) \
_(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \
_(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \
_(bool, VK_FORMAT_R8G8B8A8_SINT, Bool) \
_(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Bool) \
_(uint16_t, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \
_(float, VK_FORMAT_FLOAT4, Float) \
_(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8) \
Expand Down
Loading
Loading