Skip to content

Commit da9b0a5

Browse files
authored
[webgpu] Support int64 for range (#26673)
### Description - Add new registerInt64Ops option to WebGpuExecutionProviderConfig - Int64 support now enabled when enable_graph_capture OR register_int64_ops is true - Refactor Range kernel registration to support conditional int64 registration - Update kernel registry caching to handle all 4 combinations of flags - Rename parameters from enable_graph_capture to enable_int64 for clarity - Add config parsing in webgpu_provider_factory.cc for registerInt64Ops option ### Motivation Needed by updating position id with an onnx model in genai. Continuous decoding mode: `position_ids[i] = i + total_length - new_kv_length` We can use an onnx model which includes a Range op to implement update the position ids: Inputs: start (total_length - new_kv_length), limit (total_length), delta (1) Output: position_ids (1D tensor of size new_kv_length)
1 parent 85dddea commit da9b0a5

File tree

8 files changed

+120
-50
lines changed

8 files changed

+120
-50
lines changed

onnxruntime/core/providers/webgpu/generator/range.cc

Lines changed: 70 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,53 +24,93 @@ Status Range<T>::ComputeInternal(ComputeContext& context) const {
2424
}
2525

2626
uint32_t output_size = onnxruntime::narrow<uint32_t>(n);
27-
RangeProgram program{};
28-
#if defined(__GNUC__)
29-
#pragma GCC diagnostic push
30-
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
31-
#endif
27+
RangeProgram program{output_tensor->GetElementType()};
28+
29+
// For int64, we need to ensure values fit in int32 range since we use 4 bytes in uniforms
30+
uint32_t start_u32, delta_u32;
31+
if constexpr (std::is_same_v<T, int64_t>) {
32+
// Check if values fit in int32 range
33+
ORT_ENFORCE(start >= std::numeric_limits<int32_t>::min() && start <= std::numeric_limits<int32_t>::max(),
34+
"Range start value ", start, " is out of int32 range");
35+
ORT_ENFORCE(delta >= std::numeric_limits<int32_t>::min() && delta <= std::numeric_limits<int32_t>::max(),
36+
"Range delta value ", delta, " is out of int32 range");
37+
int32_t start_i32 = static_cast<int32_t>(start);
38+
int32_t delta_i32 = static_cast<int32_t>(delta);
39+
start_u32 = std::bit_cast<uint32_t>(start_i32);
40+
delta_u32 = std::bit_cast<uint32_t>(delta_i32);
41+
} else {
42+
start_u32 = std::bit_cast<uint32_t>(start);
43+
delta_u32 = std::bit_cast<uint32_t>(delta);
44+
}
3245

3346
program.AddOutput({output_tensor, ProgramTensorMetadataDependency::Type})
3447
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
3548
.AddUniformVariables({
3649
output_size,
37-
*reinterpret_cast<uint32_t*>(&start),
38-
*reinterpret_cast<uint32_t*>(&delta),
50+
start_u32,
51+
delta_u32,
3952
});
4053

41-
#if defined(__GNUC__)
42-
#pragma GCC diagnostic pop
43-
#endif
44-
4554
return context.RunProgram(program);
4655
}
4756

4857
Status RangeProgram::GenerateShaderCode(ShaderHelper& sh) const {
4958
const auto& output = sh.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
5059

51-
sh.MainFunctionBody() << sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
52-
<< " let value = bitcast<output_value_t>(uniforms.start) + output_value_t(global_idx) * bitcast<output_value_t>(uniforms.delta);\n"
53-
<< output.SetByOffset("global_idx", "value");
60+
sh.MainFunctionBody() << sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size");
61+
62+
// For int64, we need to cast to i32 first, then assign to output (which handles vec2<u32> conversion)
63+
// For int32 and float, we can use output_value_t directly
64+
if (data_type_ == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
65+
// int64 case: bitcast to i32, compute with i32, then assign (automatic conversion to vec2<u32>)
66+
sh.MainFunctionBody() << " let value = bitcast<i32>(uniforms.start) + i32(global_idx) * bitcast<i32>(uniforms.delta);\n"
67+
<< output.SetByOffset("global_idx", "value");
68+
} else {
69+
// float or int32 case: use output_value_t
70+
sh.MainFunctionBody() << " let value = bitcast<output_value_t>(uniforms.start) + output_value_t(global_idx) * bitcast<output_value_t>(uniforms.delta);\n"
71+
<< output.SetByOffset("global_idx", "value");
72+
}
5473

5574
return Status();
5675
}
5776

