Skip to content

Commit 954bb7b

Browse files
authored
[WebGPU] allow async shader compilation (#25941)
### Description Reduce the time blocked waiting for the shader to be compiled. ### Motivation and Context Try to optimize the responsiveness of the application when running ort-web in main thread. See #25882
1 parent 6e19cbd commit 954bb7b

File tree

3 files changed

+31
-12
lines changed

3 files changed

+31
-12
lines changed

onnxruntime/core/providers/webgpu/program_manager.cc

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "core/providers/webgpu/program_manager.h"
1212
#include "core/providers/webgpu/shader_helper.h"
13+
#include "core/providers/webgpu/webgpu_context.h"
1314

1415
namespace onnxruntime {
1516
namespace webgpu {
@@ -22,7 +23,7 @@ ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeli
2223
Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const {
2324
ORT_RETURN_IF(x == 0 || y == 0 || z == 0, "Invalid dispatch group size (", x, ", ", y, ", ", z, ")");
2425

25-
auto limit_per_dimension = limits_.maxComputeWorkgroupsPerDimension;
26+
auto limit_per_dimension = webgpu_context_.DeviceLimits().maxComputeWorkgroupsPerDimension;
2627
if (x > limit_per_dimension || y > limit_per_dimension || z > limit_per_dimension) {
2728
double size = static_cast<double>(x) * static_cast<double>(y) * static_cast<double>(z);
2829
double dispatch_avg = std::ceil(std::sqrt(size));
@@ -39,7 +40,7 @@ Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint
3940
}
4041

4142
Status ProgramManager::CalculateSegmentsForInputsAndOutputs(ProgramBase& program) {
42-
const uint64_t maxStorageBufferBindingSize = limits_.maxStorageBufferBindingSize;
43+
const uint64_t maxStorageBufferBindingSize = webgpu_context_.DeviceLimits().maxStorageBufferBindingSize;
4344

4445
// Inputs
4546
for (size_t i = 0; i < program.Inputs().size(); ++i) {
@@ -70,10 +71,11 @@ Status ProgramManager::Build(const ProgramBase& program,
7071
uint32_t normalized_dispatch_z,
7172
wgpu::ComputePipeline& compute_pipeline,
7273
std::vector<int>& shape_uniform_ranks) const {
74+
auto& device = webgpu_context_.Device();
7375
ShaderHelper shader_helper{program,
7476
program_metadata,
75-
device_,
76-
limits_,
77+
device,
78+
webgpu_context_.DeviceLimits(),
7779
normalized_dispatch_x,
7880
normalized_dispatch_y,
7981
normalized_dispatch_z};
@@ -110,7 +112,7 @@ Status ProgramManager::Build(const ProgramBase& program,
110112
wgpu::ShaderModuleDescriptor descriptor{};
111113
descriptor.nextInChain = &wgsl_source;
112114

113-
auto shader_module = device_.CreateShaderModule(&descriptor);
115+
auto shader_module = device.CreateShaderModule(&descriptor);
114116

115117
// TODO: a new cache hierarchy for constants.
116118
//
@@ -186,9 +188,26 @@ Status ProgramManager::Build(const ProgramBase& program,
186188
pipeline_descriptor.label = program.Name().c_str();
187189
#endif
188190

189-
compute_pipeline = device_.CreateComputePipeline(&pipeline_descriptor);
190-
191-
return Status();
191+
struct CreateComputePipelineContext {
192+
wgpu::ComputePipeline& pipeline;
193+
Status status;
194+
} create_pipeline_context{compute_pipeline, {}};
195+
196+
ORT_RETURN_IF_ERROR(
197+
webgpu_context_.Wait(
198+
device.CreateComputePipelineAsync(
199+
&pipeline_descriptor,
200+
wgpu::CallbackMode::WaitAnyOnly,
201+
[](wgpu::CreatePipelineAsyncStatus status, wgpu::ComputePipeline pipeline, wgpu::StringView message, CreateComputePipelineContext* context) {
202+
if (status == wgpu::CreatePipelineAsyncStatus::Success) {
203+
context->pipeline = std::move(pipeline);
204+
} else {
205+
context->status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create a WebGPU compute pipeline: ", std::string_view{message});
206+
}
207+
},
208+
&create_pipeline_context)));
209+
210+
return create_pipeline_context.status;
192211
}
193212

194213
const ProgramArtifact* ProgramManager::Get(const std::string& key) const {

onnxruntime/core/providers/webgpu/program_manager.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ namespace onnxruntime {
1616
class Tensor;
1717

1818
namespace webgpu {
19+
class WebGpuContext;
1920

2021
class ProgramArtifact {
2122
public:
@@ -34,7 +35,7 @@ class ProgramArtifact {
3435

3536
class ProgramManager {
3637
public:
37-
ProgramManager(const wgpu::Device& device, const wgpu::Limits& limits) : device_(device), limits_(limits) {}
38+
ProgramManager(WebGpuContext& webgpu_context) : webgpu_context_(webgpu_context) {}
3839

3940
Status NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const;
4041
Status CalculateSegmentsForInputsAndOutputs(ProgramBase& program);
@@ -54,8 +55,7 @@ class ProgramManager {
5455

5556
private:
5657
std::unordered_map<std::string, ProgramArtifact> programs_;
57-
const wgpu::Device& device_;
58-
const wgpu::Limits& limits_;
58+
WebGpuContext& webgpu_context_;
5959
};
6060

6161
} // namespace webgpu

onnxruntime/core/providers/webgpu/webgpu_context.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
145145
BufferCacheMode::Disabled);
146146

147147
// create program manager
148-
program_mgr_ = std::make_unique<ProgramManager>(Device(), DeviceLimits());
148+
program_mgr_ = std::make_unique<ProgramManager>(*this);
149149

150150
// set query type
151151
#if !defined(__wasm__)

0 commit comments

Comments
 (0)