-
Notifications
You must be signed in to change notification settings - Fork 160
fix(wgpu): flush staging buffers periodically during bulk writes #1204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -35,6 +35,10 @@ pub struct WgpuStream { | |||||||||||||||||||||||||||||||||||||
| encoder: wgpu::CommandEncoder, | ||||||||||||||||||||||||||||||||||||||
| poll: WgpuPoll, | ||||||||||||||||||||||||||||||||||||||
| submission_load: SubmissionLoad, | ||||||||||||||||||||||||||||||||||||||
| /// Number of consecutive `write_buffer` calls without a `queue.submit()`. | ||||||||||||||||||||||||||||||||||||||
| /// Used to prevent wgpu staging buffer pool exhaustion during bulk writes | ||||||||||||||||||||||||||||||||||||||
| /// (e.g. model loading with hundreds of tensors). | ||||||||||||||||||||||||||||||||||||||
| pending_write_count: usize, | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| impl WgpuStream { | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -82,6 +86,7 @@ impl WgpuStream { | |||||||||||||||||||||||||||||||||||||
| tasks_max, | ||||||||||||||||||||||||||||||||||||||
| poll, | ||||||||||||||||||||||||||||||||||||||
| submission_load: SubmissionLoad::default(), | ||||||||||||||||||||||||||||||||||||||
| pending_write_count: 0, | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
@@ -333,6 +338,41 @@ impl WgpuStream { | |||||||||||||||||||||||||||||||||||||
| .expect("Internal error: Failed to call `write_buffer_with`, this likely means no staging buffer could be allocated."); | ||||||||||||||||||||||||||||||||||||||
| buffer[0..data.len()].copy_from_slice(data); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| self.pending_write_count += 1; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| // Prevent wgpu staging buffer pool exhaustion during bulk writes (e.g. model | ||||||||||||||||||||||||||||||||||||||
| // loading with hundreds of tensors). queue.write_buffer() is async — wgpu | ||||||||||||||||||||||||||||||||||||||
| // copies data into an internal staging buffer, then transfers to GPU on the | ||||||||||||||||||||||||||||||||||||||
| // next queue.submit(). Without periodic submits, hundreds of writes accumulate | ||||||||||||||||||||||||||||||||||||||
| // and staging buffers get recycled before the GPU copy completes, silently | ||||||||||||||||||||||||||||||||||||||
| // corrupting early tensors. | ||||||||||||||||||||||||||||||||||||||
| // See: https://github.com/tracel-ai/cubecl/issues/1120 | ||||||||||||||||||||||||||||||||||||||
| const MAX_PENDING_WRITES: usize = 64; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if self.pending_write_count >= MAX_PENDING_WRITES { | ||||||||||||||||||||||||||||||||||||||
| // End any active compute pass before submitting. | ||||||||||||||||||||||||||||||||||||||
| self.compute_pass = None; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| // Submit the current command buffer to flush all pending write_buffer work. | ||||||||||||||||||||||||||||||||||||||
| let encoder = std::mem::replace( | ||||||||||||||||||||||||||||||||||||||
| &mut self.encoder, | ||||||||||||||||||||||||||||||||||||||
| self.device | ||||||||||||||||||||||||||||||||||||||
| .create_command_encoder(&wgpu::CommandEncoderDescriptor { | ||||||||||||||||||||||||||||||||||||||
| label: Some("CubeCL Write Flush Encoder"), | ||||||||||||||||||||||||||||||||||||||
| }), | ||||||||||||||||||||||||||||||||||||||
| ); | ||||||||||||||||||||||||||||||||||||||
| let index = self.queue.submit([encoder.finish()]); | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
| // Submit the current command buffer to flush all pending write_buffer work. | |
| let encoder = std::mem::replace( | |
| &mut self.encoder, | |
| self.device | |
| .create_command_encoder(&wgpu::CommandEncoderDescriptor { | |
| label: Some("CubeCL Write Flush Encoder"), | |
| }), | |
| ); | |
| let index = self.queue.submit([encoder.finish()]); | |
| // Submit a fresh, empty command buffer to flush all pending write_buffer work. | |
| // wgpu flushes its internal staging-buffer copies on any queue.submit(), | |
| // so we don't need to touch the main compute encoder here. | |
| let write_flush_encoder = self | |
| .device | |
| .create_command_encoder(&wgpu::CommandEncoderDescriptor { | |
| label: Some("CubeCL Write Flush Encoder"), | |
| }); | |
| let index = self.queue.submit([write_flush_encoder.finish()]); |
Outdated
Copilot
AI
Feb 28, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device.poll(PollType::Wait { .. }) is called unconditionally in the bulk-write flush path. The rest of the crate explicitly avoids blocking polling on target_family = "wasm" (see compute/poll.rs and the wasm SubmissionLoad impl), so this will likely either error repeatedly or be ineffective on wasm builds. Consider gating the wait behind cfg(not(target_family = "wasm")) and using a non-blocking alternative (or just the queue.submit flush) on wasm, and avoid wording the log as a timeout if the error can be non-timeout (e.g. unsupported wait).
| if let Err(e) = self.device.poll(wgpu::PollType::Wait { | |
| submission_index: Some(index), | |
| timeout: None, | |
| }) { | |
| log::warn!("wgpu: write flush poll timed out ({e})"); | |
| #[cfg(not(target_family = "wasm"))] | |
| { | |
| if let Err(e) = self.device.poll(wgpu::PollType::Wait { | |
| submission_index: Some(index), | |
| timeout: None, | |
| }) { | |
| log::warn!("wgpu: write flush poll failed ({e})"); | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change introduces a new correctness/ordering workaround specific to bulk
write_bufferusage (including a tuned threshold and an explicit submit+wait). There doesn’t appear to be a regression test covering the failure mode from #1120 (many consecutive tensor/buffer writes with validation via readback). Adding a targeted test (likely under the existingtests_msl/ wgpu testgen setup) would help prevent future refactors from reintroducing silent corruption.