58-
#define WEBGPU_RANGE_KERNEL(TYPE) \
59-
ONNX_OPERATOR_TYPED_KERNEL_EX( \
60-
Range, \
61-
kOnnxDomain, \
62-
11, \
63-
TYPE, \
64-
kWebGpuExecutionProvider, \
65-
KernelDefBuilder() \
66-
.TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()) \
67-
.InputMemoryType(OrtMemTypeCPU, 0) \
68-
.InputMemoryType(OrtMemTypeCPU, 1) \
69-
.InputMemoryType(OrtMemTypeCPU, 2), \
70-
Range<TYPE>);
71-
72-
WEBGPU_RANGE_KERNEL(float)
73-
WEBGPU_RANGE_KERNEL(int32_t)
77+
// Explicit template instantiations (needed for linking)
78+
template class Range<float>;
79+
template class Range<int32_t>;
80+
template class Range<int64_t>;
81+
82+
void RegisterRangeKernels(KernelRegistry& kernel_registry, bool enable_int64) {
83+
// Helper lambda to create kernel
84+
auto create_range_kernel_info = [](auto type_tag) {
85+
using T = decltype(type_tag);
86+
KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status {
87+
out = std::make_unique<Range<T>>(info);
88+
return Status::OK();
89+
};
90+
91+
return KernelCreateInfo(
92+
KernelDefBuilder()
93+
.SetName("Range")
94+
.SetDomain(kOnnxDomain)
95+
.SinceVersion(11)
96+
.Provider(kWebGpuExecutionProvider)
97+
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>())
98+
.InputMemoryType(OrtMemTypeCPU, 0)
99+
.InputMemoryType(OrtMemTypeCPU, 1)
100+
.InputMemoryType(OrtMemTypeCPU, 2)
101+
.Build(),
102+
kernel_create_fn);
103+
};
104+
105+
// Always register float and int32_t
106+
ORT_THROW_IF_ERROR(kernel_registry.Register(create_range_kernel_info(float{})));
107+
ORT_THROW_IF_ERROR(kernel_registry.Register(create_range_kernel_info(int32_t{})));
108+
109+
// Register int64_t only if int64 support is enabled
110+
if (enable_int64) {
111+
ORT_THROW_IF_ERROR(kernel_registry.Register(create_range_kernel_info(int64_t{})));
112+
}
113+
}
74114

75115
} // namespace webgpu
76116
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/generator/range.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#pragma once
55

6+
#include "core/framework/kernel_registry.h"
67
#include "core/providers/webgpu/webgpu_kernel.h"
78

89
namespace onnxruntime {
@@ -19,13 +20,20 @@ class Range : public WebGpuKernel {
1920
class RangeProgram : public Program<RangeProgram> {
2021
public:
2122
RangeProgram() : Program{"Range"} {}
23+
RangeProgram(int32_t data_type) : Program{"Range"}, data_type_(data_type) {}
2224

2325
Status GenerateShaderCode(ShaderHelper& sh) const override;
2426

2527
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32},
2628
{"start", ProgramUniformVariableDataType::Uint32},
2729
{"delta", ProgramUniformVariableDataType::Uint32});
30+
31+
private:
32+
int32_t data_type_{0};
2833
};
2934

35+
// Register Range kernels with conditional int64 support
36+
void RegisterRangeKernels(KernelRegistry& kernel_registry, bool enable_int64);
37+
3038
} // namespace webgpu
3139
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/tensor/cast.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ Status CastProgram::GenerateShaderCode(ShaderHelper& sh) const {
110110
}
111111

112112
template <int StartVersion, int EndVersion>
113-
KernelCreateInfo CreateCastKernelInfo(bool enable_graph_capture) {
114-
const auto& type_constraints = CastOpTypeConstraints(enable_graph_capture);
113+
KernelCreateInfo CreateCastKernelInfo(bool enable_int64) {
114+
const auto& type_constraints = CastOpTypeConstraints(enable_int64);
115115

116116
KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status {
117117
out = std::make_unique<Cast>(info);

onnxruntime/core/providers/webgpu/tensor/cast.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ class Cast final : public WebGpuKernel {
4040
int32_t to_;
4141
};
4242

43-
// Create Cast kernel info with appropriate type constraints based on graph capture support
43+
// Create Cast kernel info with appropriate type constraints based on int64 support
4444
template <int StartVersion, int EndVersion = StartVersion>
45-
KernelCreateInfo CreateCastKernelInfo(bool enable_graph_capture);
45+
KernelCreateInfo CreateCastKernelInfo(bool enable_int64);
4646

4747
} // namespace webgpu
4848
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "core/providers/webgpu/external_data_loader.h"
3030
#include "core/providers/webgpu/webgpu_profiler.h"
3131
#include "core/providers/webgpu/tensor/cast.h"
32+
#include "core/providers/webgpu/generator/range.h"
3233

3334
namespace onnxruntime {
3435

@@ -390,9 +391,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInt
390391
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 22, InstanceNormalization);
391392
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 22, InstanceNormalization);
392393

393-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, float, Range);
394-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, int32_t, Range);
395-
396394
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, Einsum);
397395

