Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions crates/cubecl-wgpu/src/compute/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ pub struct WgpuStream {
encoder: wgpu::CommandEncoder,
poll: WgpuPoll,
submission_load: SubmissionLoad,
/// Number of consecutive `write_buffer` calls without a `queue.submit()`.
/// Used to prevent wgpu staging buffer pool exhaustion during bulk writes
/// (e.g. model loading with hundreds of tensors).
pending_write_count: usize,
}

impl WgpuStream {
Expand Down Expand Up @@ -82,6 +86,7 @@ impl WgpuStream {
tasks_max,
poll,
submission_load: SubmissionLoad::default(),
pending_write_count: 0,
}
}

Expand Down Expand Up @@ -333,6 +338,41 @@ impl WgpuStream {
.expect("Internal error: Failed to call `write_buffer_with`, this likely means no staging buffer could be allocated.");
buffer[0..data.len()].copy_from_slice(data);
}

self.pending_write_count += 1;

// Prevent wgpu staging buffer pool exhaustion during bulk writes (e.g. model
// loading with hundreds of tensors). queue.write_buffer() is async — wgpu
// copies data into an internal staging buffer, then transfers to GPU on the
// next queue.submit(). Without periodic submits, hundreds of writes accumulate
// and staging buffers get recycled before the GPU copy completes, silently
// corrupting early tensors.
// See: https://github.com/tracel-ai/cubecl/issues/1120
const MAX_PENDING_WRITES: usize = 64;

if self.pending_write_count >= MAX_PENDING_WRITES {
Comment on lines +400 to +409
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change introduces a new correctness/ordering workaround specific to bulk write_buffer usage (including a tuned threshold and an explicit submit+wait). There doesn’t appear to be a regression test covering the failure mode from #1120 (many consecutive tensor/buffer writes with validation via readback). Adding a targeted test (likely under the existing tests_msl / wgpu testgen setup) would help prevent future refactors from reintroducing silent corruption.

Copilot uses AI. Check for mistakes.
// End any active compute pass before submitting.
self.compute_pass = None;

// Submit the current command buffer to flush all pending write_buffer work.
let encoder = std::mem::replace(
&mut self.encoder,
self.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("CubeCL Write Flush Encoder"),
}),
);
let index = self.queue.submit([encoder.finish()]);
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The periodic write flush submits encoder.finish() but doesn’t update the normal submission bookkeeping (tasks_count, submission_load.regulate, mem_manage.memory_cleanup/release_uniforms). If this flush ever happens while there are encoded tasks in the current encoder, the stream’s internal counters can drift from what has actually been submitted, which can cause extra empty submits and mis-regulate GPU waiting. Consider factoring this into a shared “submit encoder” helper that also updates the same accounting as flush() (or ensure this path can only run when tasks_count == 0).

Suggested change
// Submit the current command buffer to flush all pending write_buffer work.
let encoder = std::mem::replace(
&mut self.encoder,
self.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("CubeCL Write Flush Encoder"),
}),
);
let index = self.queue.submit([encoder.finish()]);
// Submit a fresh, empty command buffer to flush all pending write_buffer work.
// wgpu flushes its internal staging-buffer copies on any queue.submit(),
// so we don't need to touch the main compute encoder here.
let write_flush_encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("CubeCL Write Flush Encoder"),
});
let index = self.queue.submit([write_flush_encoder.finish()]);

Copilot uses AI. Check for mistakes.

// Wait for the GPU to finish processing these writes before continuing.
if let Err(e) = self.device.poll(wgpu::PollType::Wait {
submission_index: Some(index),
timeout: None,
}) {
log::warn!("wgpu: write flush poll timed out ({e})");
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device.poll(PollType::Wait { .. }) is called unconditionally in the bulk-write flush path. The rest of the crate explicitly avoids blocking polling on target_family = "wasm" (see compute/poll.rs and the wasm SubmissionLoad impl), so this will likely either error repeatedly or be ineffective on wasm builds. Consider gating the wait behind cfg(not(target_family = "wasm")) and using a non-blocking alternative (or just the queue.submit flush) on wasm, and avoid wording the log as a timeout if the error can be non-timeout (e.g. unsupported wait).

Suggested change
if let Err(e) = self.device.poll(wgpu::PollType::Wait {
submission_index: Some(index),
timeout: None,
}) {
log::warn!("wgpu: write flush poll timed out ({e})");
#[cfg(not(target_family = "wasm"))]
{
if let Err(e) = self.device.poll(wgpu::PollType::Wait {
submission_index: Some(index),
timeout: None,
}) {
log::warn!("wgpu: write flush poll failed ({e})");
}

Copilot uses AI. Check for mistakes.
}
self.pending_write_count = 0;
}
}

fn flush_if_needed(&mut self) {
Expand Down Expand Up @@ -374,6 +414,7 @@ impl WgpuStream {
self.mem_manage.release_uniforms();

self.tasks_count = 0;
self.pending_write_count = 0;
}

fn register_pipeline<'a>(
Expand Down
Loading