Skip to content

Commit 8474f5f

Browse files
committed
Support UpdatePositionIds in gpu
1 parent 6cf92ae commit 8474f5f

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

src/webgpu/interface.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,67 @@ struct InterfaceImpl : DeviceInterface {
176176
std::unique_ptr<Search> CreateBeam(const GeneratorParams& params) override { return std::make_unique<BeamSearch_Cpu>(params); }
177177

178178
void Synchronize() override {} // Nothing to do?
179+
180+
bool UpdatePositionIds(void* position_ids, int batch_beam_size, int total_length, int new_kv_length, ONNXTensorElementDataType type) override {
181+
if (!ort_allocator_) {
182+
throw std::runtime_error("WebGPU allocator not initialized");
183+
}
184+
185+
// Only support continuous decoding mode (batch_beam_size == 1)
186+
// For batch mode, fall back to CPU implementation
187+
if (batch_beam_size != 1) {
188+
return false;
189+
}
190+
191+
// Get WebGPU allocator's memory info
192+
const OrtMemoryInfo* webgpu_mem_info = nullptr;
193+
Ort::ThrowOnError(Ort::api->AllocatorGetInfo(ort_allocator_, &webgpu_mem_info));
194+
195+
// Create CPU memory info
196+
auto cpu_mem_info = OrtMemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
197+
198+
// Compute position_ids on CPU: position_ids[i] = start + i
199+
int start = total_length - new_kv_length;
200+
std::array<int64_t, 1> shape{static_cast<int64_t>(new_kv_length)};
201+
202+
if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
203+
// Generate int32 position_ids on CPU
204+
std::vector<int32_t> cpu_data(new_kv_length);
205+
for (int i = 0; i < new_kv_length; i++) {
206+
cpu_data[i] = static_cast<int32_t>(start + i);
207+
}
208+
209+
// Create source tensor (CPU memory)
210+
auto src_tensor = OrtValue::CreateTensor(*cpu_mem_info, cpu_data.data(), new_kv_length * sizeof(int32_t), shape, type);
211+
212+
// Create destination tensor (WebGPU device memory)
213+
auto dst_tensor = OrtValue::CreateTensor(*webgpu_mem_info, position_ids, new_kv_length * sizeof(int32_t), shape, type);
214+
215+
// Copy from CPU to GPU using CopyTensors
216+
OrtValue* src_ptrs[] = {src_tensor.get()};
217+
OrtValue* dst_ptrs[] = {dst_tensor.get()};
218+
Ort::ThrowOnError(Ort::api->CopyTensors(&GetOrtEnv(), src_ptrs, dst_ptrs, nullptr, 1));
219+
} else {
220+
// Generate int64 position_ids on CPU
221+
std::vector<int64_t> cpu_data(new_kv_length);
222+
for (int i = 0; i < new_kv_length; i++) {
223+
cpu_data[i] = static_cast<int64_t>(start + i);
224+
}
225+
226+
// Create source tensor (CPU memory)
227+
auto src_tensor = OrtValue::CreateTensor(*cpu_mem_info, cpu_data.data(), new_kv_length * sizeof(int64_t), shape, type);
228+
229+
// Create destination tensor (WebGPU device memory)
230+
auto dst_tensor = OrtValue::CreateTensor(*webgpu_mem_info, position_ids, new_kv_length * sizeof(int64_t), shape, type);
231+
232+
// Copy from CPU to GPU using CopyTensors
233+
OrtValue* src_ptrs[] = {src_tensor.get()};
234+
OrtValue* dst_ptrs[] = {dst_tensor.get()};
235+
Ort::ThrowOnError(Ort::api->CopyTensors(&GetOrtEnv(), src_ptrs, dst_ptrs, nullptr, 1));
236+
}
237+
238+
return true;
239+
}
179240
};
180241

181242
} // namespace WebGPU

0 commit comments

Comments
 (0)