398396
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 2, 10, Pad);
@@ -436,7 +434,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD
436434
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
437435
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ScatterElements);
438436

439-
std::unique_ptr<KernelRegistry> RegisterKernels(bool enable_graph_capture = false) {
437+
std::unique_ptr<KernelRegistry> RegisterKernels(bool enable_graph_capture = false, bool enable_int64 = false) {
440438
auto kernel_registry = std::make_unique<onnxruntime::KernelRegistry>();
441439

442440
static const BuildKernelCreateInfoFn function_table[] = {
@@ -746,9 +744,6 @@ std::unique_ptr<KernelRegistry> RegisterKernels(bool enable_graph_capture = fals
746744
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 22, InstanceNormalization)>,
747745
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 22, InstanceNormalization)>,
748746

749-
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, float, Range)>,
750-
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, int32_t, Range)>,
751-
752747
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, Einsum)>,
753748

754749
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 2, 10, Pad)>,
@@ -801,13 +796,16 @@ std::unique_ptr<KernelRegistry> RegisterKernels(bool enable_graph_capture = fals
801796
}
802797
}
803798

804-
// Register Cast kernels with conditional int64 support based on graph capture
805-
ORT_THROW_IF_ERROR(kernel_registry->Register(CreateCastKernelInfo<6, 8>(enable_graph_capture)));
806-
ORT_THROW_IF_ERROR(kernel_registry->Register(CreateCastKernelInfo<9, 12>(enable_graph_capture)));
807-
ORT_THROW_IF_ERROR(kernel_registry->Register(CreateCastKernelInfo<13, 18>(enable_graph_capture)));
808-
ORT_THROW_IF_ERROR(kernel_registry->Register(CreateCastKernelInfo<19, 20>(enable_graph_capture)));
809-
ORT_THROW_IF_ERROR(kernel_registry->Register(CreateCastKernelInfo<21, 22>(enable_graph_capture)));
810-
ORT_THROW_IF_ERROR(kernel_registry->Register(CreateCastKernelInfo<23>(enable_graph_capture)));
799+
// Register Cast kernels with conditional int64 support
800+
ORT_THROW_IF_ERROR(kernel_registry->Register(CreateCastKernelInfo<6, 8>(enable_int64)));
801+
ORT_THROW_IF_ERROR(kernel_registry->Register(CreateCastKernelInfo<9, 12>(enable_int64)));
802+
ORT_THROW_IF_ERROR(kernel_registry->Register(CreateCastKernelInfo<13, 18>(enable_int64)));
803+
ORT_THROW_IF_ERROR(kernel_registry->Register(CreateCastKernelInfo<19, 20>(enable_int64)));
804+
ORT_THROW_IF_ERROR(kernel_registry->Register(CreateCastKernelInfo<21, 22>(enable_int64)));
805+
ORT_THROW_IF_ERROR(kernel_registry->Register(CreateCastKernelInfo<23>(enable_int64)));
806+
807+
// Register Range kernels with conditional int64 support
808+
RegisterRangeKernels(*kernel_registry, enable_int64);
811809

