Skip to content

Conversation

@qjia7
Copy link
Contributor

@qjia7 qjia7 commented Jan 22, 2026

No description provided.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds WebGPU support for updating position_ids via the DeviceInterface::UpdatePositionIds hook, enabling device-side updates during continuous decoding without always falling back to the CPU path.

Changes:

  • Implement InterfaceImpl::UpdatePositionIds for the WebGPU device interface.
  • Generate position_ids on CPU and upload to WebGPU device memory via ORT CopyTensors.
  • Return false for non-continuous decoding (batch_beam_size != 1) to trigger the existing CPU fallback.

Comment on lines +222 to +227
// Get WebGPU allocator's memory info
const OrtMemoryInfo* webgpu_mem_info = nullptr;
Ort::ThrowOnError(Ort::api->AllocatorGetInfo(ort_allocator_, &webgpu_mem_info));

// Create CPU memory info
auto cpu_mem_info = OrtMemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UpdatePositionIds is called once per decode step (often with new_kv_length == 1), but this implementation allocates a new std::vector and fetches allocator info on every call. This adds avoidable CPU overhead in the hot loop. Consider special-casing new_kv_length == 1 to use a small stack buffer (or std::array<... ,1>) and caching the OrtMemoryInfo* for the WebGPU allocator after InitOrt() (if its lifetime is stable), so the per-token path avoids heap allocation and repeated AllocatorGetInfo calls.

Copilot uses AI. Check for mistakes.
Comment on lines +233 to +267
if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
// Generate int32 position_ids on CPU
std::vector<int32_t> cpu_data(new_kv_length);
for (int i = 0; i < new_kv_length; i++) {
cpu_data[i] = static_cast<int32_t>(start + i);
}

// Create source tensor (CPU memory)
auto src_tensor = OrtValue::CreateTensor(*cpu_mem_info, cpu_data.data(), new_kv_length * sizeof(int32_t), shape, type);

// Create destination tensor (WebGPU device memory)
auto dst_tensor = OrtValue::CreateTensor(*webgpu_mem_info, position_ids, new_kv_length * sizeof(int32_t), shape, type);

// Copy from CPU to GPU using CopyTensors
OrtValue* src_ptrs[] = {src_tensor.get()};
OrtValue* dst_ptrs[] = {dst_tensor.get()};
Ort::ThrowOnError(Ort::api->CopyTensors(&GetOrtEnv(), src_ptrs, dst_ptrs, nullptr, 1));
} else {
// Generate int64 position_ids on CPU
std::vector<int64_t> cpu_data(new_kv_length);
for (int i = 0; i < new_kv_length; i++) {
cpu_data[i] = static_cast<int64_t>(start + i);
}

// Create source tensor (CPU memory)
auto src_tensor = OrtValue::CreateTensor(*cpu_mem_info, cpu_data.data(), new_kv_length * sizeof(int64_t), shape, type);

// Create destination tensor (WebGPU device memory)
auto dst_tensor = OrtValue::CreateTensor(*webgpu_mem_info, position_ids, new_kv_length * sizeof(int64_t), shape, type);

// Copy from CPU to GPU using CopyTensors
OrtValue* src_ptrs[] = {src_tensor.get()};
OrtValue* dst_ptrs[] = {dst_tensor.get()};
Ort::ThrowOnError(Ort::api->CopyTensors(&GetOrtEnv(), src_ptrs, dst_ptrs, nullptr, 1));
}
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The int32/int64 branches here duplicate the same flow (fill CPU buffer -> wrap src/dst tensors -> CopyTensors) and differ only by element type/size. Refactoring to a small templated helper (or using Ort::SizeOf(type) plus a typed fill) would reduce duplication and make future changes less error-prone.

Copilot uses AI. Check for mistakes.
@qjia7 qjia7 marked this pull request as draft January 27, 2026 09:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants