Skip to content

Commit 279c8ee

Browse files
prathikrguschmue
authored andcommitted
gather elements webgpu implementation (#23137)
Increases operator coverage for WebGPU EP.
1 parent 674d333 commit 279c8ee

File tree

4 files changed

+126
-3
lines changed

4 files changed

+126
-3
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/common/inlined_containers.h"
5+
#include "core/providers/webgpu/tensor/gather_elements.h"
6+
#include "core/providers/cpu/tensor/utils.h"
7+
#include "core/providers/webgpu/shader_helper.h"
8+
#include "core/providers/webgpu/webgpu_supported_types.h"
9+
10+
namespace onnxruntime {
11+
namespace webgpu {
12+
13+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
14+
GatherElements,
15+
kOnnxDomain,
16+
11, 12,
17+
kWebGpuExecutionProvider,
18+
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()),
19+
GatherElements);
20+
21+
ONNX_OPERATOR_KERNEL_EX(
22+
GatherElements,
23+
kOnnxDomain,
24+
13,
25+
kWebGpuExecutionProvider,
26+
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()),
27+
GatherElements);
28+
29+
Status GatherElementsProgram::GenerateShaderCode(ShaderHelper& shader) const {
30+
const ShaderVariableHelper& input = shader.AddInput("input", ShaderUsage::UseUniform);
31+
const ShaderVariableHelper& indices = shader.AddInput("indices", ShaderUsage::UseUniform);
32+
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform);
33+
34+
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
35+
<< "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n"
36+
<< "var idx = " << indices.GetByOffset("global_idx") << ";\n"
37+
<< "if (idx < 0) {\n"
38+
<< " idx = idx + uniforms.axis_dim_limit;\n"
39+
<< "}\n"
40+
<< "var input_indices = output_indices;\n"
41+
<< input.IndicesSet("input_indices", "uniforms.axis", "u32(idx)") << ";\n"
42+
<< "let value = " << input.GetByIndices("input_indices") << ";\n"
43+
<< output.SetByOffset("global_idx", "value") << ";\n";
44+
45+
return Status::OK();
46+
}
47+
48+
Status GatherElements::ComputeInternal(ComputeContext& context) const {
49+
const auto* input_tensor = context.Input(0);
50+
const TensorShape& input_shape = input_tensor->Shape();
51+
int64_t input_rank = input_shape.NumDimensions();
52+
53+
const auto* indices_tensor = context.Input(1);
54+
const TensorShape& indices_shape = indices_tensor->Shape();
55+
56+
// Handle negative axis
57+
int64_t axis = axis_;
58+
if (axis < 0) {
59+
axis += input_rank;
60+
}
61+
62+
auto axis_dim_limit = input_shape[axis];
63+
64+
auto output_dims = indices_shape.AsShapeVector();
65+
TensorShape output_shape(output_dims);
66+
auto* output_tensor = context.Output(0, output_shape);
67+
int64_t output_size = output_tensor->Shape().Size();
68+
69+
if (output_size == 0) {
70+
return Status::OK();
71+
}
72+
73+
GatherElementsProgram program{};
74+
program
75+
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
76+
.AddInputs({{indices_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
77+
.AddOutputs({output_tensor})
78+
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
79+
.AddUniformVariables({{static_cast<uint32_t>(output_size)},
80+
{static_cast<int32_t>(axis_dim_limit)},
81+
{static_cast<int32_t>(axis)}});
82+
return context.RunProgram(program);
83+
}
84+
85+
} // namespace webgpu
86+
} // namespace onnxruntime
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/webgpu/webgpu_kernel.h"
7+
#include "core/providers/webgpu/program.h"
8+
9+
namespace onnxruntime {
10+
namespace webgpu {
11+
12+
class GatherElementsProgram final : public Program<GatherElementsProgram> {
13+
public:
14+
GatherElementsProgram() : Program{"GatherElements"} {}
15+
16+
Status GenerateShaderCode(ShaderHelper& sh) const override;
17+
18+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32},
19+
{"axis_dim_limit", ProgramUniformVariableDataType::Int32},
20+
{"axis", ProgramUniformVariableDataType::Int32});
21+
};
22+
23+
class GatherElements final : public WebGpuKernel {
24+
public:
25+
GatherElements(const OpKernelInfo& info) : WebGpuKernel(info) {
26+
axis_ = info.GetAttrOrDefault<int64_t>("axis", 0);
27+
}
28+
29+
Status ComputeInternal(ComputeContext& context) const override;
30+
31+
private:
32+
int64_t axis_;
33+
};
34+
35+
} // namespace webgpu
36+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -649,8 +649,8 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
649649
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Gather)>,
650650
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Gather)>,
651651

652-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements)>,
653-
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, GatherElements)>,
652+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements)>,
653+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, GatherElements)>,
654654

655655
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, Resize)>,
656656
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Resize)>,

onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,9 +389,10 @@ TEST(GatherElementsOpTest, IndicesOutOfBounds) {
389389
// skip openvino which will not throw error message but will ensure no out-of-bound access
390390
// skip TensorRT because it doesn't support out of bounds indices
391391
// skip QNN because it doesn't support out of bounds indices
392+
// skip WebGPU because it doesn't support out of bounds indices
392393
test.Run(OpTester::ExpectResult::kExpectFailure, "",
393394
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kOpenVINOExecutionProvider,
394-
kTensorrtExecutionProvider, kDmlExecutionProvider, kQnnExecutionProvider});
395+
kTensorrtExecutionProvider, kDmlExecutionProvider, kQnnExecutionProvider, kWebGpuExecutionProvider});
395396
}
396397

397398
TEST(GatherElementsOpTest, BigIndices) {

0 commit comments

Comments
 (0)