Skip to content

fix(wgpu): flush staging buffers periodically during bulk writes#1204

Open
holg wants to merge 3 commits intotracel-ai:mainfrom
holg:fix/wgpu-write-buffer-flush
Open

fix(wgpu): flush staging buffers periodically during bulk writes#1204
holg wants to merge 3 commits intotracel-ai:mainfrom
holg:fix/wgpu-write-buffer-flush

Conversation

@holg
Copy link

@holg holg commented Feb 28, 2026

Closes #1120

Summary

queue.write_buffer() is async — wgpu copies data into an internal staging buffer, then
transfers to GPU on the next queue.submit(). Without periodic submits, hundreds of writes
accumulate and staging buffers get recycled before the GPU copy completes, silently corrupting
early tensors.

This PR adds a pending_write_count counter to WgpuStream. After every 64 consecutive
write_buffer calls, the stream submits an empty command buffer to flush all pending staging
work and waits for completion.

Review feedback addressed

  • Uses a separate empty encoder for the write flush instead of swapping the main compute
    encoder. This avoids desync between tasks_count / submission_load bookkeeping and the
    actual encoder state.
  • device.poll(Wait) is gated behind #[cfg(not(target_family = "wasm"))] to match the
    existing SubmissionLoad pattern. On wasm, queue.submit() alone still flushes staging
    buffers.

Test plan

  • cargo test -p cubecl-wgpu --features msl --lib -- tests_msl::f32_ty::unary — 22/22
    pass (no regressions)
  • End-to-end model loading with 700+ tensors on wgpu/Metal — no corruption

queue.write_buffer() is async — wgpu copies data into an internal
staging buffer, then transfers to GPU on the next queue.submit().
When loading large models with hundreds of tensors, staging buffers
accumulate without being submitted, and can get recycled before the
GPU copy completes, silently corrupting early tensor data.

Add a pending_write_count counter to WgpuStream. After every 64
consecutive write_buffer calls, submit the current command buffer
and poll the device to completion. This ensures staging buffers
are freed before reuse.

Discovered while loading a 36-layer transformer (~700 parameter
tensors, 11GB) on wgpu/Metal where early tensors were silently
corrupted, producing garbage inference output.
Copilot AI review requested due to automatic review settings February 28, 2026 14:07
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses silent data corruption on Metal/wgpu during bulk tensor/buffer creation by periodically flushing queued queue.write_buffer() work so wgpu’s internal staging buffers aren’t recycled before GPU copies complete.

Changes:

  • Added pending_write_count to WgpuStream to track consecutive write_buffer calls without a submission.
  • After 64 pending writes, forces a submit and waits for completion via device.poll(...) to ensure staging copies finish.
  • Resets the pending-write counter on regular flush().

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +344 to +353
// Prevent wgpu staging buffer pool exhaustion during bulk writes (e.g. model
// loading with hundreds of tensors). queue.write_buffer() is async — wgpu
// copies data into an internal staging buffer, then transfers to GPU on the
// next queue.submit(). Without periodic submits, hundreds of writes accumulate
// and staging buffers get recycled before the GPU copy completes, silently
// corrupting early tensors.
// See: https://github.com/tracel-ai/cubecl/issues/1120
const MAX_PENDING_WRITES: usize = 64;

if self.pending_write_count >= MAX_PENDING_WRITES {
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change introduces a new correctness/ordering workaround specific to bulk write_buffer usage (including a tuned threshold and an explicit submit+wait). There doesn’t appear to be a regression test covering the failure mode from #1120 (many consecutive tensor/buffer writes with validation via readback). Adding a targeted test (likely under the existing tests_msl / wgpu testgen setup) would help prevent future refactors from reintroducing silent corruption.

Copilot uses AI. Check for mistakes.
Comment on lines +368 to +372
if let Err(e) = self.device.poll(wgpu::PollType::Wait {
submission_index: Some(index),
timeout: None,
}) {
log::warn!("wgpu: write flush poll timed out ({e})");
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device.poll(PollType::Wait { .. }) is called unconditionally in the bulk-write flush path. The rest of the crate explicitly avoids blocking polling on target_family = "wasm" (see compute/poll.rs and the wasm SubmissionLoad impl), so this will likely either error repeatedly or be ineffective on wasm builds. Consider gating the wait behind cfg(not(target_family = "wasm")) and using a non-blocking alternative (or just the queue.submit flush) on wasm, and avoid wording the log as a timeout if the error can be non-timeout (e.g. unsupported wait).

Suggested change
if let Err(e) = self.device.poll(wgpu::PollType::Wait {
submission_index: Some(index),
timeout: None,
}) {
log::warn!("wgpu: write flush poll timed out ({e})");
#[cfg(not(target_family = "wasm"))]
{
if let Err(e) = self.device.poll(wgpu::PollType::Wait {
submission_index: Some(index),
timeout: None,
}) {
log::warn!("wgpu: write flush poll failed ({e})");
}

Copilot uses AI. Check for mistakes.
Comment on lines +357 to +365
// Submit the current command buffer to flush all pending write_buffer work.
let encoder = std::mem::replace(
&mut self.encoder,
self.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("CubeCL Write Flush Encoder"),
}),
);
let index = self.queue.submit([encoder.finish()]);
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The periodic write flush submits encoder.finish() but doesn’t update the normal submission bookkeeping (tasks_count, submission_load.regulate, mem_manage.memory_cleanup/release_uniforms). If this flush ever happens while there are encoded tasks in the current encoder, the stream’s internal counters can drift from what has actually been submitted, which can cause extra empty submits and mis-regulate GPU waiting. Consider factoring this into a shared “submit encoder” helper that also updates the same accounting as flush() (or ensure this path can only run when tasks_count == 0).

Suggested change
// Submit the current command buffer to flush all pending write_buffer work.
let encoder = std::mem::replace(
&mut self.encoder,
self.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("CubeCL Write Flush Encoder"),
}),
);
let index = self.queue.submit([encoder.finish()]);
// Submit a fresh, empty command buffer to flush all pending write_buffer work.
// wgpu flushes its internal staging-buffer copies on any queue.submit(),
// so we don't need to touch the main compute encoder here.
let write_flush_encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("CubeCL Write Flush Encoder"),
});
let index = self.queue.submit([write_flush_encoder.finish()]);

Copilot uses AI. Check for mistakes.
holg and others added 2 commits March 2, 2026 19:38
- Use a separate empty encoder for the write flush instead of swapping
  the main compute encoder. This avoids desync between tasks_count /
  submission_load bookkeeping and the actual encoder state.
- Gate device.poll(Wait) behind #[cfg(not(target_family = "wasm"))]
  to match the existing SubmissionLoad pattern. On wasm, queue.submit()
  alone still flushes staging buffers.
Copy link
Member

@louisfd louisfd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bulk tensor creation corrupts early tensors on Metal/wgpu

3 participants