fix(wgpu): flush staging buffers periodically during bulk writes#1204
fix(wgpu): flush staging buffers periodically during bulk writes#1204holg wants to merge 3 commits intotracel-ai:mainfrom
Conversation
queue.write_buffer() is async — wgpu copies data into an internal staging buffer, then transfers to GPU on the next queue.submit(). When loading large models with hundreds of tensors, staging buffers accumulate without being submitted, and can get recycled before the GPU copy completes, silently corrupting early tensor data. Add a pending_write_count counter to WgpuStream. After every 64 consecutive write_buffer calls, submit the current command buffer and poll the device to completion. This ensures staging buffers are freed before reuse. Discovered while loading a 36-layer transformer (~700 parameter tensors, 11GB) on wgpu/Metal where early tensors were silently corrupted, producing garbage inference output.
There was a problem hiding this comment.
Pull request overview
This PR addresses silent data corruption on Metal/wgpu during bulk tensor/buffer creation by periodically flushing queued queue.write_buffer() work so wgpu’s internal staging buffers aren’t recycled before GPU copies complete.
Changes:
- Added
pending_write_counttoWgpuStreamto track consecutivewrite_buffercalls without a submission. - After 64 pending writes, forces a submit and waits for completion via
device.poll(...)to ensure staging copies finish. - Resets the pending-write counter on regular
flush().
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // Prevent wgpu staging buffer pool exhaustion during bulk writes (e.g. model | ||
| // loading with hundreds of tensors). queue.write_buffer() is async — wgpu | ||
| // copies data into an internal staging buffer, then transfers to GPU on the | ||
| // next queue.submit(). Without periodic submits, hundreds of writes accumulate | ||
| // and staging buffers get recycled before the GPU copy completes, silently | ||
| // corrupting early tensors. | ||
| // See: https://github.com/tracel-ai/cubecl/issues/1120 | ||
| const MAX_PENDING_WRITES: usize = 64; | ||
|
|
||
| if self.pending_write_count >= MAX_PENDING_WRITES { |
There was a problem hiding this comment.
This change introduces a new correctness/ordering workaround specific to bulk write_buffer usage (including a tuned threshold and an explicit submit+wait). There doesn’t appear to be a regression test covering the failure mode from #1120 (many consecutive tensor/buffer writes with validation via readback). Adding a targeted test (likely under the existing tests_msl / wgpu testgen setup) would help prevent future refactors from reintroducing silent corruption.
| if let Err(e) = self.device.poll(wgpu::PollType::Wait { | ||
| submission_index: Some(index), | ||
| timeout: None, | ||
| }) { | ||
| log::warn!("wgpu: write flush poll timed out ({e})"); |
There was a problem hiding this comment.
device.poll(PollType::Wait { .. }) is called unconditionally in the bulk-write flush path. The rest of the crate explicitly avoids blocking polling on target_family = "wasm" (see compute/poll.rs and the wasm SubmissionLoad impl), so this will likely either error repeatedly or be ineffective on wasm builds. Consider gating the wait behind cfg(not(target_family = "wasm")) and using a non-blocking alternative (or just the queue.submit flush) on wasm, and avoid wording the log as a timeout if the error can be non-timeout (e.g. unsupported wait).
| if let Err(e) = self.device.poll(wgpu::PollType::Wait { | |
| submission_index: Some(index), | |
| timeout: None, | |
| }) { | |
| log::warn!("wgpu: write flush poll timed out ({e})"); | |
| #[cfg(not(target_family = "wasm"))] | |
| { | |
| if let Err(e) = self.device.poll(wgpu::PollType::Wait { | |
| submission_index: Some(index), | |
| timeout: None, | |
| }) { | |
| log::warn!("wgpu: write flush poll failed ({e})"); | |
| } |
| // Submit the current command buffer to flush all pending write_buffer work. | ||
| let encoder = std::mem::replace( | ||
| &mut self.encoder, | ||
| self.device | ||
| .create_command_encoder(&wgpu::CommandEncoderDescriptor { | ||
| label: Some("CubeCL Write Flush Encoder"), | ||
| }), | ||
| ); | ||
| let index = self.queue.submit([encoder.finish()]); |
There was a problem hiding this comment.
The periodic write flush submits encoder.finish() but doesn’t update the normal submission bookkeeping (tasks_count, submission_load.regulate, mem_manage.memory_cleanup/release_uniforms). If this flush ever happens while there are encoded tasks in the current encoder, the stream’s internal counters can drift from what has actually been submitted, which can cause extra empty submits and mis-regulate GPU waiting. Consider factoring this into a shared “submit encoder” helper that also updates the same accounting as flush() (or ensure this path can only run when tasks_count == 0).
| // Submit the current command buffer to flush all pending write_buffer work. | |
| let encoder = std::mem::replace( | |
| &mut self.encoder, | |
| self.device | |
| .create_command_encoder(&wgpu::CommandEncoderDescriptor { | |
| label: Some("CubeCL Write Flush Encoder"), | |
| }), | |
| ); | |
| let index = self.queue.submit([encoder.finish()]); | |
| // Submit a fresh, empty command buffer to flush all pending write_buffer work. | |
| // wgpu flushes its internal staging-buffer copies on any queue.submit(), | |
| // so we don't need to touch the main compute encoder here. | |
| let write_flush_encoder = self | |
| .device | |
| .create_command_encoder(&wgpu::CommandEncoderDescriptor { | |
| label: Some("CubeCL Write Flush Encoder"), | |
| }); | |
| let index = self.queue.submit([write_flush_encoder.finish()]); |
- Use a separate empty encoder for the write flush instead of swapping the main compute encoder. This avoids desync between tasks_count / submission_load bookkeeping and the actual encoder state. - Gate device.poll(Wait) behind #[cfg(not(target_family = "wasm"))] to match the existing SubmissionLoad pattern. On wasm, queue.submit() alone still flushes staging buffers.
Closes #1120
Summary
queue.write_buffer()is async — wgpu copies data into an internal staging buffer, thentransfers to GPU on the next
queue.submit(). Without periodic submits, hundreds of writesaccumulate and staging buffers get recycled before the GPU copy completes, silently corrupting
early tensors.
This PR adds a
pending_write_countcounter toWgpuStream. After every 64 consecutivewrite_buffercalls, the stream submits an empty command buffer to flush all pending stagingwork and waits for completion.
Review feedback addressed
encoder. This avoids desync between
tasks_count/submission_loadbookkeeping and theactual encoder state.
device.poll(Wait)is gated behind#[cfg(not(target_family = "wasm"))]to match theexisting
SubmissionLoadpattern. On wasm,queue.submit()alone still flushes stagingbuffers.
Test plan
cargo test -p cubecl-wgpu --features msl --lib -- tests_msl::f32_ty::unary— 22/22pass (no regressions)