1+ // Copyright (c) Microsoft Corporation. All rights reserved.
2+ // Licensed under the MIT License.
3+
4+ #include " core/common/inlined_containers.h"
5+ #include " core/providers/webgpu/tensor/gather_elements.h"
6+ #include " core/providers/cpu/tensor/utils.h"
7+ #include " core/providers/webgpu/shader_helper.h"
8+ #include " core/providers/webgpu/webgpu_supported_types.h"
9+
10+ namespace onnxruntime {
11+ namespace webgpu {
12+
13+ ONNX_OPERATOR_VERSIONED_KERNEL_EX (
14+ GatherElements,
15+ kOnnxDomain ,
16+ 11 , 12 ,
17+ kWebGpuExecutionProvider ,
18+ (*KernelDefBuilder::Create ()).TypeConstraint(" T" , WebGpuSupportedFloatTypes()),
19+ GatherElements);
20+
21+ ONNX_OPERATOR_KERNEL_EX (
22+ GatherElements,
23+ kOnnxDomain ,
24+ 13 ,
25+ kWebGpuExecutionProvider ,
26+ (*KernelDefBuilder::Create ()).TypeConstraint(" T" , WebGpuSupportedFloatTypes()),
27+ GatherElements);
28+
29+ Status GatherElementsProgram::GenerateShaderCode (ShaderHelper& shader) const {
30+ const ShaderVariableHelper& input = shader.AddInput (" input" , ShaderUsage::UseUniform);
31+ const ShaderVariableHelper& indices = shader.AddInput (" indices" , ShaderUsage::UseUniform);
32+ const ShaderVariableHelper& output = shader.AddOutput (" output" , ShaderUsage::UseUniform);
33+
34+ shader.MainFunctionBody () << shader.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.output_size" )
35+ << " let output_indices = " << output.OffsetToIndices (" global_idx" ) << " ;\n "
36+ << " var idx = " << indices.GetByOffset (" global_idx" ) << " ;\n "
37+ << " if (idx < 0) {\n "
38+ << " idx = idx + uniforms.axis_dim_limit;\n "
39+ << " }\n "
40+ << " var input_indices = output_indices;\n "
41+ << input.IndicesSet (" input_indices" , " uniforms.axis" , " u32(idx)" ) << " ;\n "
42+ << " let value = " << input.GetByIndices (" input_indices" ) << " ;\n "
43+ << output.SetByOffset (" global_idx" , " value" ) << " ;\n " ;
44+
45+ return Status::OK ();
46+ }
47+
48+ Status GatherElements::ComputeInternal (ComputeContext& context) const {
49+ const auto * input_tensor = context.Input (0 );
50+ const TensorShape& input_shape = input_tensor->Shape ();
51+ int64_t input_rank = input_shape.NumDimensions ();
52+
53+ const auto * indices_tensor = context.Input (1 );
54+ const TensorShape& indices_shape = indices_tensor->Shape ();
55+
56+ // Handle negative axis
57+ int64_t axis = axis_;
58+ if (axis < 0 ) {
59+ axis += input_rank;
60+ }
61+
62+ auto axis_dim_limit = input_shape[axis];
63+
64+ auto output_dims = indices_shape.AsShapeVector ();
65+ TensorShape output_shape (output_dims);
66+ auto * output_tensor = context.Output (0 , output_shape);
67+ int64_t output_size = output_tensor->Shape ().Size ();
68+
69+ if (output_size == 0 ) {
70+ return Status::OK ();
71+ }
72+
73+ GatherElementsProgram program{};
74+ program
75+ .AddInputs ({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
76+ .AddInputs ({{indices_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
77+ .AddOutputs ({output_tensor})
78+ .SetDispatchGroupSize ((output_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE)
79+ .AddUniformVariables ({{static_cast <uint32_t >(output_size)},
80+ {static_cast <int32_t >(axis_dim_limit)},
81+ {static_cast <int32_t >(axis)}});
82+ return context.RunProgram (program);
83+ }
84+
85+ } // namespace webgpu
86+ } // namespace onnxruntime
0 commit comments