-
Notifications
You must be signed in to change notification settings - Fork 253
Support UpdatePositionIds in webgpu's interface #1952
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Adds WebGPU support for updating position_ids via the DeviceInterface::UpdatePositionIds hook, enabling device-side updates during continuous decoding without always falling back to the CPU path.
Changes:
- Implement
InterfaceImpl::UpdatePositionIdsfor the WebGPU device interface. - Generate
position_idson CPU and upload to WebGPU device memory via ORTCopyTensors. - Return
falsefor non-continuous decoding (batch_beam_size != 1) to trigger the existing CPU fallback.
| // Get WebGPU allocator's memory info | ||
| const OrtMemoryInfo* webgpu_mem_info = nullptr; | ||
| Ort::ThrowOnError(Ort::api->AllocatorGetInfo(ort_allocator_, &webgpu_mem_info)); | ||
|
|
||
| // Create CPU memory info | ||
| auto cpu_mem_info = OrtMemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
UpdatePositionIds is called once per decode step (often with new_kv_length == 1), but this implementation allocates a new std::vector and fetches allocator info on every call. This adds avoidable CPU overhead in the hot loop. Consider special-casing new_kv_length == 1 to use a small stack buffer (or std::array<... ,1>) and caching the OrtMemoryInfo* for the WebGPU allocator after InitOrt() (if its lifetime is stable), so the per-token path avoids heap allocation and repeated AllocatorGetInfo calls.
| if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { | ||
| // Generate int32 position_ids on CPU | ||
| std::vector<int32_t> cpu_data(new_kv_length); | ||
| for (int i = 0; i < new_kv_length; i++) { | ||
| cpu_data[i] = static_cast<int32_t>(start + i); | ||
| } | ||
|
|
||
| // Create source tensor (CPU memory) | ||
| auto src_tensor = OrtValue::CreateTensor(*cpu_mem_info, cpu_data.data(), new_kv_length * sizeof(int32_t), shape, type); | ||
|
|
||
| // Create destination tensor (WebGPU device memory) | ||
| auto dst_tensor = OrtValue::CreateTensor(*webgpu_mem_info, position_ids, new_kv_length * sizeof(int32_t), shape, type); | ||
|
|
||
| // Copy from CPU to GPU using CopyTensors | ||
| OrtValue* src_ptrs[] = {src_tensor.get()}; | ||
| OrtValue* dst_ptrs[] = {dst_tensor.get()}; | ||
| Ort::ThrowOnError(Ort::api->CopyTensors(&GetOrtEnv(), src_ptrs, dst_ptrs, nullptr, 1)); | ||
| } else { | ||
| // Generate int64 position_ids on CPU | ||
| std::vector<int64_t> cpu_data(new_kv_length); | ||
| for (int i = 0; i < new_kv_length; i++) { | ||
| cpu_data[i] = static_cast<int64_t>(start + i); | ||
| } | ||
|
|
||
| // Create source tensor (CPU memory) | ||
| auto src_tensor = OrtValue::CreateTensor(*cpu_mem_info, cpu_data.data(), new_kv_length * sizeof(int64_t), shape, type); | ||
|
|
||
| // Create destination tensor (WebGPU device memory) | ||
| auto dst_tensor = OrtValue::CreateTensor(*webgpu_mem_info, position_ids, new_kv_length * sizeof(int64_t), shape, type); | ||
|
|
||
| // Copy from CPU to GPU using CopyTensors | ||
| OrtValue* src_ptrs[] = {src_tensor.get()}; | ||
| OrtValue* dst_ptrs[] = {dst_tensor.get()}; | ||
| Ort::ThrowOnError(Ort::api->CopyTensors(&GetOrtEnv(), src_ptrs, dst_ptrs, nullptr, 1)); | ||
| } |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The int32/int64 branches here duplicate the same flow (fill CPU buffer -> wrap src/dst tensors -> CopyTensors) and differ only by element type/size. Refactoring to a small templated helper (or using Ort::SizeOf(type) plus a typed fill) would reduce duplication and make future changes less error-prone.
No description provided.