1010
1111#include " core/providers/webgpu/program_manager.h"
1212#include " core/providers/webgpu/shader_helper.h"
13+ #include " core/providers/webgpu/webgpu_context.h"
1314
1415namespace onnxruntime {
1516namespace webgpu {
@@ -22,7 +23,7 @@ ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeli
2223Status ProgramManager::NormalizeDispatchGroupSize (uint32_t & x, uint32_t & y, uint32_t & z) const {
2324 ORT_RETURN_IF (x == 0 || y == 0 || z == 0 , " Invalid dispatch group size (" , x, " , " , y, " , " , z, " )" );
2425
25- auto limit_per_dimension = limits_ .maxComputeWorkgroupsPerDimension ;
26+ auto limit_per_dimension = webgpu_context_. DeviceLimits () .maxComputeWorkgroupsPerDimension ;
2627 if (x > limit_per_dimension || y > limit_per_dimension || z > limit_per_dimension) {
2728 double size = static_cast <double >(x) * static_cast <double >(y) * static_cast <double >(z);
2829 double dispatch_avg = std::ceil (std::sqrt (size));
@@ -39,7 +40,7 @@ Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint
3940}
4041
4142Status ProgramManager::CalculateSegmentsForInputsAndOutputs (ProgramBase& program) {
42- const uint64_t maxStorageBufferBindingSize = limits_ .maxStorageBufferBindingSize ;
43+ const uint64_t maxStorageBufferBindingSize = webgpu_context_. DeviceLimits () .maxStorageBufferBindingSize ;
4344
4445 // Inputs
4546 for (size_t i = 0 ; i < program.Inputs ().size (); ++i) {
@@ -70,10 +71,11 @@ Status ProgramManager::Build(const ProgramBase& program,
7071 uint32_t normalized_dispatch_z,
7172 wgpu::ComputePipeline& compute_pipeline,
7273 std::vector<int >& shape_uniform_ranks) const {
74+ auto & device = webgpu_context_.Device ();
7375 ShaderHelper shader_helper{program,
7476 program_metadata,
75- device_ ,
76- limits_ ,
77+ device ,
78+ webgpu_context_. DeviceLimits () ,
7779 normalized_dispatch_x,
7880 normalized_dispatch_y,
7981 normalized_dispatch_z};
@@ -110,7 +112,7 @@ Status ProgramManager::Build(const ProgramBase& program,
110112 wgpu::ShaderModuleDescriptor descriptor{};
111113 descriptor.nextInChain = &wgsl_source;
112114
113- auto shader_module = device_ .CreateShaderModule (&descriptor);
115+ auto shader_module = device .CreateShaderModule (&descriptor);
114116
115117 // TODO: a new cache hierarchy for constants.
116118 //
@@ -186,9 +188,26 @@ Status ProgramManager::Build(const ProgramBase& program,
186188 pipeline_descriptor.label = program.Name ().c_str ();
187189#endif
188190
189- compute_pipeline = device_.CreateComputePipeline (&pipeline_descriptor);
190-
191- return Status ();
191+ struct CreateComputePipelineContext {
192+ wgpu::ComputePipeline& pipeline;
193+ Status status;
194+ } create_pipeline_context{compute_pipeline, {}};
195+
196+ ORT_RETURN_IF_ERROR (
197+ webgpu_context_.Wait (
198+ device.CreateComputePipelineAsync (
199+ &pipeline_descriptor,
200+ wgpu::CallbackMode::WaitAnyOnly,
201+ [](wgpu::CreatePipelineAsyncStatus status, wgpu::ComputePipeline pipeline, wgpu::StringView message, CreateComputePipelineContext* context) {
202+ if (status == wgpu::CreatePipelineAsyncStatus::Success) {
203+ context->pipeline = std::move (pipeline);
204+ } else {
205+ context->status = ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " Failed to create a WebGPU compute pipeline: " , std::string_view{message});
206+ }
207+ },
208+ &create_pipeline_context)));
209+
210+ return create_pipeline_context.status ;
192211}
193212
194213const ProgramArtifact* ProgramManager::Get (const std::string& key) const {
0 commit comments