Skip to content

Commit 7219877

Browse files
remove mid pipeline copy to CPU with indirect dispatch
1 parent 7a5e686 commit 7219877

3 files changed

Lines changed: 115 additions & 20 deletions

File tree

src/gpu.rs

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,48 @@ impl Gpu {
101101
bytemuck::cast_slice(&bytes).to_vec()
102102
}
103103

104-
pub fn dispatch(
104+
pub fn indirect_buffer(&self, label: &str) -> wgpu::Buffer {
105+
self.device.create_buffer(&wgpu::BufferDescriptor {
106+
label: Some(label),
107+
size: 12,
108+
usage: wgpu::BufferUsages::STORAGE
109+
| wgpu::BufferUsages::INDIRECT
110+
| wgpu::BufferUsages::COPY_SRC
111+
| wgpu::BufferUsages::COPY_DST,
112+
mapped_at_creation: false,
113+
})
114+
}
115+
116+
pub fn read_buffer_at<T: bytemuck::Pod>(&self, buffer: &wgpu::Buffer, byte_offset: u64) -> T {
117+
let size = std::mem::size_of::<T>() as u64;
118+
let staging = self.staging_buffer(size);
119+
120+
let mut encoder = self
121+
.device
122+
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
123+
124+
encoder.copy_buffer_to_buffer(buffer, byte_offset, &staging, 0, size);
125+
self.queue.submit(Some(encoder.finish()));
126+
127+
let slice = staging.slice(..);
128+
let (tx, rx) = std::sync::mpsc::channel();
129+
slice.map_async(wgpu::MapMode::Read, move |result| {
130+
tx.send(result).unwrap();
131+
});
132+
133+
self.device.poll(wgpu::Maintain::Wait);
134+
rx.recv().unwrap().unwrap();
135+
136+
let data = slice.get_mapped_range();
137+
*bytemuck::from_bytes(&data)
138+
}
139+
140+
fn build_pipeline(
105141
&self,
106142
shader_src: &str,
107143
entry_point: &str,
108144
buffers: &[(&wgpu::Buffer, bool)],
109-
workgroups: u32,
110-
) {
145+
) -> (wgpu::ComputePipeline, wgpu::BindGroup) {
111146
let module = self
112147
.device
113148
.create_shader_module(wgpu::ShaderModuleDescriptor {
@@ -171,6 +206,18 @@ impl Gpu {
171206
entries: &bind_group_entries,
172207
});
173208

209+
(pipeline, bind_group)
210+
}
211+
212+
pub fn dispatch(
213+
&self,
214+
shader_src: &str,
215+
entry_point: &str,
216+
buffers: &[(&wgpu::Buffer, bool)],
217+
workgroups: u32,
218+
) {
219+
let (pipeline, bind_group) = self.build_pipeline(shader_src, entry_point, buffers);
220+
174221
let mut encoder = self
175222
.device
176223
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
@@ -185,6 +232,29 @@ impl Gpu {
185232
self.queue.submit(Some(encoder.finish()));
186233
}
187234

235+
pub fn dispatch_indirect(
236+
&self,
237+
shader_src: &str,
238+
entry_point: &str,
239+
buffers: &[(&wgpu::Buffer, bool)],
240+
indirect: &wgpu::Buffer,
241+
) {
242+
let (pipeline, bind_group) = self.build_pipeline(shader_src, entry_point, buffers);
243+
244+
let mut encoder = self
245+
.device
246+
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
247+
248+
{
249+
let mut cpass = encoder.begin_compute_pass(&Default::default());
250+
cpass.set_pipeline(&pipeline);
251+
cpass.set_bind_group(0, &bind_group, &[]);
252+
cpass.dispatch_workgroups_indirect(indirect, 0);
253+
}
254+
255+
self.queue.submit(Some(encoder.finish()));
256+
}
257+
188258
pub fn multi_scan_u32(&self, buf: &wgpu::Buffer, n: usize) {
189259
self.multi_scan_generic(
190260
buf,

src/main.rs

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,27 +55,30 @@ fn parse(gpu: &Gpu, json: &str) -> Vec<TapeEntry> {
5555
n_wg.try_into().unwrap(),
5656
);
5757

58-
let scanned_mask = gpu.read_buffer_as::<u32>(&mask_buf);
59-
let struct_count = scanned_mask[padded_len - 1] as usize;
60-
let struct_wg = struct_count.div_ceil(256).max(1);
61-
let struct_padded = struct_wg * 256;
62-
63-
let depth_buf = gpu.storage_buffer_empty("depth", (struct_padded * 4) as u64);
58+
let indirect_buf = gpu.indirect_buffer("indirect");
6459
gpu.dispatch(
60+
include_str!("shaders/multi/prepare_indirect.wgsl"),
61+
"main",
62+
&[(&mask_buf, true), (&indirect_buf, false)],
63+
1,
64+
);
65+
66+
let depth_buf = gpu.storage_buffer_empty("depth", buf_size(4));
67+
gpu.dispatch_indirect(
6568
include_str!("shaders/multi/map_depth.wgsl"),
6669
"main",
6770
&[
6871
(&input_buf, true),
6972
(&compact_buf, true),
7073
(&depth_buf, false),
7174
],
72-
struct_wg.try_into().unwrap(),
75+
&indirect_buf,
7376
);
7477

75-
gpu.multi_scan_i32(&depth_buf, struct_padded);
78+
gpu.multi_scan_i32(&depth_buf, padded_len);
7679

77-
let parent_buf = gpu.storage_buffer_empty("parents", (struct_padded * 4) as u64);
78-
gpu.dispatch(
80+
let parent_buf = gpu.storage_buffer_empty("parents", buf_size(4));
81+
gpu.dispatch_indirect(
7982
include_str!("shaders/multi/parent_link.wgsl"),
8083
"main",
8184
&[
@@ -84,14 +87,12 @@ fn parse(gpu: &Gpu, json: &str) -> Vec<TapeEntry> {
8487
(&depth_buf, true),
8588
(&parent_buf, false),
8689
],
87-
struct_wg.try_into().unwrap(),
90+
&indirect_buf,
8891
);
8992

90-
let tape_buf = gpu.storage_buffer_empty(
91-
"tape",
92-
(struct_padded * std::mem::size_of::<TapeEntry>()) as u64,
93-
);
94-
gpu.dispatch(
93+
let tape_buf =
94+
gpu.storage_buffer_empty("tape", buf_size(std::mem::size_of::<TapeEntry>()));
95+
gpu.dispatch_indirect(
9596
include_str!("shaders/multi/assemble_tape.wgsl"),
9697
"main",
9798
&[
@@ -101,9 +102,12 @@ fn parse(gpu: &Gpu, json: &str) -> Vec<TapeEntry> {
101102
(&parent_buf, true),
102103
(&tape_buf, false),
103104
],
104-
struct_wg.try_into().unwrap(),
105+
&indirect_buf,
105106
);
106107

108+
let struct_count =
109+
gpu.read_buffer_at::<u32>(&mask_buf, ((padded_len - 1) * 4) as u64) as usize;
110+
107111
let tape = gpu.read_buffer_as::<TapeEntry>(&tape_buf);
108112
tape[..struct_count].to_vec()
109113
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
@group(0) @binding(0)
2+
var<storage, read> mask: array<u32>;
3+
4+
struct Indirect {
5+
x: u32,
6+
y: u32,
7+
z: u32,
8+
}
9+
10+
@group(0) @binding(1)
11+
var<storage, read_write> indirect: Indirect;
12+
13+
@compute
14+
@workgroup_size(1)
15+
fn main() {
16+
let n = arrayLength(&mask);
17+
let struct_count = mask[n - 1u];
18+
indirect.x = (struct_count + 255u) / 256u;
19+
indirect.y = 1u;
20+
indirect.z = 1u;
21+
}

0 commit comments

Comments
 (0)