@@ -176,6 +176,67 @@ struct InterfaceImpl : DeviceInterface {
176176 std::unique_ptr<Search> CreateBeam (const GeneratorParams& params) override { return std::make_unique<BeamSearch_Cpu>(params); }
177177
178178 void Synchronize () override {} // Nothing to do?
179+
180+ bool UpdatePositionIds (void * position_ids, int batch_beam_size, int total_length, int new_kv_length, ONNXTensorElementDataType type) override {
181+ if (!ort_allocator_) {
182+ throw std::runtime_error (" WebGPU allocator not initialized" );
183+ }
184+
185+ // Only support continuous decoding mode (batch_beam_size == 1)
186+ // For batch mode, fall back to CPU implementation
187+ if (batch_beam_size != 1 ) {
188+ return false ;
189+ }
190+
191+ // Get WebGPU allocator's memory info
192+ const OrtMemoryInfo* webgpu_mem_info = nullptr ;
193+ Ort::ThrowOnError (Ort::api->AllocatorGetInfo (ort_allocator_, &webgpu_mem_info));
194+
195+ // Create CPU memory info
196+ auto cpu_mem_info = OrtMemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeDefault);
197+
198+ // Compute position_ids on CPU: position_ids[i] = start + i
199+ int start = total_length - new_kv_length;
200+ std::array<int64_t , 1 > shape{static_cast <int64_t >(new_kv_length)};
201+
202+ if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
203+ // Generate int32 position_ids on CPU
204+ std::vector<int32_t > cpu_data (new_kv_length);
205+ for (int i = 0 ; i < new_kv_length; i++) {
206+ cpu_data[i] = static_cast <int32_t >(start + i);
207+ }
208+
209+ // Create source tensor (CPU memory)
210+ auto src_tensor = OrtValue::CreateTensor (*cpu_mem_info, cpu_data.data (), new_kv_length * sizeof (int32_t ), shape, type);
211+
212+ // Create destination tensor (WebGPU device memory)
213+ auto dst_tensor = OrtValue::CreateTensor (*webgpu_mem_info, position_ids, new_kv_length * sizeof (int32_t ), shape, type);
214+
215+ // Copy from CPU to GPU using CopyTensors
216+ OrtValue* src_ptrs[] = {src_tensor.get ()};
217+ OrtValue* dst_ptrs[] = {dst_tensor.get ()};
218+ Ort::ThrowOnError (Ort::api->CopyTensors (&GetOrtEnv (), src_ptrs, dst_ptrs, nullptr , 1 ));
219+ } else {
220+ // Generate int64 position_ids on CPU
221+ std::vector<int64_t > cpu_data (new_kv_length);
222+ for (int i = 0 ; i < new_kv_length; i++) {
223+ cpu_data[i] = static_cast <int64_t >(start + i);
224+ }
225+
226+ // Create source tensor (CPU memory)
227+ auto src_tensor = OrtValue::CreateTensor (*cpu_mem_info, cpu_data.data (), new_kv_length * sizeof (int64_t ), shape, type);
228+
229+ // Create destination tensor (WebGPU device memory)
230+ auto dst_tensor = OrtValue::CreateTensor (*webgpu_mem_info, position_ids, new_kv_length * sizeof (int64_t ), shape, type);
231+
232+ // Copy from CPU to GPU using CopyTensors
233+ OrtValue* src_ptrs[] = {src_tensor.get ()};
234+ OrtValue* dst_ptrs[] = {dst_tensor.get ()};
235+ Ort::ThrowOnError (Ort::api->CopyTensors (&GetOrtEnv (), src_ptrs, dst_ptrs, nullptr , 1 ));
236+ }
237+
238+ return true ;
239+ }
179240};
180241
181242} // namespace WebGPU
0 commit comments