@@ -255,7 +255,14 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
255255 uint32_t x = program.DispatchGroupSizeX ();
256256 uint32_t y = program.DispatchGroupSizeY ();
257257 uint32_t z = program.DispatchGroupSizeZ ();
258- ORT_RETURN_IF_ERROR (program_mgr_->NormalizeDispatchGroupSize (x, y, z));
258+
259+ // Skip normalization for indirect dispatch since dimensions are determined by the indirect buffer
260+ if (program.IndirectDispatchTensor () == nullptr ) {
261+ ORT_RETURN_IF_ERROR (program_mgr_->NormalizeDispatchGroupSize (x, y, z));
262+ } else {
263+ ORT_ENFORCE (x == 0 && y == 0 && z == 0 ,
264+ " Only one of SetIndirectDispatchTensor and SetDispatchGroupSize should be called for program" , program.Name ());
265+ }
259266
260267 bool is_1d_dispatch = (y == 1 && z == 1 );
261268
@@ -442,7 +449,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
442449 bind_buffers.push_back (uniform_buffer);
443450 }
444451
445- LaunchComputePipeline (compute_pass_encoder, bind_buffers, *program_artifact, x, y, z);
452+ LaunchComputePipeline (compute_pass_encoder, bind_buffers, *program_artifact, x, y, z, program. IndirectDispatchTensor () );
446453 if (uniform_buffer) {
447454 buffer_mgr.Release (uniform_buffer);
448455 }
@@ -722,7 +729,8 @@ void WebGpuContext::OnRunEnd() {
722729void WebGpuContext::LaunchComputePipeline (const wgpu::ComputePassEncoder& compute_pass_encoder,
723730 const std::vector<WGPUBuffer>& bind_buffers,
724731 const ProgramArtifact& program_artifact,
725- uint32_t x, uint32_t y, uint32_t z) {
732+ uint32_t x, uint32_t y, uint32_t z,
733+ const Tensor* indirect_dispatch_tensor) {
726734 uint32_t entry_index = 0 ;
727735 std::vector<WGPUBindGroupEntry> bind_group_entries;
728736 for (WGPUBuffer buffer : bind_buffers) {
@@ -738,14 +746,27 @@ void WebGpuContext::LaunchComputePipeline(const wgpu::ComputePassEncoder& comput
738746
739747 auto bind_group = wgpuDeviceCreateBindGroup (Device ().Get (), &bind_group_desc);
740748 if (graph_capture_state_ == GraphCaptureState::Capturing) {
749+ WGPUBuffer indirect_buffer = nullptr ;
750+ if (indirect_dispatch_tensor != nullptr ) {
751+ indirect_buffer = reinterpret_cast <WGPUBuffer>(const_cast <void *>(indirect_dispatch_tensor->DataRaw ()));
752+ }
741753 external_captured_commands_->push_back ({program_artifact.compute_pipeline ,
742754 bind_group,
743755 bind_group_layout,
744- {x, y, z}});
756+ {x, y, z},
757+ indirect_buffer});
745758 } else {
746759 compute_pass_encoder.SetPipeline (program_artifact.compute_pipeline );
747760 wgpuComputePassEncoderSetBindGroup (compute_pass_encoder.Get (), 0 , bind_group, 0 , nullptr );
748- compute_pass_encoder.DispatchWorkgroups (x, y, z);
761+
762+ if (indirect_dispatch_tensor != nullptr ) {
763+ // Use indirect dispatch
764+ WGPUBuffer indirect_buffer = reinterpret_cast <WGPUBuffer>(const_cast <void *>(indirect_dispatch_tensor->DataRaw ()));
765+ compute_pass_encoder.DispatchWorkgroupsIndirect (indirect_buffer, 0 );
766+ } else {
767+ // Use direct dispatch
768+ compute_pass_encoder.DispatchWorkgroups (x, y, z);
769+ }
749770
750771 wgpuBindGroupRelease (bind_group);
751772 wgpuBindGroupLayoutRelease (bind_group_layout);
@@ -781,7 +802,15 @@ void WebGpuContext::Replay(const std::vector<webgpu::CapturedCommandInfo>& captu
781802 WriteTimestamp (num_pending_dispatches_ * 2 );
782803 compute_pass_encoder.SetPipeline (command.compute_pipeline );
783804 wgpuComputePassEncoderSetBindGroup (compute_pass_encoder.Get (), 0 , command.bind_group , 0 , nullptr );
784- compute_pass_encoder.DispatchWorkgroups (command.dispatch_group [0 ], command.dispatch_group [1 ], command.dispatch_group [2 ]);
805+
806+ if (command.indirect_buffer != nullptr ) {
807+ // Use indirect dispatch
808+ compute_pass_encoder.DispatchWorkgroupsIndirect (command.indirect_buffer , 0 );
809+ } else {
810+ // Use direct dispatch
811+ compute_pass_encoder.DispatchWorkgroups (command.dispatch_group [0 ], command.dispatch_group [1 ], command.dispatch_group [2 ]);
812+ }
813+
785814 WriteTimestamp (num_pending_dispatches_ * 2 + 1 );
786815 ++num_pending_dispatches_;
787816 if (num_pending_dispatches_ >= max_num_pending_dispatches_ ||
0 commit comments