812810
#ifndef DISABLE_CONTRIB_OPS
813811
Status status = ::onnxruntime::contrib::webgpu::RegisterWebGpuContribKernels(*kernel_registry, enable_graph_capture);
@@ -830,6 +828,7 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
830828
preferred_data_layout_{config.data_layout},
831829
force_cpu_node_names_{std::move(config.force_cpu_node_names)},
832830
enable_graph_capture_{config.enable_graph_capture},
831+
enable_int64_{config.enable_graph_capture || config.enable_int64},
833832
prepack_allocator_{std::make_shared<webgpu::GpuBufferAllocator>(context_.InitializerBufferManager(), false)} {
834833
// If graph capture is enabled, create a dedicated buffer manager for graph mode
835834
if (enable_graph_capture_) {
@@ -952,11 +951,16 @@ std::vector<std::unique_ptr<ComputeCapability>> WebGpuExecutionProvider::GetCapa
952951
}
953952

954953
std::shared_ptr<KernelRegistry> WebGpuExecutionProvider::GetKernelRegistry() const {
954+
// Cache registries based on enable_graph_capture_ and enable_int64_ flags
955+
// Note: enable_int64_ is always true when enable_graph_capture_ is true
955956
if (enable_graph_capture_) {
956-
static std::shared_ptr<KernelRegistry> registry = webgpu::RegisterKernels(true);
957+
static std::shared_ptr<KernelRegistry> registry = webgpu::RegisterKernels(true, true);
958+
return registry;
959+
} else if (enable_int64_) {
960+
static std::shared_ptr<KernelRegistry> registry = webgpu::RegisterKernels(false, true);
957961
return registry;
958962
} else {
959-
static std::shared_ptr<KernelRegistry> registry = webgpu::RegisterKernels(false);
963+
static std::shared_ptr<KernelRegistry> registry = webgpu::RegisterKernels(false, false);
960964
return registry;
961965
}
962966
}

onnxruntime/core/providers/webgpu/webgpu_execution_provider.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ struct WebGpuExecutionProviderConfig {
3434
DataLayout data_layout{DataLayout::NHWC}; // preferred layout is NHWC by default
3535
bool enable_graph_capture{false}; // graph capture feature is disabled by default
3636
bool enable_pix_capture{false}; // PIX capture is disabled by default
37+
bool enable_int64{false}; // int64 ops are not enabled by default
3738
std::vector<std::string> force_cpu_node_names{};
3839
};
3940

@@ -92,6 +93,7 @@ class WebGpuExecutionProvider : public IExecutionProvider {
9293
DataLayout preferred_data_layout_;
9394
std::vector<std::string> force_cpu_node_names_;
9495
bool enable_graph_capture_ = false;
96+
bool enable_int64_ = false;
9597
bool is_graph_captured_ = false;
9698
int regular_run_count_before_graph_capture_ = 0;
9799
const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations.

onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ WebGpuExecutionProviderConfig ParseEpConfig(const ConfigOptions& config_options)
6161
}
6262
}
6363

64+
std::string enable_int64_str;
65+
if (config_options.TryGetConfigEntry(kEnableInt64, enable_int64_str)) {
66+
if (enable_int64_str == kEnableInt64_ON) {
67+
webgpu_ep_config.enable_int64 = true;
68+
} else if (enable_int64_str == kEnableInt64_OFF) {
69+
webgpu_ep_config.enable_int64 = false;
70+
} else {
71+
ORT_THROW("Invalid enableInt64 value: ", enable_int64_str);
72+
}
73+
}
74+
6475
// parse force CPU node names
6576
// The force CPU node names are separated by EOL (\n or \r\n) in the config entry.
6677
// each line is a node name that will be forced to run on CPU.
@@ -96,6 +107,7 @@ WebGpuExecutionProviderConfig ParseEpConfig(const ConfigOptions& config_options)
96107
LOGS_DEFAULT(VERBOSE) << "WebGPU EP graph capture enable: " << webgpu_ep_config.enable_graph_capture;
97108
LOGS_DEFAULT(VERBOSE) << "WebGPU EP force CPU node count: " << webgpu_ep_config.force_cpu_node_names.size();
98109
LOGS_DEFAULT(VERBOSE) << "WebGPU EP pix capture enable: " << webgpu_ep_config.enable_pix_capture;
110+
LOGS_DEFAULT(VERBOSE) << "WebGPU EP enable int64: " << webgpu_ep_config.enable_int64;
99111

100112
return webgpu_ep_config;
101113
}

onnxruntime/core/providers/webgpu/webgpu_provider_options.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ namespace options {
1111

1212
constexpr const char* kPreferredLayout = "ep.webgpuexecutionprovider.preferredLayout";
1313
constexpr const char* kEnableGraphCapture = "ep.webgpuexecutionprovider.enableGraphCapture";
14+
constexpr const char* kEnableInt64 = "ep.webgpuexecutionprovider.enableInt64";
1415

1516
constexpr const char* kDawnProcTable = "ep.webgpuexecutionprovider.dawnProcTable";
1617

@@ -49,6 +50,9 @@ constexpr const char* kPreferredLayout_NHWC = "NHWC";
4950
constexpr const char* kEnableGraphCapture_ON = "1";
5051
constexpr const char* kEnableGraphCapture_OFF = "0";
5152

53+
constexpr const char* kEnableInt64_ON = "1";
54+
constexpr const char* kEnableInt64_OFF = "0";
55+
5256
constexpr const char* kEnablePIXCapture_ON = "1";
5357
constexpr const char* kEnablePIXCapture_OFF = "0";
5458

0 commit comments

Comments
 (0)