Skip to content

Commit 63b5cef

Browse files
authored
[webgpu] Add ScatterElements support (#26903)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 2573346 commit 63b5cef

File tree

5 files changed

+377
-2
lines changed

5 files changed

+377
-2
lines changed

onnxruntime/core/providers/webgpu/shader_helper.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType va
144144
ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int32 ||
145145
var_type == ProgramVariableDataType::Uint32 ||
146146
var_type == ProgramVariableDataType::Float32 ||
147+
var_type == ProgramVariableDataType::Float16 ||
147148
var_type == ProgramVariableDataType::Float16x4 ||
148149
var_type == ProgramVariableDataType::Float32x4,
149150
"Unexpected program variable type ", int(var_type), " for atomic variable");
@@ -482,6 +483,8 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector<int>& sha
482483
ss << "atomic<u32>";
483484
} else if (output->type_ == ProgramVariableDataType::Int32) {
484485
ss << "atomic<i32>";
486+
} else if (output->type_ == ProgramVariableDataType::Float16) {
487+
ss << "atomic<u32>"; // emulate f16 atomic via u32 (storing as packed u16)
485488
} else {
486489
ORT_RETURN_IF(true, "Unsupported atomic type: ", int(output->type_));
487490
}
Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/tensor/scatter_elements.h"
5+
#include "core/providers/webgpu/shader_helper.h"
6+
#include "core/providers/webgpu/webgpu_supported_types.h"
7+
8+
namespace onnxruntime {
9+
namespace webgpu {
10+
11+
Status ScatterElementsProgram::GenerateShaderCode(ShaderHelper& shader) const {
12+
const auto& indices = shader.AddInput("indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
13+
const auto& updates = shader.AddInput("updates", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
14+
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseShapeAndStride);
15+
16+
// Helper lambda for atomic reduction operations
17+
auto atomic_reduction_snippet = [](ScatterElementsReduction reduction, const std::string& base_ptr, const std::string& offset_var, const std::string& value, const std::string& data_type) -> std::string {
18+
std::ostringstream ss;
19+
bool is_32_bit_integer = data_type == "i32" || data_type == "u32";
20+
bool is_unsigned_integer = data_type == "u32";
21+
bool is_float16 = data_type == "f16";
22+
23+
std::ostringstream ss_float_start;
24+
if (is_float16) {
25+
// For f16, we use u32 atomics where each u32 stores 2 f16 values
26+
// offset_var is the f16 index, so we need to:
27+
// 1. Calculate u32_offset = offset_var / 2
28+
// 2. Determine which half: offset_var % 2
29+
// 3. Update the appropriate half
30+
ss_float_start << " {\n"
31+
<< " let u32_offset = " << offset_var << " / 2u;\n"
32+
<< " let is_lower_half = (" << offset_var << " % 2u) == 0u;\n"
33+
<< " var oldValue = 0u;\n"
34+
<< " loop {\n"
35+
<< " let oldVec = unpack2x16float(oldValue);\n"
36+
<< " let oldF16 = f16(select(oldVec.y, oldVec.x, is_lower_half));\n"
37+
<< " let newValueF16 = ";
38+
} else {
39+
ss_float_start << " {\n"
40+
<< " var oldValue = 0" << (is_unsigned_integer ? "u" : "") << ";\n"
41+
<< " loop {\n"
42+
<< " let newValueF32 = ";
43+
}
44+
45+
std::ostringstream ss_float_end;
46+
if (is_float16) {
47+
ss_float_end << ";\n"
48+
<< " let updatedVec = select(\n"
49+
<< " vec2<f32>(oldVec.x, f32(newValueF16)),\n"
50+
<< " vec2<f32>(f32(newValueF16), oldVec.y),\n"
51+
<< " is_lower_half\n"
52+
<< " );\n"
53+
<< " let newValue = pack2x16float(updatedVec);\n"
54+
<< " let res = atomicCompareExchangeWeak(&" << base_ptr << "[u32_offset], oldValue, newValue);\n"
55+
<< " if res.exchanged {\n"
56+
<< " break;\n"
57+
<< " }\n"
58+
<< " oldValue = res.old_value;\n"
59+
<< " }\n"
60+
<< " }\n";
61+
} else {
62+
ss_float_end << ";\n"
63+
<< " let newValue = bitcast<" << (is_unsigned_integer ? "u32" : "i32") << ">(newValueF32);\n"
64+
<< " let res = atomicCompareExchangeWeak(&" << base_ptr << "[" << offset_var << "], oldValue, newValue);\n"
65+
<< " if res.exchanged {\n"
66+
<< " break;\n"
67+
<< " }\n"
68+
<< " oldValue = res.old_value;\n"
69+
<< " }\n"
70+
<< " }\n";
71+
}
72+
73+
switch (reduction) {
74+
case ScatterElementsReduction::Add:
75+
if (is_32_bit_integer) {
76+
ss << " atomicAdd(&" << base_ptr << "[" << offset_var << "], bitcast<" << data_type << ">(" << value << "));\n";
77+
} else if (is_float16) {
78+
ss << ss_float_start.str() << "oldF16 + (" << value << ")" << ss_float_end.str();
79+
} else {
80+
ss << ss_float_start.str() << "bitcast<" << data_type << ">(oldValue) + (" << value << ")" << ss_float_end.str();
81+
}
82+
break;
83+
case ScatterElementsReduction::Mul:
84+
if (is_float16) {
85+
ss << ss_float_start.str() << "(oldF16 * (" << value << "))" << ss_float_end.str();
86+
} else {
87+
ss << ss_float_start.str() << "(bitcast<" << data_type << ">(oldValue) * (" << value << "))" << ss_float_end.str();
88+
}
89+
break;
90+
case ScatterElementsReduction::Min:
91+
if (is_32_bit_integer) {
92+
ss << " atomicMin(&" << base_ptr << "[" << offset_var << "], bitcast<" << data_type << ">(" << value << "));\n";
93+
} else if (is_float16) {
94+
ss << ss_float_start.str() << "min(oldF16, (" << value << "))" << ss_float_end.str();
95+
} else {
96+
ss << ss_float_start.str() << "min(bitcast<" << data_type << ">(oldValue), (" << value << "))" << ss_float_end.str();
97+
}
98+
break;
99+
case ScatterElementsReduction::Max:
100+
if (is_32_bit_integer) {
101+
ss << " atomicMax(&" << base_ptr << "[" << offset_var << "], bitcast<" << data_type << ">(" << value << "));\n";
102+
} else if (is_float16) {
103+
ss << ss_float_start.str() << "max(oldF16, (" << value << "))" << ss_float_end.str();
104+
} else {
105+
ss << ss_float_start.str() << "max(bitcast<" << data_type << ">(oldValue), (" << value << "))" << ss_float_end.str();
106+
}
107+
break;
108+
default:
109+
ORT_THROW("Unsupported reduction type: ", static_cast<int>(reduction));
110+
}
111+
return ss.str();
112+
};
113+
114+
// Determine data type string for atomic operations
115+
std::string data_type_str;
116+
bool reducible = false;
117+
if (data_type_ == DataTypeImpl::GetType<int32_t>()) {
118+
reducible = true;
119+
data_type_str = "i32";
120+
} else if (data_type_ == DataTypeImpl::GetType<uint32_t>()) {
121+
reducible = true;
122+
data_type_str = "u32";
123+
} else if (data_type_ == DataTypeImpl::GetType<float>()) {
124+
reducible = true;
125+
data_type_str = "f32";
126+
} else if (data_type_ == DataTypeImpl::GetType<MLFloat16>()) {
127+
reducible = true;
128+
data_type_str = "f16";
129+
} else {
130+
data_type_str = "output_value_t";
131+
}
132+
133+
if (reduction_ != ScatterElementsReduction::None && !reducible) {
134+
ORT_THROW("ScatterElements: Reduction is not supported for data type ", data_type_str);
135+
}
136+
137+
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size");
138+
139+
// Convert linear index to multi-dimensional indices using indices OffsetToIndices
140+
shader.MainFunctionBody() << " // Calculate output indices from global_idx\n"
141+
<< " let update_indices = " << indices.OffsetToIndices("global_idx") << ";\n";
142+
143+
// Get the scatter index from indices tensor
144+
shader.MainFunctionBody() << " // Get the scatter index\n"
145+
<< " var idx = i32(" << indices.GetByOffset("global_idx") << ");\n";
146+
147+
// Handle negative indices
148+
shader.MainFunctionBody() << " // Handle negative indices\n"
149+
<< " if (idx < 0) {\n"
150+
<< " idx = idx + i32(uniforms.axis_dim_limit);\n"
151+
<< " }\n";
152+
153+
// Bounds checking
154+
shader.MainFunctionBody() << " // Bounds checking\n"
155+
<< " if (idx < 0 || idx >= i32(uniforms.axis_dim_limit)) {\n"
156+
<< " return;\n"
157+
<< " }\n";
158+
159+
// Build output indices by replacing the axis dimension with the scatter index
160+
shader.MainFunctionBody() << " // Build output indices\n"
161+
<< " var output_indices = update_indices;\n"
162+
<< output.IndicesSet("output_indices", std::to_string(axis_), "u32(idx)") << ";\n";
163+
164+
// Get update value and scatter
165+
shader.MainFunctionBody() << " let update_value = " << updates.GetByOffset("global_idx") << ";\n";
166+
shader.MainFunctionBody() << " let output_offset = " << output.IndicesToOffset("output_indices") << ";\n";
167+
168+
// Handle reduction
169+
if (reduction_ == ScatterElementsReduction::None) {
170+
// Non-reduction path: use direct assignment
171+
shader.MainFunctionBody() << " " << output.SetByOffset("output_offset", "update_value") << ";\n";
172+
} else {
173+
// Reduction path: use atomic operations
174+
shader.MainFunctionBody() << atomic_reduction_snippet(reduction_, "output", "output_offset", "update_value", data_type_str);
175+
}
176+
177+
return Status::OK();
178+
}
179+
180+
Status ScatterElements::ComputeInternal(ComputeContext& context) const {
181+
const Tensor* input = context.Input<Tensor>(0);
182+
const Tensor* indices = context.Input<Tensor>(1);
183+
const Tensor* updates = context.Input<Tensor>(2);
184+
185+
const auto& input_shape = input->Shape();
186+
const auto& indices_shape = indices->Shape();
187+
const auto& updates_shape = updates->Shape();
188+
189+
const int64_t input_rank = static_cast<int64_t>(input_shape.NumDimensions());
190+
const int64_t axis = axis_ < 0 ? axis_ + input_rank : axis_;
191+
192+
// Validate axis
193+
ORT_RETURN_IF_NOT(axis >= 0 && axis < input_rank, "axis ", axis_, " is out of bounds for tensor of rank ", input_rank);
194+
195+
// Validate shapes
196+
ORT_RETURN_IF_NOT(indices_shape.NumDimensions() == updates_shape.NumDimensions(),
197+
"Indices and updates must have the same rank");
198+
199+
for (size_t i = 0; i < indices_shape.NumDimensions(); ++i) {
200+
ORT_RETURN_IF_NOT(indices_shape[i] == updates_shape[i],
201+
"Indices and updates dimensions must match at position ", i);
202+
}
203+
204+
auto* output = context.Output(0, input_shape);
205+
206+
// Copy input to output if not in-place
207+
const void* source = input->DataRaw();
208+
void* target = output->MutableDataRaw();
209+
if (target != source) {
210+
ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input, *output));
211+
}
212+
213+
// Early return if indices/updates are empty
214+
if (indices_shape.Size() == 0) {
215+
return Status::OK();
216+
}
217+
218+
const uint32_t output_size = onnxruntime::narrow<uint32_t>(indices_shape.Size());
219+
const uint32_t axis_dim_limit = onnxruntime::narrow<uint32_t>(input_shape[static_cast<size_t>(axis)]);
220+
221+
MLDataType data_type = input->DataType();
222+
ScatterElementsProgram program(axis, reduction_, data_type);
223+
224+
program
225+
.CacheHint(std::to_string(axis) + "_" + std::to_string(static_cast<uint32_t>(reduction_)))
226+
.AddInputs({{indices, ProgramTensorMetadataDependency::TypeAndRank},
227+
{updates, ProgramTensorMetadataDependency::TypeAndRank}})
228+
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
229+
.AddUniformVariables({output_size, axis_dim_limit});
230+
231+
// Use atomic output if reduction is enabled and data type supports it
232+
// Note: f16 uses atomic<u32> for reductions (packing 2 f16 values per u32)
233+
if (reduction_ != ScatterElementsReduction::None &&
234+
(data_type == DataTypeImpl::GetType<float>() ||
235+
data_type == DataTypeImpl::GetType<MLFloat16>() ||
236+
data_type == DataTypeImpl::GetType<int32_t>() ||
237+
data_type == DataTypeImpl::GetType<uint32_t>())) {
238+
program.AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, ProgramOutput::Atomic});
239+
} else {
240+
program.AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank});
241+
}
242+
243+
return context.RunProgram(program);
244+
}
245+
246+
// Register kernels for different opset versions
247+
248+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
249+
ScatterElements,
250+
kOnnxDomain,
251+
11,
252+
12,
253+
kWebGpuExecutionProvider,
254+
(*KernelDefBuilder::Create())
255+
.TypeConstraint("T", WebGpuSupportedNumberTypes())
256+
.TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<TypeList<int32_t, int64_t>>())
257+
.MayInplace(0, 0),
258+
ScatterElements);
259+
260+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
261+
ScatterElements,
262+
kOnnxDomain,
263+
13,
264+
15,
265+
kWebGpuExecutionProvider,
266+
(*KernelDefBuilder::Create())
267+
.TypeConstraint("T", WebGpuSupportedNumberTypes())
268+
.TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<TypeList<int32_t, int64_t>>())
269+
.MayInplace(0, 0),
270+
ScatterElements);
271+
272+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
273+
ScatterElements,
274+
kOnnxDomain,
275+
16,
276+
17,
277+
kWebGpuExecutionProvider,
278+
(*KernelDefBuilder::Create())
279+
.TypeConstraint("T", WebGpuSupportedNumberTypes())
280+
.TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<TypeList<int32_t, int64_t>>())
281+
.MayInplace(0, 0),
282+
ScatterElements);
283+
284+
ONNX_OPERATOR_KERNEL_EX(
285+
ScatterElements,
286+
kOnnxDomain,
287+
18,
288+
kWebGpuExecutionProvider,
289+
(*KernelDefBuilder::Create())
290+
.TypeConstraint("T", WebGpuSupportedNumberTypes())
291+
.TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<TypeList<int32_t, int64_t>>())
292+
.MayInplace(0, 0),
293+
ScatterElements);
294+
295+
} // namespace webgpu
296+
} // namespace onnxruntime
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include "core/providers/webgpu/webgpu_supported_types.h"
6+
#include "core/providers/webgpu/webgpu_kernel.h"
7+
#include "core/providers/webgpu/program.h"
8+
9+
namespace onnxruntime {
10+
namespace webgpu {
11+
12+
enum class ScatterElementsReduction : int {
13+
None = 0,
14+
Add = 1,
15+
Mul = 2,
16+
Min = 3,
17+
Max = 4,
18+
};
19+
20+
class ScatterElementsProgram final : public Program<ScatterElementsProgram> {
21+
public:
22+
ScatterElementsProgram(int64_t axis, ScatterElementsReduction reduction, MLDataType data_type)
23+
: Program{"ScatterElements"}, axis_(axis), reduction_(reduction), data_type_(data_type) {}
24+
25+
Status GenerateShaderCode(ShaderHelper& sh) const override;
26+
27+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32},
28+
{"axis_dim_limit", ProgramUniformVariableDataType::Uint32});
29+
30+
private:
31+
int64_t axis_;
32+
ScatterElementsReduction reduction_;
33+
MLDataType data_type_;
34+
};
35+
36+
class ScatterElements : public WebGpuKernel {
37+
public:
38+
ScatterElements(const OpKernelInfo& info) : WebGpuKernel(info) {
39+
ORT_ENFORCE(info.GetAttr<int64_t>("axis", &axis_).IsOK(),
40+
"Missing/Invalid 'axis' attribute value");
41+
42+
std::string reduction = info.GetAttrOrDefault<std::string>("reduction", "none");
43+
if (reduction == "add") {
44+
reduction_ = ScatterElementsReduction::Add;
45+
} else if (reduction == "mul") {
46+
reduction_ = ScatterElementsReduction::Mul;
47+
} else if (reduction == "min") {
48+
reduction_ = ScatterElementsReduction::Min;
49+
} else if (reduction == "max") {
50+
reduction_ = ScatterElementsReduction::Max;
51+
} else if (reduction == "none") {
52+
reduction_ = ScatterElementsReduction::None;
53+
} else {
54+
ORT_THROW("Reduction '", reduction, "' is not supported.");
55+
}
56+
}
57+
58+
Status ComputeInternal(ComputeContext& context) const override;
59+
60+
private:
61+
int64_t axis_;
62+
ScatterElementsReduction reduction_{ScatterElementsReduction::None};
63+
};
64+
65+
} // namespace webgpu
66+
} // namespace onnxruntime

0 commit comments

Comments
 (0)