diff --git a/examples/features/src/ray_aabb_compute/mod.rs b/examples/features/src/ray_aabb_compute/mod.rs index f45167beab8..691ae7ed0ac 100644 --- a/examples/features/src/ray_aabb_compute/mod.rs +++ b/examples/features/src/ray_aabb_compute/mod.rs @@ -288,6 +288,7 @@ impl crate::framework::Example for Example { )), 0, 0xff, + wgpu::IntersectionShaderIndex::QueryData(0), )); let mut encoder = diff --git a/examples/features/src/ray_cube_compute/mod.rs b/examples/features/src/ray_cube_compute/mod.rs index 76c0960ae8c..7371e692317 100644 --- a/examples/features/src/ray_cube_compute/mod.rs +++ b/examples/features/src/ray_cube_compute/mod.rs @@ -341,6 +341,7 @@ impl crate::framework::Example for Example { )), 0, 0xff, + wgpu::IntersectionShaderIndex::QueryData(0), )); } } diff --git a/examples/features/src/ray_cube_fragment/mod.rs b/examples/features/src/ray_cube_fragment/mod.rs index 1a5396609dc..43ab6801ff2 100644 --- a/examples/features/src/ray_cube_fragment/mod.rs +++ b/examples/features/src/ray_cube_fragment/mod.rs @@ -306,7 +306,13 @@ impl crate::framework::Example for Example { .try_into() .unwrap(); - *instance = Some(wgpu::TlasInstance::new(&self.blas, transform, 0, 0xff)); + *instance = Some(wgpu::TlasInstance::new( + &self.blas, + transform, + 0, + 0xff, + wgpu::IntersectionShaderIndex::QueryData(0), + )); } } } diff --git a/examples/features/src/ray_cube_normals/mod.rs b/examples/features/src/ray_cube_normals/mod.rs index 05793d604b2..cdc138407de 100644 --- a/examples/features/src/ray_cube_normals/mod.rs +++ b/examples/features/src/ray_cube_normals/mod.rs @@ -332,6 +332,7 @@ impl crate::framework::Example for Example { )), 0, 0xff, + wgpu::IntersectionShaderIndex::QueryData(0), )); } } diff --git a/examples/features/src/ray_scene/mod.rs b/examples/features/src/ray_scene/mod.rs index b4a1db8d934..99403ad4ee3 100644 --- a/examples/features/src/ray_scene/mod.rs +++ b/examples/features/src/ray_scene/mod.rs @@ -486,6 +486,7 @@ impl crate::framework::Example for Example { transform, blas_index as u32, 0xff, + wgpu::IntersectionShaderIndex::QueryData(0), )); } } diff --git a/examples/features/src/ray_shadows/mod.rs b/examples/features/src/ray_shadows/mod.rs index 6cd76c59a31..5868eb87b3d 100644 --- a/examples/features/src/ray_shadows/mod.rs +++ b/examples/features/src/ray_shadows/mod.rs @@ -223,6 +223,7 @@ impl crate::framework::Example for Example { [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], 0, 0xFF, + wgpu::IntersectionShaderIndex::QueryData(0), )); let mut encoder = diff --git a/examples/features/src/ray_traced_triangle/mod.rs b/examples/features/src/ray_traced_triangle/mod.rs index b154ce9068c..6c4eb565614 100644 --- a/examples/features/src/ray_traced_triangle/mod.rs +++ b/examples/features/src/ray_traced_triangle/mod.rs @@ -165,6 +165,7 @@ impl crate::framework::Example for Example { .unwrap(), 0, 0xff, + wgpu::IntersectionShaderIndex::QueryData(0), )); tlas[1] = Some(TlasInstance::new( @@ -180,6 +181,7 @@ impl crate::framework::Example for Example { .unwrap(), 0, 0xff, + wgpu::IntersectionShaderIndex::QueryData(0), )); tlas[2] = Some(TlasInstance::new( @@ -195,6 +197,7 @@ impl crate::framework::Example for Example { .unwrap(), 0, 0xff, + wgpu::IntersectionShaderIndex::QueryData(0), )); let uniforms = { diff --git a/examples/standalone/custom_backend/src/custom.rs b/examples/standalone/custom_backend/src/custom.rs index a643ee8bd4b..6277c2f4454 100644 --- a/examples/standalone/custom_backend/src/custom.rs +++ b/examples/standalone/custom_backend/src/custom.rs @@ -191,6 +191,13 @@ impl DeviceInterface for CustomDevice { wgpu::custom::DispatchComputePipeline::custom(CustomComputePipeline(module.0.clone())) } + fn create_ray_tracing_pipeline( + &self, + _desc: &wgpu::RayTracingPipelineDescriptor<'_>, + ) -> wgpu::custom::DispatchRayTracingPipeline { + unimplemented!() + } + unsafe fn create_pipeline_cache( &self, _desc: &wgpu::PipelineCacheDescriptor<'_>, diff --git a/player/src/lib.rs b/player/src/lib.rs index e6d7139549a..23da9509571 100644 --- a/player/src/lib.rs +++ b/player/src/lib.rs @@ -46,6 +46,10 @@ pub struct Player { wgc::id::PointerId, Arc, >, + ray_tracing_pipelines: HashMap< + wgc::id::PointerId, + Arc, + >, pipeline_caches: HashMap< wgc::id::PointerId, Arc, @@ -75,6 +79,7 @@ impl Default for Player { render_bundles: HashMap::new(), render_pipelines: HashMap::new(), compute_pipelines: HashMap::new(), + ray_tracing_pipelines: HashMap::new(), pipeline_caches: HashMap::new(), query_sets: HashMap::new(), buffers: HashMap::new(), @@ -230,6 +235,17 @@ impl Player { .expect("invalid compute pipeline"); self.bind_group_layouts.insert(id, bgl); } + Action::GetRayTracingPipelineBindGroupLayout { + id, + pipeline, + index, + } => { + let pipeline = self.resolve_ray_tracing_pipeline_id(pipeline); + let bgl = pipeline + .get_bind_group_layout(index) + .expect("invalid ray tracing pipeline"); + self.bind_group_layouts.insert(id, bgl); + } Action::DestroyBindGroupLayout(id) => { self.bind_group_layouts .remove(&id) @@ -481,6 +497,18 @@ impl Player { Action::DestroyTlas(id) => { self.tlas_s.remove(&id).expect("invalid tlas"); } + Action::CreateRayTracingPipeline { id, desc } => { + let descriptor = self.resolve_ray_tracing_pipeline_descriptor(desc); + let rt_pipeline = device + .create_ray_tracing_pipeline(descriptor) + .expect("create_ray_tracing_pipeline error"); + self.ray_tracing_pipelines.insert(id, rt_pipeline); + } + Action::DestroyRayTracingPipeline(id) => { + self.ray_tracing_pipelines + .remove(&id) + .expect("invalid ray tracing pipeline"); + } } } @@ -600,6 +628,16 @@ impl Player { .clone() } + fn resolve_ray_tracing_pipeline_id( + &self, + id: wgc::id::PointerId, + ) -> Arc { + self.ray_tracing_pipelines + .get(&id) + .expect("invalid ray tracing pipeline") + .clone() + } + fn resolve_pipeline_cache_id( &self, id: wgc::id::PointerId, @@ -744,6 +782,61 @@ impl Player { } } + fn resolve_ray_tracing_pipeline_descriptor<'a>( + &self, + desc: wgc::device::trace::TraceRayTracingPipelineDescriptor<'a>, + ) -> wgc::pipeline::ResolvedRayTracingPipelineDescriptor<'a> { + let layout = desc.layout.map(|id| self.resolve_pipeline_layout_id(id)); + + wgc::pipeline::ResolvedRayTracingPipelineDescriptor { + label: desc.label, + layout, + cache: desc.cache.map(|id| self.resolve_pipeline_cache_id(id)), + ray_generation: wgc::pipeline::ResolvedProgrammableStageDescriptor { + module: self.resolve_shader_module_id(desc.ray_generation.module), + entry_point: desc.ray_generation.entry_point, + constants: desc.ray_generation.constants, + zero_initialize_workgroup_memory: desc + .ray_generation + .zero_initialize_workgroup_memory, + }, + miss: wgc::pipeline::ResolvedProgrammableStageDescriptor { + module: self.resolve_shader_module_id(desc.miss.module), + entry_point: desc.miss.entry_point, + constants: desc.miss.constants, + zero_initialize_workgroup_memory: desc.miss.zero_initialize_workgroup_memory, + }, + intersections: desc + .intersections + .into_iter() + .map(|intersections| match intersections { + wgc::pipeline::RayTracingIntersectionDescriptor::Triangle { + closest_hit, + any_hit, + } => wgc::pipeline::RayTracingIntersectionDescriptor::Triangle { + closest_hit: wgc::pipeline::ResolvedProgrammableStageDescriptor { + module: self.resolve_shader_module_id(closest_hit.module), + entry_point: closest_hit.entry_point, + constants: closest_hit.constants, + zero_initialize_workgroup_memory: closest_hit + .zero_initialize_workgroup_memory, + }, + any_hit: any_hit.map(|any_hit| { + wgc::pipeline::ResolvedProgrammableStageDescriptor { + module: self.resolve_shader_module_id(any_hit.module), + entry_point: any_hit.entry_point, + constants: any_hit.constants, + zero_initialize_workgroup_memory: any_hit + .zero_initialize_workgroup_memory, + } + }), + }, + }) + .collect(), + max_recursion_depth: desc.max_recursion_depth, + } + } + fn resolve_bind_group_descriptor<'a>( &self, desc: wgc::device::trace::TraceBindGroupDescriptor<'a>, @@ -936,6 +1029,9 @@ impl Player { occlusion_query_set: occlusion_query_set.map(|qs| self.resolve_query_set_id(qs)), multiview_mask, }, + Command::RunRayTracingPass { pass } => Command::RunRayTracingPass { + pass: self.resolve_ray_tracing_pass(pass), + }, Command::BuildAccelerationStructures { blas, tlas } => { Command::BuildAccelerationStructures { blas: blas @@ -1031,6 +1127,32 @@ impl Player { } } + fn resolve_ray_tracing_pass( + &self, + pass: BasePass, Infallible>, + ) -> BasePass, Infallible> { + let BasePass { + label, + error, + commands, + dynamic_offsets, + immediates_data, + string_data, + } = pass; + + BasePass { + label, + error, + commands: commands + .into_iter() + .map(|cmd| self.resolve_ray_tracing_command(cmd)) + .collect(), + dynamic_offsets, + immediates_data, + string_data, + } + } + fn resolve_compute_command( &self, command: wgc::command::ComputeCommand, @@ -1248,6 +1370,38 @@ impl Player { } } + fn resolve_ray_tracing_command( + &self, + command: wgc::command::RayTracingCommand, + ) -> wgc::command::RayTracingCommand { + use wgc::command::RayTracingCommand as C; + match command { + C::SetBindGroup { + index, + num_dynamic_offsets, + bind_group, + } => C::SetBindGroup { + index, + num_dynamic_offsets, + bind_group: bind_group.map(|bg| self.resolve_bind_group_id(bg)), + }, + C::SetPipeline(id) => C::SetPipeline(self.resolve_ray_tracing_pipeline_id(id)), + C::SetImmediate { + offset, + size_bytes, + values_offset, + } => C::SetImmediate { + offset, + size_bytes, + values_offset, + }, + C::TraceRays(groups) => C::TraceRays(groups), + C::PushDebugGroup { color, len } => C::PushDebugGroup { color, len }, + C::PopDebugGroup => C::PopDebugGroup, + C::InsertDebugMarker { color, len } => C::InsertDebugMarker { color, len }, + } + } + fn resolve_pass_timestamp_writes( &self, writes: wgc::command::PassTimestampWrites>, @@ -1380,6 +1534,7 @@ impl Player { transform: instance.transform, custom_data: instance.custom_data, mask: instance.mask, + intersection_index: instance.intersection_index, } } diff --git a/tests/tests/wgpu-gpu/binding_array/tlas.rs b/tests/tests/wgpu-gpu/binding_array/tlas.rs index f44f6dddaeb..d87303693fa 100644 --- a/tests/tests/wgpu-gpu/binding_array/tlas.rs +++ b/tests/tests/wgpu-gpu/binding_array/tlas.rs @@ -131,12 +131,14 @@ async fn binding_array_tlas(ctx: TestingContext) { [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], 0, 0xff, + wgpu::IntersectionShaderIndex::QueryData(0), )); tlas_b[0] = Some(wgpu::TlasInstance::new( &blas, [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], 0, 0xff, + wgpu::IntersectionShaderIndex::QueryData(0), )); // Build BLAS and TLASes. diff --git a/tests/tests/wgpu-gpu/ray_tracing/as_aabb.rs b/tests/tests/wgpu-gpu/ray_tracing/as_aabb.rs index d9e173e6671..404325754b8 100644 --- a/tests/tests/wgpu-gpu/ray_tracing/as_aabb.rs +++ b/tests/tests/wgpu-gpu/ray_tracing/as_aabb.rs @@ -86,6 +86,7 @@ fn aabb_blas_build_and_trace(ctx: TestingContext) { [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], 0, 0xFF, + wgpu::IntersectionShaderIndex::QueryData(0), )); let mut encoder = ctx diff --git a/tests/tests/wgpu-gpu/ray_tracing/as_build.rs b/tests/tests/wgpu-gpu/ray_tracing/as_build.rs index 3d5228c02fb..43cc4175d7d 100644 --- a/tests/tests/wgpu-gpu/ray_tracing/as_build.rs +++ b/tests/tests/wgpu-gpu/ray_tracing/as_build.rs @@ -43,6 +43,7 @@ fn unbuilt_blas(ctx: TestingContext) { &ctx, AccelerationStructureFlags::empty(), AccelerationStructureFlags::empty(), + false, ); // Build the TLAS package with an unbuilt BLAS. @@ -77,6 +78,7 @@ fn unbuilt_blas_compaction(ctx: TestingContext) { &ctx, AccelerationStructureFlags::ALLOW_COMPACTION, AccelerationStructureFlags::empty(), + false, ); fail( @@ -104,6 +106,7 @@ fn blas_compaction_without_flags(ctx: TestingContext) { &ctx, AccelerationStructureFlags::empty(), AccelerationStructureFlags::empty(), + false, ); let mut encoder = ctx @@ -140,6 +143,7 @@ fn unprepared_blas_compaction(ctx: TestingContext) { &ctx, AccelerationStructureFlags::ALLOW_COMPACTION, AccelerationStructureFlags::empty(), + false, ); let mut encoder = ctx @@ -168,6 +172,7 @@ fn blas_compaction(ctx: TestingContext) { &ctx, AccelerationStructureFlags::ALLOW_COMPACTION, AccelerationStructureFlags::empty(), + false, ); let mut encoder = ctx @@ -225,6 +230,7 @@ fn out_of_order_as_build(ctx: TestingContext) { &ctx, AccelerationStructureFlags::empty(), AccelerationStructureFlags::empty(), + false, ); // @@ -260,6 +266,7 @@ fn out_of_order_as_build(ctx: TestingContext) { &ctx, AccelerationStructureFlags::empty(), AccelerationStructureFlags::empty(), + false, ); // @@ -312,6 +319,7 @@ fn out_of_order_as_build_use(ctx: TestingContext) { &ctx, AccelerationStructureFlags::empty(), AccelerationStructureFlags::empty(), + false, ); // @@ -404,6 +412,7 @@ fn out_of_order_as_build_use(ctx: TestingContext) { &ctx, AccelerationStructureFlags::empty(), AccelerationStructureFlags::empty(), + false, ); // @@ -556,6 +565,7 @@ fn build_with_transform(ctx: TestingContext) { [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], 0, 0xFF, + wgpu_types::IntersectionShaderIndex::QueryData(0), )); let mut encoder_build = ctx @@ -602,6 +612,7 @@ fn only_blas_vertex_return(ctx: TestingContext) { &ctx, AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN, AccelerationStructureFlags::empty(), + false, ); let mut encoder_blas = ctx @@ -727,6 +738,7 @@ fn only_tlas_vertex_return(ctx: TestingContext) { &ctx, AccelerationStructureFlags::empty(), AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN, + false, ); let mut encoder_blas = ctx diff --git a/tests/tests/wgpu-gpu/ray_tracing/as_use_after_free.rs b/tests/tests/wgpu-gpu/ray_tracing/as_use_after_free.rs index e5213b3e36e..cc816f26bf3 100644 --- a/tests/tests/wgpu-gpu/ray_tracing/as_use_after_free.rs +++ b/tests/tests/wgpu-gpu/ray_tracing/as_use_after_free.rs @@ -65,6 +65,7 @@ fn acceleration_structure_use_after_free(ctx: TestingContext) { [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], 0, 0xFF, + wgpu_types::IntersectionShaderIndex::QueryData(0), )); // Actually build the BLAS. diff --git a/tests/tests/wgpu-gpu/ray_tracing/mod.rs b/tests/tests/wgpu-gpu/ray_tracing/mod.rs index 8219569b26e..d88ce3f4538 100644 --- a/tests/tests/wgpu-gpu/ray_tracing/mod.rs +++ b/tests/tests/wgpu-gpu/ray_tracing/mod.rs @@ -16,6 +16,7 @@ mod as_build; mod as_create; mod as_use_after_free; mod limits; +mod pipelines; mod scene; mod shader; @@ -27,6 +28,7 @@ pub fn all_tests(tests: &mut Vec) { limits::all_tests(tests); scene::all_tests(tests); shader::all_tests(tests); + pipelines::all_tests(tests); } fn acceleration_structure_limits() -> wgpu::Limits { @@ -46,6 +48,7 @@ impl AsBuildContext { ctx: &TestingContext, additional_blas_flags: AccelerationStructureFlags, additional_tlas_flags: AccelerationStructureFlags, + for_ray_tracing_pipeline: bool, ) -> Self { let vertices = ctx.device.create_buffer_init(&BufferInitDescriptor { label: None, @@ -84,6 +87,11 @@ impl AsBuildContext { [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], 0, 0xFF, + if for_ray_tracing_pipeline { + wgpu_types::IntersectionShaderIndex::IntersectionIndex(0) + } else { + wgpu_types::IntersectionShaderIndex::QueryData(0) + }, )); Self { diff --git a/tests/tests/wgpu-gpu/ray_tracing/pipelines.rs b/tests/tests/wgpu-gpu/ray_tracing/pipelines.rs new file mode 100644 index 00000000000..13869c309f7 --- /dev/null +++ b/tests/tests/wgpu-gpu/ray_tracing/pipelines.rs @@ -0,0 +1,204 @@ +use wgpu::{ + BindGroupDescriptor, BindGroupEntry, CommandEncoderDescriptor, Features, Limits, + RayTracingIntersectionDescriptor, RayTracingPassDescriptor, RayTracingPipelineDescriptor, + RayTracingStage, ShaderModuleDescriptor, +}; +use wgpu_test::{ + fail, gpu_test, GpuTestConfiguration, GpuTestInitializer, TestParameters, TestingContext, +}; +use wgpu_types::AccelerationStructureFlags; + +pub fn all_tests(tests: &mut Vec) { + tests.push(PIPELINE_CREATE_USE); +} + +#[gpu_test] +static PIPELINE_CREATE_USE: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters( + TestParameters::default() + .features(Features::EXPERIMENTAL_RAY_TRACING_PIPELINES) + .limits( + Limits::defaults() + .using_minimum_supported_acceleration_structure_values() + .using_minimum_supported_ray_tracing_pipeline_values(), + ), + ) + .run_sync(pipeline_create_use); + +fn pipeline_create_use(ctx: TestingContext) { + let mut as_ctx = super::AsBuildContext::new( + &ctx, + AccelerationStructureFlags::empty(), + AccelerationStructureFlags::empty(), + true, + ); + + let mut encoder = ctx + .device + .create_command_encoder(&CommandEncoderDescriptor::default()); + + // Build the BLAS and the TLAS. + encoder.build_acceleration_structures([&as_ctx.blas_build_entry()], [&as_ctx.tlas]); + + ctx.queue.submit([encoder.finish()]); + + let ray_gen_source = " + enable wgpu_ray_tracing_pipeline; + + @group(0) @binding(0) var acc_struct: acceleration_structure; + + var payload: u32; + + @ray_generation + fn gen() { + traceRay(acc_struct, RayDesc(0u, 0xFFu, 0.001, 100.0, vec3f(0.0, 0.0, 0.0), vec3f(0.0, 0.0, 1.0)), &payload); + } + "; + + let ray_closest_source = " + enable wgpu_ray_tracing_pipeline; + + var payload: u32; + + @closest_hit + @incoming_payload(payload) + fn closest() { + + } + "; + + let ray_any_source = " + enable wgpu_ray_tracing_pipeline; + + var payload: u32; + + @any_hit + @incoming_payload(payload) + fn any() { + + } + "; + + let ray_miss_source = " + enable wgpu_ray_tracing_pipeline; + + var payload: u32; + + @miss + @incoming_payload(payload) + fn miss() { + + } + "; + + let ray_gen = ctx.device.create_shader_module(ShaderModuleDescriptor { + label: Some("ray generation shader"), + source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(ray_gen_source)), + }); + + let ray_closest = ctx.device.create_shader_module(ShaderModuleDescriptor { + label: Some("ray closest hit shader"), + source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(ray_closest_source)), + }); + + let ray_any = ctx.device.create_shader_module(ShaderModuleDescriptor { + label: Some("ray any hit shader"), + source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(ray_any_source)), + }); + + let ray_miss = ctx.device.create_shader_module(ShaderModuleDescriptor { + label: Some("ray miss shader"), + source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(ray_miss_source)), + }); + + let pipeline = ctx + .device + .create_ray_tracing_pipeline(&RayTracingPipelineDescriptor { + label: None, + layout: None, + ray_generation: RayTracingStage { + module: &ray_gen, + entry_point: None, + compilation_options: Default::default(), + }, + miss: RayTracingStage { + module: &ray_miss, + entry_point: None, + compilation_options: Default::default(), + }, + intersection_descs: &[RayTracingIntersectionDescriptor::Triangle { + closest_hit: RayTracingStage { + module: &ray_closest, + entry_point: None, + compilation_options: Default::default(), + }, + any_hit: Some(RayTracingStage { + module: &ray_any, + entry_point: None, + compilation_options: Default::default(), + }), + }], + max_recersion_depth: 1, + cache: None, + }); + + let bind_group = ctx.device.create_bind_group(&BindGroupDescriptor { + label: Some("ray tracing pipeline bind group"), + layout: &pipeline.get_bind_group_layout(0), + entries: &[BindGroupEntry { + binding: 0, + resource: as_ctx.tlas.as_binding(), + }], + }); + + let mut encoder = ctx.device.create_command_encoder(&Default::default()); + + { + let mut pass = encoder.begin_ray_tracing_pass(&RayTracingPassDescriptor::default()); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, &bind_group, &[]); + pass.trace_rays(1, 1, 2); + } + + // Change the intersection index to the other valid one + as_ctx.tlas[0].as_mut().unwrap().intersection_index = + wgpu::IntersectionShaderIndex::IntersectionIndex(1); + + // Build the TLAS with the other index. + encoder.build_acceleration_structures([], [&as_ctx.tlas]); + + // Rerun with the new intersection index. + { + let mut pass = encoder.begin_ray_tracing_pass(&RayTracingPassDescriptor::default()); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, &bind_group, &[]); + pass.trace_rays(1, 1, 2); + } + + ctx.queue.submit([encoder.finish()]); + + let mut encoder = ctx.device.create_command_encoder(&Default::default()); + + // Change the intersection index to be invalid. + as_ctx.tlas[0].as_mut().unwrap().intersection_index = + wgpu::IntersectionShaderIndex::IntersectionIndex(2); + + // Build the TLAS with the invalid index. + encoder.build_acceleration_structures([], [&as_ctx.tlas]); + + // Rerun with the new intersection index (should fail). + { + let mut pass = encoder.begin_ray_tracing_pass(&RayTracingPassDescriptor::default()); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, &bind_group, &[]); + pass.trace_rays(1, 1, 2); + } + + fail( + &ctx.device, + || { + ctx.queue.submit([encoder.finish()]); + }, + None, + ); +} diff --git a/tests/tests/wgpu-gpu/ray_tracing/scene/mod.rs b/tests/tests/wgpu-gpu/ray_tracing/scene/mod.rs index 9b1760a0c72..1b087caa591 100644 --- a/tests/tests/wgpu-gpu/ray_tracing/scene/mod.rs +++ b/tests/tests/wgpu-gpu/ray_tracing/scene/mod.rs @@ -76,6 +76,7 @@ fn acceleration_structure_build(ctx: &TestingContext, use_index_buffer: bool) { )), 0, 0xff, + wgpu_types::IntersectionShaderIndex::QueryData(0), )); } diff --git a/tests/tests/wgpu-gpu/ray_tracing/shader.rs b/tests/tests/wgpu-gpu/ray_tracing/shader.rs index f61e9dd8545..3b0d3056a43 100644 --- a/tests/tests/wgpu-gpu/ray_tracing/shader.rs +++ b/tests/tests/wgpu-gpu/ray_tracing/shader.rs @@ -41,6 +41,7 @@ fn access_all_struct_members(ctx: TestingContext) { &ctx, AccelerationStructureFlags::empty(), AccelerationStructureFlags::empty(), + false, ); let mut encoder_build = ctx @@ -135,6 +136,7 @@ fn prevent_invalid_ray_query_calls(ctx: TestingContext) { &ctx, AccelerationStructureFlags::empty(), AccelerationStructureFlags::empty(), + false, ); let mut encoder_build = ctx diff --git a/wgpu-core/src/binding_model.rs b/wgpu-core/src/binding_model.rs index 7ad13e25e25..eeb1384ef02 100644 --- a/wgpu-core/src/binding_model.rs +++ b/wgpu-core/src/binding_model.rs @@ -21,7 +21,7 @@ use crate::{ device::{bgl, Device, DeviceError, MissingDownlevelFlags, MissingFeatures}, id::{BindGroupLayoutId, BufferId, ExternalTextureId, SamplerId, TextureViewId, TlasId}, init_tracker::{BufferInitTrackerAction, TextureInitTrackerAction}, - pipeline::{ComputePipeline, RenderPipeline}, + pipeline::{ComputePipeline, RayTracingPipeline, RenderPipeline}, resource::{ Buffer, DestroyedResourceError, ExternalTexture, InvalidResourceError, Labeled, MissingBufferUsageError, MissingTextureUsageError, RawResourceAccess, ResourceErrorIdent, @@ -722,6 +722,7 @@ pub(crate) enum ExclusivePipeline { None, Render(Weak), Compute(Weak), + RayTracing(Weak), } impl From<&Arc> for ExclusivePipeline { @@ -736,6 +737,12 @@ impl From<&Arc> for ExclusivePipeline { } } +impl From<&Arc> for ExclusivePipeline { + fn from(pipeline: &Arc) -> Self { + Self::RayTracing(Arc::downgrade(pipeline)) + } +} + impl fmt::Display for ExclusivePipeline { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -754,6 +761,13 @@ impl fmt::Display for ExclusivePipeline { f.write_str("ComputePipeline") } } + ExclusivePipeline::RayTracing(p) => { + if let Some(p) = p.upgrade() { + p.error_ident().fmt(f) + } else { + f.write_str("RayTracingPipeline") + } + } } } } diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index 761aaca5e5e..2adf9b57d66 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -378,7 +378,7 @@ impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> { .remove_usage(buffer, wgt::BufferUses::INDIRECT); } - flush_bindings_helper(&mut self.pass)?; + flush_bindings_helper(&mut self.pass, resource::MaxIntersectionIndex::Query)?; CommandEncoder::drain_barriers( self.pass.base.raw_encoder, diff --git a/wgpu-core/src/command/encoder_command.rs b/wgpu-core/src/command/encoder_command.rs index 2dfe24590c8..52df9440244 100644 --- a/wgpu-core/src/command/encoder_command.rs +++ b/wgpu-core/src/command/encoder_command.rs @@ -22,6 +22,7 @@ pub trait ReferenceType { type RenderPipeline: Clone + core::fmt::Debug; type RenderBundle: Clone + core::fmt::Debug; type ComputePipeline: Clone + core::fmt::Debug; + type RayTracingPipeline: Clone + core::fmt::Debug; type Blas: Clone + core::fmt::Debug; type Tlas: Clone + core::fmt::Debug; } @@ -55,6 +56,7 @@ impl ReferenceType for IdReferences { type RenderPipeline = id::RenderPipelineId; type RenderBundle = id::RenderBundleId; type ComputePipeline = id::ComputePipelineId; + type RayTracingPipeline = id::RayTracingPipelineId; type Blas = id::BlasId; type Tlas = id::TlasId; } @@ -71,6 +73,7 @@ impl ReferenceType for PointerReferences { type RenderPipeline = id::PointerId; type RenderBundle = id::PointerId; type ComputePipeline = id::PointerId; + type RayTracingPipeline = id::PointerId; type Blas = id::PointerId; type Tlas = id::PointerId; } @@ -86,6 +89,7 @@ impl ReferenceType for ArcReferences { type RenderPipeline = Arc; type RenderBundle = Arc; type ComputePipeline = Arc; + type RayTracingPipeline = Arc; type Blas = Arc; type Tlas = Arc; } @@ -105,6 +109,7 @@ attribute_alias! { R::RenderPipeline: serde::Serialize + for<'d> serde::Deserialize<'d>,\ R::RenderBundle: serde::Serialize + for<'d> serde::Deserialize<'d>,\ R::ComputePipeline: serde::Serialize + for<'d> serde::Deserialize<'d>,\ + R::RayTracingPipeline: serde::Serialize + for<'d> serde::Deserialize<'d>,\ R::Blas: serde::Serialize + for<'d> serde::Deserialize<'d>,\ R::Tlas: serde::Serialize + for<'d> serde::Deserialize<'d>,\ wgt::BufferTransition: serde::Serialize + for<'d> serde::Deserialize<'d>,\ @@ -164,6 +169,9 @@ pub enum Command { pass: crate::command::BasePass, Infallible>, timestamp_writes: Option>, }, + RunRayTracingPass { + pass: crate::command::BasePass, Infallible>, + }, RunRenderPass { pass: crate::command::BasePass, Infallible>, color_attachments: ColorAttachments, diff --git a/wgpu-core/src/command/mod.rs b/wgpu-core/src/command/mod.rs index fe5e0dd2c65..21d45816fe4 100644 --- a/wgpu-core/src/command/mod.rs +++ b/wgpu-core/src/command/mod.rs @@ -22,6 +22,8 @@ mod memory_init; mod pass; mod query; mod ray_tracing; +mod ray_tracing_pass; +mod ray_tracing_pass_commands; mod render; mod render_command; mod timestamp_writes; @@ -56,6 +58,11 @@ pub use self::{ draw::{DrawError, Rect, RenderCommandError}, encoder_command::{ArcCommand, ArcReferences, Command, IdReferences, ReferenceType}, query::{QueryError, QueryUseError, ResolveError, SimplifiedQueryType}, + ray_tracing_pass::{ + RayTracingBasePass, RayTracingPass, RayTracingPassDescriptor, RayTracingPassError, + RayTracingPassErrorInner, TraceRayError, + }, + ray_tracing_pass_commands::ArcRayTracingCommand, render::{ ArcRenderPassColorAttachment, AttachmentError, AttachmentErrorLocation, ColorAttachmentError, ColorAttachments, LoadOp, PassChannel, RenderBasePass, RenderPass, @@ -81,7 +88,10 @@ pub(crate) use self::{ pub(crate) use allocator::CommandAllocator; /// cbindgen:ignore -pub use self::{compute_command::ComputeCommand, render_command::RenderCommand}; +pub use self::{ + compute_command::ComputeCommand, ray_tracing_pass_commands::RayTracingCommand, + render_command::RenderCommand, +}; pub(crate) use timestamp_writes::ArcPassTimestampWrites; pub use timestamp_writes::PassTimestampWrites; @@ -1033,7 +1043,9 @@ impl CommandEncoder { for command in commands { if matches!( command, - ArcCommand::RunRenderPass { .. } | ArcCommand::RunComputePass { .. } + ArcCommand::RunRenderPass { .. } + | ArcCommand::RunComputePass { .. } + | ArcCommand::RunRayTracingPass { .. } ) { // Compute passes and render passes can accept either an // open or closed encoder. This state object holds an @@ -1104,6 +1116,22 @@ impl CommandEncoder { } res?; } + ArcCommand::RunRayTracingPass { pass } => { + api_log!( + "Begin encoding ray tracing pass with '{}' label", + pass.label.as_deref().unwrap_or("") + ); + let res = ray_tracing_pass::encode_ray_tracing_pass(&mut state, pass); + match res.as_ref() { + Err(err) => { + api_log!("Finished encoding ray tracing pass ({err:?})") + } + Ok(_) => { + api_log!("Finished encoding ray tracing pass (success)") + } + } + res?; + } _ => unreachable!(), } } else { @@ -1200,7 +1228,9 @@ impl CommandEncoder { texture_transitions, )?; } - ArcCommand::RunComputePass { .. } | ArcCommand::RunRenderPass { .. } => { + ArcCommand::RunComputePass { .. } + | ArcCommand::RunRenderPass { .. } + | ArcCommand::RunRayTracingPass { .. } => { unreachable!() } } @@ -1570,6 +1600,8 @@ pub enum CommandEncoderError { ComputePass(#[from] ComputePassError), #[error(transparent)] RenderPass(#[from] RenderPassError), + #[error(transparent)] + RayTracingPass(#[from] RayTracingPassError), } impl CommandEncoderError { @@ -1620,6 +1652,7 @@ impl WebGpuError for CommandEncoderError { Self::ResourceUsage(e) => e.webgpu_error_type(), Self::ComputePass(e) => e.webgpu_error_type(), Self::RenderPass(e) => e.webgpu_error_type(), + Self::RayTracingPass(e) => e.webgpu_error_type(), } } } @@ -2059,6 +2092,8 @@ pub enum PassErrorScope { PopDebugGroup, #[error("In a insert_debug_marker command")] InsertDebugMarker, + #[error("In a trace rays command")] + TraceRays, } /// Variant of `EncoderStateError` that includes the pass scope. diff --git a/wgpu-core/src/command/pass.rs b/wgpu-core/src/command/pass.rs index 67ba4ba6db0..b792127256a 100644 --- a/wgpu-core/src/command/pass.rs +++ b/wgpu-core/src/command/pass.rs @@ -8,7 +8,9 @@ use crate::command::{ }; use crate::device::{Device, DeviceError, MissingFeatures}; use crate::pipeline::LateSizedBufferGroup; -use crate::resource::{DestroyedResourceError, Labeled, ParentDevice, QuerySet}; +use crate::resource::{ + DestroyedResourceError, Labeled, MaxIntersectionIndex, ParentDevice, QuerySet, +}; use crate::track::{ResourceUsageCompatibilityError, UsageScope}; use crate::{api_log, binding_model}; use alloc::sync::Arc; @@ -141,7 +143,12 @@ where /// /// See the compute pass version of `State::flush_bindings` for an explanation /// of some differences in handling the two types of passes. -pub(super) fn flush_bindings_helper(state: &mut PassState) -> Result<(), DestroyedResourceError> { +/// +/// `tlas_max_allowed` should be either [MaxIntersectionIndex::Query] or [MaxIntersectionIndex::Intersection] +pub(super) fn flush_bindings_helper( + state: &mut PassState, + tlas_max_allowed: MaxIntersectionIndex, +) -> Result<(), DestroyedResourceError> { let start = state.binder.take_rebind_start_index(); let entries = state.binder.list_valid_with_start(start); let pipeline_layout = state.binder.pipeline_layout.as_ref().unwrap(); @@ -169,7 +176,7 @@ pub(super) fn flush_bindings_helper(state: &mut PassState) -> Result<(), Destroy .used .acceleration_structures .into_iter() - .map(|tlas| crate::ray_tracing::AsAction::UseTlas(tlas.clone())); + .map(|tlas| crate::ray_tracing::AsAction::UseTlas(tlas.clone(), tlas_max_allowed)); state.base.as_actions.extend(used_resource); diff --git a/wgpu-core/src/command/ray_tracing.rs b/wgpu-core/src/command/ray_tracing.rs index f8bba308b98..6a0d0e9be75 100644 --- a/wgpu-core/src/command/ray_tracing.rs +++ b/wgpu-core/src/command/ray_tracing.rs @@ -10,7 +10,7 @@ use wgt::{math::align_to, BufferUsages, BufferUses, Features}; use crate::{ command::encoder::EncodingState, ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError}, - resource::InvalidResourceError, + resource::{InvalidResourceError, MaxIntersectionIndex}, }; use crate::{command::EncoderStateError, device::resource::CommandIndices}; use crate::{ @@ -76,7 +76,11 @@ impl Global { |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> { let device = &cmd_enc.device; device.check_is_valid()?; - device.require_features(Features::EXPERIMENTAL_RAY_QUERY)?; + device + .require_features(Features::EXPERIMENTAL_RAY_QUERY) + .or_else(|_| { + device.require_features(Features::EXPERIMENTAL_RAY_TRACING_PIPELINES) + })?; let mut build_command = AsBuild::with_capacity(blas_ids.len(), tlas_ids.len()); @@ -90,6 +94,7 @@ impl Global { build_command.tlas_s_built.push(TlasBuild { tlas, dependencies: Vec::new(), + max_intersection_idx: MaxIntersectionIndex::Unused, }); } @@ -173,6 +178,7 @@ impl Global { transform: *instance.transform, custom_data: instance.custom_data, mask: instance.mask, + intersection_index: instance.intersection_index, }) }) .transpose() @@ -198,7 +204,12 @@ pub(crate) fn build_acceleration_structures( ) -> Result<(), BuildAccelerationStructureError> { state .device - .require_features(Features::EXPERIMENTAL_RAY_QUERY)?; + .require_features(Features::EXPERIMENTAL_RAY_QUERY) + .or_else(|_| { + state + .device + .require_features(Features::EXPERIMENTAL_RAY_TRACING_PIPELINES) + })?; let mut build_command = AsBuild::with_capacity(blas.len(), tlas.len()); let mut input_barriers = Vec::>::new(); @@ -232,21 +243,60 @@ pub(crate) fn build_acceleration_structures( let mut dependencies = Vec::new(); let mut instance_count = 0; - for instance in package.instances.iter().flatten() { + + let mut max_intersection_idx = MaxIntersectionIndex::Unused; + for (instance_idx, instance) in package.instances.iter().flatten().enumerate() { + let Some(()) = max_intersection_idx.set_intersection_index(instance.intersection_index) + else { + return Err( + BuildAccelerationStructureError::TlasInstancesIntersectionIndicesDiffer( + tlas.error_ident(), + instance_idx, + ), + ); + }; + if instance.custom_data >= (1u32 << 24u32) { - return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex( + return Err(BuildAccelerationStructureError::TlasInvalidCustomData( tlas.error_ident(), + instance_idx, )); } let blas = &instance.blas; state.tracker.blas_s.insert_single(blas.clone()); + let intersection_data_offset = match instance.intersection_index { + wgt::IntersectionShaderIndex::IntersectionIndex(v) => { + if v >= state.device.limits.max_intersection_group_count { + return Err( + BuildAccelerationStructureError::TlasInvalidIntersectionIndex( + tlas.error_ident(), + instance_idx, + v, + state.device.limits.max_intersection_group_count, + ), + ); + } + v * state.device.alignments.ray_tracing_pipeline_group_data_size + } + wgt::IntersectionShaderIndex::QueryData(v) => { + if v >= (1u32 << 24u32) { + return Err(BuildAccelerationStructureError::TlasInvalidQueryData( + tlas.error_ident(), + instance_idx, + )); + } + v + } + }; + instance_buffer_staging_source.extend(state.device.raw().tlas_instance_to_bytes( hal::TlasInstance { transform: instance.transform, custom_data: instance.custom_data, mask: instance.mask, blas_address: blas.handle, + pipeline_intersection_data_offset: intersection_data_offset, }, )); @@ -273,6 +323,7 @@ pub(crate) fn build_acceleration_structures( build_command.tlas_s_built.push(TlasBuild { tlas: tlas.clone(), dependencies, + max_intersection_idx, }); if instance_count > tlas.max_instance_count { @@ -508,10 +559,12 @@ impl CommandBufferMutable { .tlas .dependencies .write() - .clone_from(&tlas_build.dependencies) + .clone_from(&tlas_build.dependencies); + *tlas_build.tlas.max_intersection_index.write() = + tlas_build.max_intersection_idx; } } - AsAction::UseTlas(tlas) => { + AsAction::UseTlas(tlas, max_intersection_idx) => { let tlas_build_index = tlas.built_index.read(); let dependencies = tlas.dependencies.read(); @@ -534,6 +587,16 @@ impl CommandBufferMutable { } blas.try_raw(snatch_guard)?; } + + let current_max = tlas.max_intersection_index.read(); + + if !current_max.at_most(*max_intersection_idx) { + return Err(ValidateAsActionsError::TlasIntersectionInvalid( + tlas.error_ident(), + *current_max, + *max_intersection_idx, + )); + } } } } @@ -546,7 +609,7 @@ impl CommandBufferMutable { .as_actions .iter() .filter_map(|action| { - if let AsAction::UseTlas(tlas) = action { + if let AsAction::UseTlas(tlas, _) = action { Some(tlas.dependencies.read()) } else { None @@ -566,7 +629,7 @@ impl CommandBufferMutable { } } } - AsAction::UseTlas(_tlas) => { + AsAction::UseTlas(_tlas, _) => { let tlas_dependencies = tlas_dependencies_lock_iter.next().unwrap(); // _tlas.dependencies.read(); for dependency in tlas_dependencies.iter() { if let Some(dependency) = dependency.raw(snatch_guard) { diff --git a/wgpu-core/src/command/ray_tracing_pass.rs b/wgpu-core/src/command/ray_tracing_pass.rs new file mode 100644 index 00000000000..2469d64ced7 --- /dev/null +++ b/wgpu-core/src/command/ray_tracing_pass.rs @@ -0,0 +1,859 @@ +use core::{convert::Infallible, fmt}; + +use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec}; + +use thiserror::Error; +use wgt::{ + error::{ErrorType, WebGpuError}, + BufferAddress, DynamicOffset, +}; + +use crate::{ + api_log, + binding_model::{BindError, ImmediateUploadError, LateMinBufferBindingSizeMismatch}, + command::{ + bind::{Binder, BinderError}, + memory_init::SurfacesInDiscardState, + pass::{self, flush_bindings_helper}, + pass_base, pass_try, + ray_tracing_pass_commands::ArcRayTracingCommand, + ArcCommand, BasePass, BindGroupStateChange, CommandEncoder, CommandEncoderError, + DebugGroupError, EncoderStateError, EncodingState, InnerCommandEncoder, MapPassErr, + PassErrorScope, PassStateError, StateChange, + }, + device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures}, + global::Global, + hal_label, id, + pipeline::RayTracingPipeline, + resource::{ + DestroyedResourceError, InvalidResourceError, Labeled, MissingBufferUsageError, + ParentDevice, + }, + track::{ResourceUsageCompatibilityError, Tracker}, + Label, +}; + +pub type RayTracingBasePass = BasePass; + +/// Very similar to [`super::compute::ComputePass`] +pub struct RayTracingPass { + /// All pass data & records is stored here. + base: RayTracingBasePass, + + /// Parent command encoder that this pass records commands into. + /// + /// Implications are the same as [`super::compute::ComputePass::parent`] + parent: Option>, + + current_bind_groups: BindGroupStateChange, + current_pipeline: StateChange, +} + +impl fmt::Debug for RayTracingPass { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.parent { + Some(ref cmd_enc) => { + write!(f, "RayTracingPass {{ parent: {} }}", cmd_enc.error_ident()) + } + None => write!(f, "RayTracingPass {{ parent: None }}"), + } + } +} + +impl RayTracingPass { + /// If the parent command encoder is invalid, the returned pass will be invalid. + fn new(parent: Arc, desc: RayTracingPassDescriptor) -> Self { + let RayTracingPassDescriptor { label } = desc; + + Self { + base: BasePass::new(&label), + parent: Some(parent), + + current_bind_groups: BindGroupStateChange::new(), + current_pipeline: StateChange::new(), + } + } + + fn new_invalid(parent: Arc, label: &Label, err: RayTracingPassError) -> Self { + Self { + base: BasePass::new_invalid(label, err), + parent: Some(parent), + current_bind_groups: BindGroupStateChange::new(), + current_pipeline: StateChange::new(), + } + } + + #[inline] + pub fn label(&self) -> Option<&str> { + self.base.label.as_deref() + } +} + +#[derive(Clone, Debug, Default)] +pub struct RayTracingPassDescriptor<'a> { + pub label: Label<'a>, +} + +#[derive(Clone, Debug, Error)] +#[non_exhaustive] +pub enum TraceRayError { + #[error("Ray tracing pipeline must be set")] + MissingPipeline(pass::MissingPipeline), + #[error(transparent)] + IncompatibleBindGroup(#[from] Box), + #[error("Trace ray size ({current:?}) must be less or equal to {limit:?}")] + InvalidDimensionSize { current: [u32; 3], limit: [u32; 3] }, + #[error("The total count of rays invocations ({current:?}) must be less or equal to {limit}")] + TooManyTotal { current: u32, limit: u32 }, + #[error(transparent)] + BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch), +} + +impl WebGpuError for TraceRayError { + fn webgpu_error_type(&self) -> ErrorType { + ErrorType::Validation + } +} + +/// Error encountered when performing a ray tracing pass. +#[derive(Clone, Debug, Error)] +pub enum RayTracingPassErrorInner { + #[error(transparent)] + Device(#[from] DeviceError), + #[error(transparent)] + EncoderState(#[from] EncoderStateError), + #[error("Parent encoder is invalid")] + InvalidParentEncoder, + #[error(transparent)] + DebugGroupError(#[from] DebugGroupError), + #[error(transparent)] + BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange), + #[error(transparent)] + DestroyedResource(#[from] DestroyedResourceError), + #[error("Indirect buffer offset {0:?} is not a multiple of 4")] + UnalignedIndirectBufferOffset(BufferAddress), + #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")] + IndirectBufferOverrun { + offset: u64, + end_offset: u64, + buffer_size: u64, + }, + #[error(transparent)] + ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError), + #[error(transparent)] + MissingBufferUsage(#[from] MissingBufferUsageError), + #[error(transparent)] + TraceRay(#[from] TraceRayError), + #[error(transparent)] + Bind(#[from] BindError), + #[error(transparent)] + ImmediateData(#[from] ImmediateUploadError), + #[error("Immediate data offset must be aligned to 4 bytes")] + ImmediateOffsetAlignment, + #[error("Immediate data size must be aligned to 4 bytes")] + ImmediateDataizeAlignment, + #[error("Ran out of immediate data space. Don't set 4gb of immediates per RayTracingPass.")] + ImmediateOutOfMemory, + #[error(transparent)] + MissingFeatures(#[from] MissingFeatures), + #[error(transparent)] + MissingDownlevelFlags(#[from] MissingDownlevelFlags), + #[error("The ray tracing pass has already been ended and no further commands can be recorded")] + PassEnded, + #[error(transparent)] + InvalidResource(#[from] InvalidResourceError), + // This one is unreachable, but required for generic pass support + #[error(transparent)] + InvalidValuesOffset(#[from] pass::InvalidValuesOffset), +} + +/// Error encountered when performing a ray tracing pass, stored for later reporting +/// when encoding ends. +#[derive(Clone, Debug, Error)] +#[error("{scope}")] +pub struct RayTracingPassError { + pub scope: PassErrorScope, + #[source] + pub(super) inner: RayTracingPassErrorInner, +} + +impl From for RayTracingPassErrorInner { + fn from(value: pass::MissingPipeline) -> Self { + Self::TraceRay(TraceRayError::MissingPipeline(value)) + } +} + +impl MapPassErr for E +where + E: Into, +{ + fn map_pass_err(self, scope: PassErrorScope) -> RayTracingPassError { + RayTracingPassError { + scope, + inner: self.into(), + } + } +} + +impl WebGpuError for RayTracingPassError { + fn webgpu_error_type(&self) -> ErrorType { + let Self { scope: _, inner } = self; + match inner { + RayTracingPassErrorInner::Device(e) => e.webgpu_error_type(), + RayTracingPassErrorInner::EncoderState(e) => e.webgpu_error_type(), + RayTracingPassErrorInner::DebugGroupError(e) => e.webgpu_error_type(), + RayTracingPassErrorInner::DestroyedResource(e) => e.webgpu_error_type(), + RayTracingPassErrorInner::ResourceUsageCompatibility(e) => e.webgpu_error_type(), + RayTracingPassErrorInner::MissingBufferUsage(e) => e.webgpu_error_type(), + RayTracingPassErrorInner::TraceRay(e) => e.webgpu_error_type(), + RayTracingPassErrorInner::Bind(e) => e.webgpu_error_type(), + RayTracingPassErrorInner::ImmediateData(e) => e.webgpu_error_type(), + RayTracingPassErrorInner::MissingFeatures(e) => e.webgpu_error_type(), + RayTracingPassErrorInner::MissingDownlevelFlags(e) => e.webgpu_error_type(), + RayTracingPassErrorInner::InvalidResource(e) => e.webgpu_error_type(), + RayTracingPassErrorInner::InvalidValuesOffset(e) => e.webgpu_error_type(), + + RayTracingPassErrorInner::InvalidParentEncoder + | RayTracingPassErrorInner::BindGroupIndexOutOfRange { .. } + | RayTracingPassErrorInner::UnalignedIndirectBufferOffset(_) + | RayTracingPassErrorInner::IndirectBufferOverrun { .. } + | RayTracingPassErrorInner::ImmediateOffsetAlignment + | RayTracingPassErrorInner::ImmediateDataizeAlignment + | RayTracingPassErrorInner::ImmediateOutOfMemory + | RayTracingPassErrorInner::PassEnded => ErrorType::Validation, + } + } +} + +struct State<'scope, 'snatch_guard, 'cmd_enc> { + pipeline: Option>, + + pass: pass::PassState<'scope, 'snatch_guard, 'cmd_enc>, + + intermediate_trackers: Tracker, +} + +impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> { + fn is_ready(&self) -> Result<(), TraceRayError> { + if let Some(pipeline) = self.pipeline.as_ref() { + self.pass.binder.check_compatibility(pipeline.as_ref())?; + self.pass.binder.check_late_buffer_bindings()?; + Ok(()) + } else { + Err(TraceRayError::MissingPipeline(pass::MissingPipeline)) + } + } + + /// Flush binding state in preparation for a trace rays call. + /// + /// # Differences between render and compute (from which ray tracing passes inherit functionality) passes + /// + /// There are differences between the `flush_bindings` implementations for + /// render and compute passes, because render passes have a single usage + /// scope for the entire pass, and compute passes have a separate usage + /// scope for each dispatch. + /// + /// For compute passes, bind groups are merged into a fresh usage scope + /// here, not into the pass usage scope within calls to `set_bind_group`. As + /// specified by WebGPU, for compute passes, we merge only the bind groups + /// that are actually used by the pipeline, unlike render passes, which + /// merge every bind group that is ever set, even if it is not ultimately + /// used by the pipeline. + /// + /// For compute passes, we call `drain_barriers` here, because barriers may + /// be needed before each dispatch if a previous dispatch had a conflicting + /// usage. For render passes, barriers are emitted once at the start of the + /// render pass. + fn flush_bindings(&mut self) -> Result<(), RayTracingPassErrorInner> { + for bind_group in self.pass.binder.list_active() { + unsafe { self.pass.scope.merge_bind_group(&bind_group.used)? }; + } + // For compute, usage scopes are associated with each dispatch and not + // with the pass as a whole. However, because the cost of creating and + // dropping `UsageScope`s is significant (even with the pool), we + // add and then remove usage from a single usage scope. + + for bind_group in self.pass.binder.list_active() { + self.intermediate_trackers + .set_and_remove_from_usage_scope_sparse(&mut self.pass.scope, &bind_group.used); + } + + flush_bindings_helper( + &mut self.pass, + crate::resource::MaxIntersectionIndex::Intersection( + self.pipeline + .as_ref() + .unwrap() + .shader_binding_data + .num_intersection_groups as _, + ), + )?; + + CommandEncoder::drain_barriers( + self.pass.base.raw_encoder, + &mut self.intermediate_trackers, + self.pass.base.snatch_guard, + ); + Ok(()) + } +} + +// Ray tracing pass commands + +impl Global { + /// Creates a ray tracing pass. + /// + /// If creation fails, an invalid pass is returned. Attempting to record + /// commands into an invalid pass is permitted, but a validation error will + /// ultimately be generated when the parent encoder is finished, and it is + /// not possible to run any commands from the invalid pass. + /// + /// If successful, puts the encoder into the [`Locked`] state. + /// + /// [`Locked`]: crate::command::CommandEncoderStatus::Locked + pub fn command_encoder_begin_ray_tracing_pass( + &self, + encoder_id: id::CommandEncoderId, + desc: &RayTracingPassDescriptor<'_>, + ) -> (RayTracingPass, Option) { + use EncoderStateError as SErr; + + let scope = PassErrorScope::Pass; + let hub = &self.hub; + + let label = desc.label.as_deref().map(Cow::Borrowed); + + let cmd_enc = hub.command_encoders.get(encoder_id); + let mut cmd_buf_data = cmd_enc.data.lock(); + + match cmd_buf_data.lock_encoder() { + Ok(()) => { + drop(cmd_buf_data); + if let Err(err) = cmd_enc.device.check_is_valid() { + return ( + RayTracingPass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)), + None, + ); + } + + ( + RayTracingPass::new(cmd_enc, RayTracingPassDescriptor { label }), + None, + ) + } + Err(err @ SErr::Locked) => { + // Attempting to open a new pass while the encoder is locked + // invalidates the encoder, but does not generate a validation + // error. + cmd_buf_data.invalidate(err.clone()); + drop(cmd_buf_data); + ( + RayTracingPass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)), + None, + ) + } + Err(err @ (SErr::Ended | SErr::Submitted)) => { + // Attempting to open a new pass after the encode has ended + // generates an immediate validation error. + drop(cmd_buf_data); + ( + RayTracingPass::new_invalid(cmd_enc, &label, err.clone().map_pass_err(scope)), + Some(err.into()), + ) + } + Err(err @ SErr::Invalid) => { + // Passes can be opened even on an invalid encoder. Such passes + // are even valid, but since there's no visible side-effect of + // the pass being valid and there's no point in storing recorded + // commands that will ultimately be discarded, we open an + // invalid pass to save that work. + drop(cmd_buf_data); + ( + RayTracingPass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)), + None, + ) + } + Err(SErr::Unlocked) => { + unreachable!("lock_encoder cannot fail due to the encoder being unlocked") + } + } + } + + pub fn ray_tracing_pass_set_pipeline( + &self, + pass: &mut RayTracingPass, + pipeline_id: id::RayTracingPipelineId, + ) -> Result<(), PassStateError> { + let scope = PassErrorScope::SetPipelineRender; + + let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id); + + // This statement will return an error if the pass is ended. + // Its important the error check comes before the early-out for `redundant`. + let base = pass_base!(pass, scope); + + if redundant { + return Ok(()); + } + + let hub = &self.hub; + let pipeline = pass_try!( + base, + scope, + hub.ray_tracing_pipelines.get(pipeline_id).get() + ); + + base.commands + .push(ArcRayTracingCommand::SetPipeline(pipeline)); + + Ok(()) + } + + pub fn ray_tracing_pass_set_bind_group( + &self, + pass: &mut RayTracingPass, + index: u32, + bind_group_id: Option, + offsets: &[DynamicOffset], + ) -> Result<(), PassStateError> { + let scope = PassErrorScope::SetBindGroup; + + // This statement will return an error if the pass is ended. It's + // important the error check comes before the early-out for + // `set_and_check_redundant`. + let base = pass_base!(pass, scope); + + if pass.current_bind_groups.set_and_check_redundant( + bind_group_id, + index, + &mut base.dynamic_offsets, + offsets, + ) { + return Ok(()); + } + + let mut bind_group = None; + if let Some(bind_group_id) = bind_group_id { + let hub = &self.hub; + bind_group = Some(pass_try!( + base, + scope, + hub.bind_groups.get(bind_group_id).get(), + )); + } + + base.commands.push(ArcRayTracingCommand::SetBindGroup { + index, + num_dynamic_offsets: offsets.len(), + bind_group, + }); + + Ok(()) + } + + pub fn ray_tracing_pass_set_immediates( + &self, + pass: &mut RayTracingPass, + offset: u32, + data: &[u8], + ) -> Result<(), PassStateError> { + let scope = PassErrorScope::SetImmediate; + let base = pass_base!(pass, scope); + + if offset & (wgt::IMMEDIATE_DATA_ALIGNMENT - 1) != 0 { + pass_try!( + base, + scope, + Err(RayTracingPassErrorInner::ImmediateOffsetAlignment) + ); + } + if data.len() as u32 & (wgt::IMMEDIATE_DATA_ALIGNMENT - 1) != 0 { + pass_try!( + base, + scope, + Err(RayTracingPassErrorInner::ImmediateDataizeAlignment) + ); + } + + let value_offset = pass_try!( + base, + scope, + base.immediates_data + .len() + .try_into() + .map_err(|_| RayTracingPassErrorInner::ImmediateOutOfMemory), + ); + + base.immediates_data.extend( + data.chunks_exact(wgt::IMMEDIATE_DATA_ALIGNMENT as usize) + .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])), + ); + + base.commands.push(ArcRayTracingCommand::SetImmediate { + offset, + size_bytes: data.len() as u32, + values_offset: value_offset, + }); + + Ok(()) + } + + pub fn ray_tracing_pass_push_debug_group( + &self, + pass: &mut RayTracingPass, + label: &str, + color: u32, + ) -> Result<(), PassStateError> { + let base = pass_base!(pass, PassErrorScope::PushDebugGroup); + + let bytes = label.as_bytes(); + base.string_data.extend_from_slice(bytes); + + base.commands.push(ArcRayTracingCommand::PushDebugGroup { + color, + len: bytes.len(), + }); + + Ok(()) + } + + pub fn ray_tracing_pass_pop_debug_group( + &self, + pass: &mut RayTracingPass, + ) -> Result<(), PassStateError> { + let base = pass_base!(pass, PassErrorScope::PopDebugGroup); + + base.commands.push(ArcRayTracingCommand::PopDebugGroup); + + Ok(()) + } + + pub fn ray_tracing_pass_insert_debug_marker( + &self, + pass: &mut RayTracingPass, + label: &str, + color: u32, + ) -> Result<(), PassStateError> { + let base = pass_base!(pass, PassErrorScope::InsertDebugMarker); + + let bytes = label.as_bytes(); + base.string_data.extend_from_slice(bytes); + + base.commands.push(ArcRayTracingCommand::InsertDebugMarker { + color, + len: bytes.len(), + }); + + Ok(()) + } + + pub fn ray_tracing_pass_trace_rays( + &self, + pass: &mut RayTracingPass, + count_x: u32, + count_y: u32, + count_z: u32, + ) -> Result<(), PassStateError> { + let scope = PassErrorScope::TraceRays; + + pass_base!(pass, scope) + .commands + .push(ArcRayTracingCommand::TraceRays([count_x, count_y, count_z])); + + Ok(()) + } + + pub fn ray_tracing_pass_end(&self, pass: &mut RayTracingPass) -> Result<(), EncoderStateError> { + profiling::scope!( + "CommandEncoder::encode_ray_tracing_pass {}", + pass.base.label.as_deref().unwrap_or("") + ); + + let cmd_enc = pass.parent.take().ok_or(EncoderStateError::Ended)?; + let mut cmd_buf_data = cmd_enc.data.lock(); + + cmd_buf_data.unlock_encoder()?; + + let base = pass.base.take(); + + if let Err(RayTracingPassError { + inner: + RayTracingPassErrorInner::EncoderState( + err @ (EncoderStateError::Locked | EncoderStateError::Ended), + ), + scope: _, + }) = base + { + // Most encoding errors are detected and raised within `finish()`. + // + // However, we raise a validation error here if the pass was opened + // within another pass, or on a finished encoder. The latter is + // particularly important, because in that case reporting errors via + // `CommandEncoder::finish` is not possible. + return Err(err.clone()); + } + + cmd_buf_data.push_with(|| -> Result<_, RayTracingPassError> { + Ok(ArcCommand::RunRayTracingPass { pass: base? }) + }) + } +} + +pub(super) fn encode_ray_tracing_pass( + parent_state: &mut EncodingState, + mut base: BasePass, +) -> Result<(), RayTracingPassError> { + let pass_scope = PassErrorScope::Pass; + + let device = parent_state.device; + + // We automatically keep extending command buffers over time, and because + // we want to insert a command buffer _before_ what we're about to record, + // we need to make sure to close the previous one. + parent_state + .raw_encoder + .close_if_open() + .map_pass_err(pass_scope)?; + let raw_encoder = parent_state + .raw_encoder + .open_pass(base.label.as_deref()) + .map_pass_err(pass_scope)?; + + let mut debug_scope_depth = 0; + + let mut state = State { + pipeline: None, + + pass: pass::PassState { + base: EncodingState { + device, + raw_encoder, + tracker: parent_state.tracker, + buffer_memory_init_actions: parent_state.buffer_memory_init_actions, + texture_memory_actions: parent_state.texture_memory_actions, + as_actions: parent_state.as_actions, + temp_resources: parent_state.temp_resources, + indirect_draw_validation_resources: parent_state.indirect_draw_validation_resources, + snatch_guard: parent_state.snatch_guard, + debug_scope_depth: &mut debug_scope_depth, + }, + binder: Binder::new(), + temp_offsets: Vec::new(), + dynamic_offset_count: 0, + pending_discard_init_fixups: SurfacesInDiscardState::new(), + scope: device.new_usage_scope(), + string_offset: 0, + }, + + intermediate_trackers: Tracker::new( + device.ordered_buffer_usages, + device.ordered_texture_usages, + ), + }; + + let indices = &device.tracker_indices; + state + .pass + .base + .tracker + .buffers + .set_size(indices.buffers.size()); + state + .pass + .base + .tracker + .textures + .set_size(indices.textures.size()); + + let hal_desc = hal::RayTracingPassDescriptor { + label: hal_label(base.label.as_deref(), device.instance_flags), + }; + + unsafe { + state + .pass + .base + .raw_encoder + .begin_ray_tracing_pass(&hal_desc); + } + + for command in base.commands.drain(..) { + match command { + ArcRayTracingCommand::SetBindGroup { + index, + num_dynamic_offsets, + bind_group, + } => { + let scope = PassErrorScope::SetBindGroup; + pass::set_bind_group::( + &mut state.pass, + device, + &base.dynamic_offsets, + index, + num_dynamic_offsets, + bind_group, + false, + ) + .map_pass_err(scope)?; + } + ArcRayTracingCommand::SetPipeline(pipeline) => { + let scope = PassErrorScope::SetPipelineCompute; + set_pipeline(&mut state, device, pipeline).map_pass_err(scope)?; + } + ArcRayTracingCommand::SetImmediate { + offset, + size_bytes, + values_offset, + } => { + let scope = PassErrorScope::SetImmediate; + + pass::set_immediates::( + &mut state.pass, + &base.immediates_data, + offset, + size_bytes, + Some(values_offset), + |_| {}, + ) + .map_pass_err(scope)?; + } + ArcRayTracingCommand::PushDebugGroup { color: _, len } => { + pass::push_debug_group(&mut state.pass, &base.string_data, len); + } + ArcRayTracingCommand::PopDebugGroup => { + let scope = PassErrorScope::PopDebugGroup; + pass::pop_debug_group::(&mut state.pass) + .map_pass_err(scope)?; + } + ArcRayTracingCommand::InsertDebugMarker { color: _, len } => { + pass::insert_debug_marker(&mut state.pass, &base.string_data, len); + } + ArcRayTracingCommand::TraceRays(groups) => { + let scope = PassErrorScope::TraceRays; + trace_rays(&mut state, groups, device).map_pass_err(scope)?; + } + } + } + + Ok(()) +} + +fn set_pipeline( + state: &mut State, + device: &Arc, + pipeline: Arc, +) -> Result<(), RayTracingPassErrorInner> { + pipeline.same_device(device)?; + + state.pipeline = Some(pipeline.clone()); + + let pipeline = state + .pass + .base + .tracker + .ray_tracing_pipelines + .insert_single(pipeline) + .clone(); + + unsafe { + state + .pass + .base + .raw_encoder + .set_ray_tracing_pipeline(pipeline.raw()); + } + + // Rebind resources + pass::change_pipeline_layout::( + &mut state.pass, + &pipeline.layout, + &pipeline.late_sized_buffer_groups, + || {}, + ) +} + +fn trace_rays( + state: &mut State, + dims: [u32; 3], + device: &Device, +) -> Result<(), RayTracingPassErrorInner> { + api_log!("RayTracingPass::trace_rays {dims:?}"); + + state.is_ready()?; + + state.flush_bindings()?; + + let limits = &state.pass.base.device.limits; + + // todo + let dim_size_limit = [ + limits.max_compute_workgroup_size_x * limits.max_compute_workgroups_per_dimension, + limits.max_compute_workgroup_size_y * limits.max_compute_workgroups_per_dimension, + limits.max_compute_workgroup_size_z * limits.max_compute_workgroups_per_dimension, + ]; + + if dims[0] > dim_size_limit[0] { + return Err(RayTracingPassErrorInner::TraceRay( + TraceRayError::InvalidDimensionSize { + current: dims, + limit: dim_size_limit, + }, + )); + } + + if dims[1] > dim_size_limit[1] { + return Err(RayTracingPassErrorInner::TraceRay( + TraceRayError::InvalidDimensionSize { + current: dims, + limit: dim_size_limit, + }, + )); + } + + if dims[2] > dim_size_limit[2] { + return Err(RayTracingPassErrorInner::TraceRay( + TraceRayError::InvalidDimensionSize { + current: dims, + limit: dim_size_limit, + }, + )); + } + + let tot_rays = dims[0] * dims[1] * dims[2]; + + if tot_rays > limits.max_ray_dispatch_count { + return Err(RayTracingPassErrorInner::TraceRay( + TraceRayError::TooManyTotal { + current: tot_rays, + limit: limits.max_ray_dispatch_count, + }, + )); + } + + let current_pipeline = state.pipeline.as_ref().unwrap(); + + unsafe { + state.pass.base.raw_encoder.trace_rays( + dims, + hal::PipelineGroupData { + buffer: current_pipeline.shader_binding_data.raw.as_ref(), + offset: 0, + stride: device.alignments.ray_tracing_pipeline_group_data_alignment as _, + count: 1, + }, + hal::PipelineGroupData { + buffer: current_pipeline.shader_binding_data.raw.as_ref(), + offset: current_pipeline.shader_binding_data.miss_offset, + stride: device.alignments.ray_tracing_pipeline_group_data_alignment as _, + count: 1, + }, + hal::PipelineGroupData { + buffer: current_pipeline.shader_binding_data.raw.as_ref(), + offset: current_pipeline.shader_binding_data.intersection_offset, + stride: device.alignments.ray_tracing_pipeline_group_data_alignment as _, + count: current_pipeline.shader_binding_data.num_intersection_groups, + }, + ); + } + Ok(()) +} diff --git a/wgpu-core/src/command/ray_tracing_pass_commands.rs b/wgpu-core/src/command/ray_tracing_pass_commands.rs new file mode 100644 index 00000000000..ca25a7ec364 --- /dev/null +++ b/wgpu-core/src/command/ray_tracing_pass_commands.rs @@ -0,0 +1,53 @@ +#[cfg(feature = "serde")] +use crate::command::serde_object_reference_struct; +use crate::command::{ArcReferences, ReferenceType}; + +#[cfg(feature = "serde")] +use macro_rules_attribute::apply; + +/// cbindgen:ignore +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", apply(serde_object_reference_struct))] +pub enum RayTracingCommand { + SetBindGroup { + index: u32, + num_dynamic_offsets: usize, + bind_group: Option, + }, + + SetPipeline(R::RayTracingPipeline), + + /// Set a range of immediates to values stored in `immediates_data`. + SetImmediate { + /// The byte offset within the immediate data storage to write to. This + /// must be a multiple of four. + offset: u32, + + /// The number of bytes to write. This must be a multiple of four. + size_bytes: u32, + + /// Index in `immediates_data` of the start of the data + /// to be written. + /// + /// Note: this is not a byte offset like `offset`. Rather, it is the + /// index of the first `u32` element in `immediates_data` to read. + values_offset: u32, + }, + + TraceRays([u32; 3]), + + PushDebugGroup { + color: u32, + len: usize, + }, + + PopDebugGroup, + + InsertDebugMarker { + color: u32, + len: usize, + }, +} + +/// cbindgen:ignore +pub type ArcRayTracingCommand = RayTracingCommand; diff --git a/wgpu-core/src/command/render.rs b/wgpu-core/src/command/render.rs index 46c75c7cdb5..07b87e1f650 100644 --- a/wgpu-core/src/command/render.rs +++ b/wgpu-core/src/command/render.rs @@ -700,7 +700,7 @@ impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> { /// See the compute pass version for an explanation of some ways that /// `flush_bindings` differs between the two types of passes. fn flush_bindings(&mut self) -> Result<(), RenderPassErrorInner> { - flush_bindings_helper(&mut self.pass)?; + flush_bindings_helper(&mut self.pass, crate::resource::MaxIntersectionIndex::Query)?; Ok(()) } diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index 288e90aa338..ca2638863d0 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -18,7 +18,8 @@ use crate::{ pipeline::{ self, RenderPipelineVertexProcessor, ResolvedComputePipelineDescriptor, ResolvedFragmentState, ResolvedGeneralRenderPipelineDescriptor, ResolvedMeshState, - ResolvedProgrammableStageDescriptor, ResolvedTaskState, ResolvedVertexState, + ResolvedProgrammableStageDescriptor, ResolvedRayTracingPipelineDescriptor, + ResolvedTaskState, ResolvedVertexState, }, present, resource::{ @@ -1776,6 +1777,251 @@ impl Global { } } + pub fn device_create_ray_tracing_pipeline( + &self, + device_id: DeviceId, + desc: &pipeline::RayTracingPipelineDescriptor, + id_in: Option, + ) -> ( + id::RayTracingPipelineId, + Option, + ) { + profiling::scope!("Device::create_ray_tracing_pipeline"); + + let hub = &self.hub; + + let fid = hub.ray_tracing_pipelines.prepare(id_in); + + let error = 'error: { + let device = self.hub.devices.get(device_id); + + if let Err(e) = device.check_is_valid() { + break 'error e.into(); + } + + let layout = desc + .layout + .map(|layout| hub.pipeline_layouts.get(layout).get()) + .transpose(); + let layout = match layout { + Ok(layout) => layout, + Err(e) => break 'error e.into(), + }; + + let cache = desc + .cache + .map(|cache| hub.pipeline_caches.get(cache).get()) + .transpose(); + let cache = match cache { + Ok(cache) => cache, + Err(e) => break 'error e.into(), + }; + + let ray_gen_stage = { + let module = hub.shader_modules.get(desc.ray_generation.module).get(); + let module = match module { + Ok(module) => module, + Err(e) => break 'error e.into(), + }; + if module.interface.interface().is_none() && layout.is_none() { + break 'error pipeline::CreateRayTracingPipelineError::Implicit( + pipeline::ImplicitLayoutError::Passthrough( + wgt::ShaderStages::RAY_GENERATION, + ), + ); + } + ResolvedProgrammableStageDescriptor { + module, + entry_point: desc.ray_generation.entry_point.clone(), + constants: desc.ray_generation.constants.clone(), + zero_initialize_workgroup_memory: desc + .ray_generation + .zero_initialize_workgroup_memory, + } + }; + + let ray_miss_stage = { + let module = hub.shader_modules.get(desc.miss.module).get(); + let module = match module { + Ok(module) => module, + Err(e) => break 'error e.into(), + }; + if module.interface.interface().is_none() && layout.is_none() { + break 'error pipeline::CreateRayTracingPipelineError::Implicit( + pipeline::ImplicitLayoutError::Passthrough(wgt::ShaderStages::MISS), + ); + } + ResolvedProgrammableStageDescriptor { + module, + entry_point: desc.miss.entry_point.clone(), + constants: desc.miss.constants.clone(), + zero_initialize_workgroup_memory: desc.miss.zero_initialize_workgroup_memory, + } + }; + + let mut intersection_stages = Vec::with_capacity(desc.intersections.len()); + + for intersection_desc in desc.intersections.iter() { + match intersection_desc { + pipeline::RayTracingIntersectionDescriptor::Triangle { + closest_hit, + any_hit, + } => { + let closest_hit_stage = { + let module = hub.shader_modules.get(closest_hit.module).get(); + let module = match module { + Ok(module) => module, + Err(e) => break 'error e.into(), + }; + if module.interface.interface().is_none() && layout.is_none() { + break 'error pipeline::CreateRayTracingPipelineError::Implicit( + pipeline::ImplicitLayoutError::Passthrough( + wgt::ShaderStages::CLOSEST_HIT, + ), + ); + } + ResolvedProgrammableStageDescriptor { + module, + entry_point: closest_hit.entry_point.clone(), + constants: closest_hit.constants.clone(), + zero_initialize_workgroup_memory: closest_hit + .zero_initialize_workgroup_memory, + } + }; + + let any_hit_stage = match any_hit { + Some(any_hit) => { + let module = hub.shader_modules.get(any_hit.module).get(); + let module = match module { + Ok(module) => module, + Err(e) => break 'error e.into(), + }; + if module.interface.interface().is_none() && layout.is_none() { + break 'error pipeline::CreateRayTracingPipelineError::Implicit( + pipeline::ImplicitLayoutError::Passthrough( + wgt::ShaderStages::ANY_HIT, + ), + ); + } + Some(ResolvedProgrammableStageDescriptor { + module, + entry_point: any_hit.entry_point.clone(), + constants: any_hit.constants.clone(), + zero_initialize_workgroup_memory: any_hit + .zero_initialize_workgroup_memory, + }) + } + None => None, + }; + + intersection_stages.push( + pipeline::RayTracingIntersectionDescriptor::Triangle { + closest_hit: closest_hit_stage, + any_hit: any_hit_stage, + }, + ); + } + } + } + + let desc = ResolvedRayTracingPipelineDescriptor { + label: desc.label.clone(), + layout, + ray_generation: ray_gen_stage, + miss: ray_miss_stage, + intersections: intersection_stages, + max_recursion_depth: desc.max_recursion_depth, + cache, + }; + + #[cfg(feature = "trace")] + let trace_desc = desc.clone().into_trace(); + + let pipeline = match device.create_ray_tracing_pipeline(desc) { + Ok(pair) => pair, + Err(e) => break 'error e, + }; + + #[cfg(feature = "trace")] + if let Some(ref mut trace) = *device.trace.lock() { + trace.add(trace::Action::CreateRayTracingPipeline { + id: pipeline.to_trace(), + desc: trace_desc, + }); + } + + let id = fid.assign(Fallible::Valid(pipeline)); + api_log!("Device::create_compute_pipeline -> {id:?}"); + + return (id, None); + }; + + let id = fid.assign(Fallible::Invalid(Arc::new(desc.label.to_string()))); + + (id, Some(error)) + } + + /// Get an ID of one of the bind group layouts. The ID adds a refcount, + /// which needs to be released by calling `bind_group_layout_drop`. + pub fn ray_tracing_pipeline_get_bind_group_layout( + &self, + pipeline_id: id::RayTracingPipelineId, + index: u32, + id_in: Option, + ) -> ( + id::BindGroupLayoutId, + Option, + ) { + let hub = &self.hub; + + let fid = hub.bind_group_layouts.prepare(id_in); + + let error = 'error: { + let pipeline = match hub.ray_tracing_pipelines.get(pipeline_id).get() { + Ok(pipeline) => pipeline, + Err(e) => break 'error e.into(), + }; + + match pipeline.get_bind_group_layout(index) { + Ok(bgl) => { + #[cfg(feature = "trace")] + if let Some(ref mut trace) = *pipeline.device.trace.lock() { + trace.add(trace::Action::GetRayTracingPipelineBindGroupLayout { + id: bgl.to_trace(), + pipeline: pipeline.to_trace(), + index, + }); + } + + let id = fid.assign(Fallible::Valid(bgl.clone())); + return (id, None); + } + Err(err) => break 'error err, + }; + }; + + let id = fid.assign(Fallible::Invalid(Arc::new(String::new()))); + (id, Some(error)) + } + + pub fn ray_tracing_pipeline_drop(&self, ray_tracing_pipeline_id: id::RayTracingPipelineId) { + profiling::scope!("RayTracingPipeline::drop"); + api_log!("RayTracingPipeline::drop {ray_tracing_pipeline_id:?}"); + + let hub = &self.hub; + + let _pipeline = hub.ray_tracing_pipelines.remove(ray_tracing_pipeline_id); + + #[cfg(feature = "trace")] + if let Ok(pipeline) = _pipeline.get() { + if let Some(t) = pipeline.device.trace.lock().as_mut() { + t.add(trace::Action::DestroyRayTracingPipeline( + pipeline.to_trace(), + )); + } + } + } + /// # Safety /// The `data` argument of `desc` must have been returned by /// [Self::pipeline_cache_get_data] for the same adapter diff --git a/wgpu-core/src/device/queue.rs b/wgpu-core/src/device/queue.rs index a66ff71bea7..be93b46f37b 100644 --- a/wgpu-core/src/device/queue.rs +++ b/wgpu-core/src/device/queue.rs @@ -1857,7 +1857,12 @@ impl Global { // TODO: Tracing let error = 'error: { - match device.require_features(wgpu_types::Features::EXPERIMENTAL_RAY_QUERY) { + match device + .require_features(wgpu_types::Features::EXPERIMENTAL_RAY_QUERY) + .or_else(|_| { + device + .require_features(wgpu_types::Features::EXPERIMENTAL_RAY_TRACING_PIPELINES) + }) { Ok(_) => {} Err(err) => break 'error err.into(), } diff --git a/wgpu-core/src/device/ray_tracing.rs b/wgpu-core/src/device/ray_tracing.rs index da7c0da667e..0f0265692fc 100644 --- a/wgpu-core/src/device/ray_tracing.rs +++ b/wgpu-core/src/device/ray_tracing.rs @@ -3,7 +3,8 @@ use core::mem::{size_of, ManuallyDrop}; #[cfg(feature = "trace")] use crate::device::trace::{Action, IntoTrace}; -use crate::device::DeviceError; +use crate::device::{DeviceError, ENTRYPOINT_FAILURE_ERROR}; +use crate::resource::ParentDevice; use crate::{ api_log, device::Device, @@ -21,6 +22,7 @@ use crate::{ snatch::Snatchable, LabelHelpers, }; +use crate::{binding_model, pipeline, FastHashMap}; use hal::AccelerationStructureTriangleIndices; use wgt::{Features, AABB_GEOMETRY_MIN_STRIDE}; @@ -31,7 +33,8 @@ impl Device { sizes: wgt::BlasGeometrySizeDescriptors, ) -> Result, CreateBlasError> { self.check_is_valid()?; - self.require_features(Features::EXPERIMENTAL_RAY_QUERY)?; + self.require_features(Features::EXPERIMENTAL_RAY_QUERY) + .or_else(|_| self.require_features(Features::EXPERIMENTAL_RAY_TRACING_PIPELINES))?; if blas_desc .flags @@ -214,7 +217,8 @@ impl Device { desc: &resource::TlasDescriptor, ) -> Result, CreateTlasError> { self.check_is_valid()?; - self.require_features(Features::EXPERIMENTAL_RAY_QUERY)?; + self.require_features(Features::EXPERIMENTAL_RAY_QUERY) + .or_else(|_| self.require_features(Features::EXPERIMENTAL_RAY_TRACING_PIPELINES))?; if desc.max_instances > self.limits.max_tlas_instance_count { return Err(CreateTlasError::TooManyInstances( @@ -289,12 +293,422 @@ impl Device { update_mode: desc.update_mode, built_index: RwLock::new(rank::TLAS_BUILT_INDEX, None), dependencies: RwLock::new(rank::TLAS_DEPENDENCIES, Vec::new()), + max_intersection_index: RwLock::new( + rank::TLAS_MAX_INTERSECTION_IDX, + resource::MaxIntersectionIndex::Unused, + ), instance_buffer: ManuallyDrop::new(instance_buffer), label: desc.label.to_string(), max_instance_count: desc.max_instances, tracking_data: TrackingData::new(self.tracker_indices.tlas_s.clone()), })) } + + pub fn create_ray_tracing_pipeline( + self: &Arc, + desc: pipeline::ResolvedRayTracingPipelineDescriptor, + ) -> Result, pipeline::CreateRayTracingPipelineError> { + use crate::validation; + + self.check_is_valid()?; + self.require_features(Features::EXPERIMENTAL_RAY_TRACING_PIPELINES)?; + + let mut shader_binding_sizes = FastHashMap::default(); + + let mut io = validation::StageIo::default(); + + let is_auto_layout = desc.layout.is_none(); + + // Get the pipeline layout from the desc if it is provided. + let pipeline_layout = match desc.layout { + Some(pipeline_layout) => { + pipeline_layout.same_device(self)?; + Some(pipeline_layout) + } + None => None, + }; + + let mut binding_layout_source = match pipeline_layout { + Some(pipeline_layout) => validation::BindingLayoutSource::Provided(pipeline_layout), + None => validation::BindingLayoutSource::new_derived(&self.limits), + }; + + // The size the immediates will be if there is no pipeline layout + let mut derived_immediate_size = 0; + + let final_ray_gen_name; + let ray_generation = { + let stage = validation::ShaderStageForValidation::RayGeneration; + let stage_bit = stage.to_wgt_bit(); + + final_ray_gen_name = desc + .ray_generation + .module + .finalize_entry_point_name( + stage.to_naga(), + desc.ray_generation + .entry_point + .as_ref() + .map(|ep| ep.as_ref()), + ) + .map_err(|e| pipeline::CreateRayTracingPipelineError::Stage { + stage: stage_bit, + error: e, + })?; + + if let Some(interface) = desc.ray_generation.module.interface.interface() { + io = interface + .check_stage( + &mut binding_layout_source, + &mut shader_binding_sizes, + &final_ray_gen_name, + stage, + io, + None, + ) + .map_err(|e| pipeline::CreateRayTracingPipelineError::Stage { + stage: stage_bit, + error: e, + })?; + + derived_immediate_size = derived_immediate_size.max(interface.immediate_size); + } + + hal::ProgrammableStage { + module: desc.ray_generation.module.raw(), + entry_point: &final_ray_gen_name, + constants: &desc.ray_generation.constants, + zero_initialize_workgroup_memory: desc + .ray_generation + .zero_initialize_workgroup_memory, + } + }; + + let final_miss_name; + let miss = { + let stage = validation::ShaderStageForValidation::Miss; + let stage_bit = stage.to_wgt_bit(); + + final_miss_name = desc + .miss + .module + .finalize_entry_point_name( + stage.to_naga(), + desc.miss.entry_point.as_ref().map(|ep| ep.as_ref()), + ) + .map_err(|e| pipeline::CreateRayTracingPipelineError::Stage { + stage: stage_bit, + error: e, + })?; + + if let Some(interface) = desc.miss.module.interface.interface() { + io = interface + .check_stage( + &mut binding_layout_source, + &mut shader_binding_sizes, + &final_miss_name, + stage, + io, + None, + ) + .map_err(|e| pipeline::CreateRayTracingPipelineError::Stage { + stage: stage_bit, + error: e, + })?; + + derived_immediate_size = derived_immediate_size.max(interface.immediate_size); + } + + hal::ProgrammableStage { + module: desc.miss.module.raw(), + entry_point: &final_miss_name, + constants: &desc.miss.constants, + zero_initialize_workgroup_memory: desc.miss.zero_initialize_workgroup_memory, + } + }; + + if desc.intersections.len() > self.limits.max_intersection_group_count as usize { + return Err( + pipeline::CreateRayTracingPipelineError::TooManyIntersectionGroups( + desc.intersections.len(), + self.limits.max_intersection_group_count, + ), + ); + } + + if desc.max_recursion_depth > self.limits.max_ray_recursion_depth { + return Err( + pipeline::CreateRayTracingPipelineError::TooHighRayRecursionDepth( + desc.max_recursion_depth, + self.limits.max_ray_recursion_depth, + ), + ); + } + + let mut intersections = Vec::with_capacity(desc.intersections.len()); + let mut final_intersection_names = Vec::with_capacity(desc.intersections.len()); + + for intersection in &desc.intersections { + match intersection { + pipeline::RayTracingIntersectionDescriptor::Triangle { + closest_hit, + any_hit, + } => { + let stage = validation::ShaderStageForValidation::ClosestHit { triangle: true }; + let closest_name = closest_hit + .module + .finalize_entry_point_name( + stage.to_naga(), + closest_hit.entry_point.as_ref().map(|ep| ep.as_ref()), + ) + .map_err(|e| pipeline::CreateRayTracingPipelineError::Stage { + stage: stage.to_wgt_bit(), + error: e, + })?; + + let any_hit = match any_hit { + Some(any_hit) => { + let stage = + validation::ShaderStageForValidation::AnyHit { triangle: true }; + + Some( + any_hit + .module + .finalize_entry_point_name( + stage.to_naga(), + any_hit.entry_point.as_ref().map(|ep| ep.as_ref()), + ) + .map_err(|e| { + pipeline::CreateRayTracingPipelineError::Stage { + stage: stage.to_wgt_bit(), + error: e, + } + })?, + ) + } + None => None, + }; + + final_intersection_names.push((closest_name, any_hit)); + } + } + } + + for (intersection, (final_closest_name, final_any_name)) in desc + .intersections + .iter() + .zip(final_intersection_names.iter()) + { + intersections.push(match intersection { + pipeline::RayTracingIntersectionDescriptor::Triangle { + closest_hit, + any_hit, + } => { + let closest_hit = { + let stage = + validation::ShaderStageForValidation::ClosestHit { triangle: true }; + + let stage_bits = stage.to_wgt_bit(); + if let Some(interface) = closest_hit.module.interface.interface() { + io = interface + .check_stage( + &mut binding_layout_source, + &mut shader_binding_sizes, + final_closest_name, + stage, + io, + None, + ) + .map_err(|e| pipeline::CreateRayTracingPipelineError::Stage { + stage: stage_bits, + error: e, + })?; + + derived_immediate_size = + derived_immediate_size.max(interface.immediate_size); + } + + hal::ProgrammableStage { + module: closest_hit.module.raw(), + entry_point: final_closest_name, + constants: &closest_hit.constants, + zero_initialize_workgroup_memory: closest_hit + .zero_initialize_workgroup_memory, + } + }; + + let any_hit = match any_hit { + Some(any_hit) => { + let stage = + validation::ShaderStageForValidation::AnyHit { triangle: true }; + + let final_any_name = final_any_name.as_ref().unwrap(); + + let stage_bits = stage.to_wgt_bit(); + if let Some(interface) = any_hit.module.interface.interface() { + io = interface + .check_stage( + &mut binding_layout_source, + &mut shader_binding_sizes, + final_any_name, + stage, + io, + None, + ) + .map_err(|e| { + pipeline::CreateRayTracingPipelineError::Stage { + stage: stage_bits, + error: e, + } + })?; + + derived_immediate_size = + derived_immediate_size.max(interface.immediate_size); + } + + Some(hal::ProgrammableStage { + module: any_hit.module.raw(), + entry_point: final_any_name, + constants: &any_hit.constants, + zero_initialize_workgroup_memory: any_hit + .zero_initialize_workgroup_memory, + }) + } + None => None, + }; + + hal::RayObjectIntersectionState { + closest_hit, + any_hit, + } + } + }); + } + + if !self + .downlevel + .flags + .contains(wgt::DownlevelFlags::BUFFER_BINDINGS_NOT_16_BYTE_ALIGNED) + { + for (binding, size) in shader_binding_sizes.iter() { + if size.get() % 16 != 0 { + return Err(pipeline::CreateRayTracingPipelineError::UnalignedShader { + binding: binding.binding, + group: binding.group, + size: size.get(), + }); + } + } + } + + let pipeline_layout = match binding_layout_source { + validation::BindingLayoutSource::Provided(layout) => layout, + validation::BindingLayoutSource::Derived(entries) => { + self.create_derived_pipeline_layout(entries, derived_immediate_size)? + } + }; + + let late_sized_buffer_groups = + Device::make_late_sized_buffer_groups(&shader_binding_sizes, &pipeline_layout); + + let cache = match desc.cache { + Some(cache) => { + cache.same_device(self)?; + Some(cache) + } + None => None, + }; + + let raw = { + let pipeline_desc = hal::RayTracingPipelineDescriptor { + label: desc.label.to_hal(self.instance_flags), + layout: pipeline_layout.raw(), + cache: cache.as_ref().map(|it| it.raw()), + ray_generation, + miss, + intersection: &intersections, + max_recursion_depth: desc.max_recursion_depth, + }; + unsafe { self.raw().create_ray_tracing_pipeline(&pipeline_desc) }.map_err(|err| { + match err { + hal::PipelineError::Device(error) => { + pipeline::CreateRayTracingPipelineError::Device( + self.handle_hal_error(error), + ) + } + hal::PipelineError::Linkage(stage, msg) => { + pipeline::CreateRayTracingPipelineError::Internal { stage, error: msg } + } + hal::PipelineError::EntryPoint(stage) => { + pipeline::CreateRayTracingPipelineError::Internal { + stage: hal::auxil::map_naga_stage(stage), + error: ENTRYPOINT_FAILURE_ERROR.to_string(), + } + } + hal::PipelineError::PipelineConstants(stage, error) => { + pipeline::CreateRayTracingPipelineError::PipelineConstants { stage, error } + } + } + })? + }; + + let shader_modules = { + let mut shader_modules = Vec::new(); + shader_modules.push(desc.ray_generation.module); + shader_modules.push(desc.miss.module); + shader_modules.reserve(desc.intersections.len()); + for intersection in &desc.intersections { + match intersection { + pipeline::RayTracingIntersectionDescriptor::Triangle { + closest_hit, + any_hit, + } => { + shader_modules.push(closest_hit.module.clone()); + if let Some(any) = any_hit { + shader_modules.push(any.module.clone()); + } + } + } + } + shader_modules + }; + + let shader_binding_data = pipeline::ShaderBindingData::from_raw_pipeline( + self.clone(), + raw.as_ref(), + desc.intersections.len(), + )?; + + let pipeline = pipeline::RayTracingPipeline { + raw: ManuallyDrop::new(raw), + layout: pipeline_layout, + device: self.clone(), + _shader_modules: shader_modules, + late_sized_buffer_groups, + shader_binding_data, + label: desc.label.to_string(), + tracking_data: TrackingData::new(self.tracker_indices.ray_tracing_pipelines.clone()), + }; + + let pipeline = Arc::new(pipeline); + + if is_auto_layout { + for bgl in pipeline.layout.bind_group_layouts.iter() { + let Some(bgl) = bgl else { + continue; + }; + + // `bind_group_layouts` might contain duplicate entries, so we need to ignore the result. + let _ = bgl + .exclusive_pipeline + .set(binding_model::ExclusivePipeline::RayTracing( + Arc::downgrade(&pipeline), + )); + } + } + + Ok(pipeline) + } } impl Global { diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index b8809a41e7b..dba6bdb0eb0 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -1015,7 +1015,10 @@ impl Device { .usage .intersects(wgt::BufferUsages::BLAS_INPUT | wgt::BufferUsages::TLAS_INPUT) { - self.require_features(wgt::Features::EXPERIMENTAL_RAY_QUERY)?; + self.require_features(wgt::Features::EXPERIMENTAL_RAY_QUERY) + .or_else(|_| { + self.require_features(wgt::Features::EXPERIMENTAL_RAY_TRACING_PIPELINES) + })?; } if desc.usage.contains(wgt::BufferUsages::INDEX) @@ -2531,7 +2534,7 @@ impl Device { /// Generate information about late-validated buffer bindings for pipelines. //TODO: should this be combined with `get_introspection_bind_group_layouts` in some way? - fn make_late_sized_buffer_groups( + pub(super) fn make_late_sized_buffer_groups( shader_binding_sizes: &FastHashMap, layout: &binding_model::PipelineLayout, ) -> ArrayVec { @@ -2751,6 +2754,9 @@ impl Device { } Bt::AccelerationStructure { vertex_return } => { self.require_features(wgt::Features::EXPERIMENTAL_RAY_QUERY) + .or_else(|_| { + self.require_features(wgt::Features::EXPERIMENTAL_RAY_TRACING_PIPELINES) + }) .map_err(|e| CreateBindGroupLayoutError::Entry { binding: entry.binding, error: e.into(), @@ -3806,7 +3812,7 @@ impl Device { Ok(layout) } - fn create_derived_pipeline_layout( + pub(super) fn create_derived_pipeline_layout( self: &Arc, mut derived_group_layouts: Box>, immediate_size: u32, diff --git a/wgpu-core/src/device/trace.rs b/wgpu-core/src/device/trace.rs index 8852798e232..360704142d5 100644 --- a/wgpu-core/src/device/trace.rs +++ b/wgpu-core/src/device/trace.rs @@ -168,6 +168,11 @@ pub enum Action<'a, R: ReferenceType> { pipeline: PointerId, index: u32, }, + GetRayTracingPipelineBindGroupLayout { + id: PointerId, + pipeline: PointerId, + index: u32, + }, DestroyBindGroupLayout(PointerId), CreatePipelineLayout( PointerId, @@ -202,6 +207,11 @@ pub enum Action<'a, R: ReferenceType> { desc: TraceGeneralRenderPipelineDescriptor<'a>, }, DestroyRenderPipeline(PointerId), + CreateRayTracingPipeline { + id: PointerId, + desc: TraceRayTracingPipelineDescriptor<'a>, + }, + DestroyRayTracingPipeline(PointerId), CreatePipelineCache { id: PointerId, desc: crate::pipeline::PipelineCacheDescriptor<'a>, @@ -285,3 +295,14 @@ pub type TraceComputePipelineDescriptor<'a> = crate::pipeline::ComputePipelineDe PointerId, PointerId, >; + +/// Not a public API. For use by `player` only. +/// +/// cbindgen:ignore +#[doc(hidden)] +pub type TraceRayTracingPipelineDescriptor<'a> = crate::pipeline::RayTracingPipelineDescriptor< + 'a, + PointerId, + PointerId, + PointerId, +>; diff --git a/wgpu-core/src/device/trace/record.rs b/wgpu-core/src/device/trace/record.rs index 630e49affd1..137ec630db9 100644 --- a/wgpu-core/src/device/trace/record.rs +++ b/wgpu-core/src/device/trace/record.rs @@ -4,11 +4,12 @@ use std::io::Write as _; use crate::{ command::{ - ArcCommand, ArcComputeCommand, ArcPassTimestampWrites, ArcReferences, ArcRenderCommand, - BasePass, ColorAttachments, Command, ComputeCommand, PointerReferences, RenderCommand, - RenderPassColorAttachment, ResolvedRenderPassDepthStencilAttachment, + ArcCommand, ArcComputeCommand, ArcPassTimestampWrites, ArcRayTracingCommand, ArcReferences, + ArcRenderCommand, BasePass, ColorAttachments, Command, ComputeCommand, PointerReferences, + RayTracingCommand, RenderCommand, RenderPassColorAttachment, + ResolvedRenderPassDepthStencilAttachment, }, - device::trace::{Data, DataKind}, + device::trace::{Data, DataKind, TraceRayTracingPipelineDescriptor}, id::{markers, PointerId}, storage::StorageItem, }; @@ -264,6 +265,9 @@ impl IntoTrace for ArcCommand { occlusion_query_set: occlusion_query_set.map(|q| q.to_trace()), multiview_mask, }, + ArcCommand::RunRayTracingPass { pass } => Command::RunRayTracingPass { + pass: pass.into_trace(), + }, ArcCommand::BuildAccelerationStructures { blas, tlas } => { Command::BuildAccelerationStructures { blas: blas.into_iter().map(|b| b.into_trace()).collect(), @@ -422,6 +426,7 @@ impl IntoTrace for crate::ray_tracing::OwnedTlasInstance { transform: self.transform, custom_data: self.custom_data, mask: self.mask, + intersection_index: self.intersection_index, } } } @@ -662,6 +667,38 @@ impl IntoTrace for ArcRenderCommand { } } +impl IntoTrace for ArcRayTracingCommand { + type Output = RayTracingCommand; + fn into_trace(self) -> Self::Output { + use RayTracingCommand as C; + match self { + C::SetBindGroup { + index, + num_dynamic_offsets, + bind_group, + } => C::SetBindGroup { + index, + num_dynamic_offsets, + bind_group: bind_group.map(|bg| bg.into_trace()), + }, + C::SetPipeline(id) => C::SetPipeline(id.into_trace()), + C::SetImmediate { + offset, + size_bytes, + values_offset, + } => C::SetImmediate { + offset, + size_bytes, + values_offset, + }, + C::TraceRays(groups) => C::TraceRays(groups), + C::PushDebugGroup { color, len } => C::PushDebugGroup { color, len }, + C::PopDebugGroup => C::PopDebugGroup, + C::InsertDebugMarker { color, len } => C::InsertDebugMarker { color, len }, + } + } +} + impl IntoTrace for crate::binding_model::ResolvedPipelineLayoutDescriptor<'_> { type Output = crate::binding_model::PipelineLayoutDescriptor< 'static, @@ -783,6 +820,34 @@ impl<'a> IntoTrace for crate::pipeline::ResolvedComputePipelineDescriptor<'a> { } } +impl<'a> IntoTrace for crate::pipeline::ResolvedRayTracingPipelineDescriptor<'a> { + type Output = TraceRayTracingPipelineDescriptor<'a>; + + fn into_trace(self) -> Self::Output { + TraceRayTracingPipelineDescriptor { + label: self.label, + layout: self.layout.into_trace(), + ray_generation: self.ray_generation.into_trace(), + miss: self.miss.into_trace(), + intersections: self + .intersections + .into_iter() + .map(|intersection| match intersection { + crate::pipeline::RayTracingIntersectionDescriptor::Triangle { + closest_hit, + any_hit, + } => crate::pipeline::RayTracingIntersectionDescriptor::Triangle { + closest_hit: closest_hit.into_trace(), + any_hit: any_hit.map(|a| a.into_trace()), + }, + }) + .collect(), + cache: self.cache.map(|c| c.into_trace()), + max_recursion_depth: self.max_recursion_depth, + } + } +} + impl<'a> IntoTrace for crate::pipeline::ResolvedProgrammableStageDescriptor<'a> { type Output = crate::pipeline::ProgrammableStageDescriptor<'a, PointerId>; @@ -924,11 +989,21 @@ fn action_to_owned(action: Action<'_, PointerReferences>) -> Action<'static, Poi pipeline, index, }, + A::GetRayTracingPipelineBindGroupLayout { + id, + pipeline, + index, + } => A::GetRayTracingPipelineBindGroupLayout { + id, + pipeline, + index, + }, A::DestroyPipelineLayout(layout) => A::DestroyPipelineLayout(layout), A::DestroyBindGroup(bind_group) => A::DestroyBindGroup(bind_group), A::DestroyShaderModule(shader_module) => A::DestroyShaderModule(shader_module), A::DestroyComputePipeline(pipeline) => A::DestroyComputePipeline(pipeline), A::DestroyRenderPipeline(pipeline) => A::DestroyRenderPipeline(pipeline), + A::DestroyRayTracingPipeline(pipeline) => A::DestroyRayTracingPipeline(pipeline), A::DestroyPipelineCache(cache) => A::DestroyPipelineCache(cache), A::DestroyRenderBundle(render_bundle) => A::DestroyRenderBundle(render_bundle), A::DestroyQuerySet(query_set) => A::DestroyQuerySet(query_set), @@ -980,6 +1055,7 @@ fn action_to_owned(action: Action<'_, PointerReferences>) -> Action<'static, Poi | A::CreateShaderModulePassthrough { .. } | A::CreateComputePipeline { .. } | A::CreateGeneralRenderPipeline { .. } + | A::CreateRayTracingPipeline { .. } | A::CreatePipelineCache { .. } | A::CreateRenderBundle { .. } | A::CreateQuerySet { .. } diff --git a/wgpu-core/src/hub.rs b/wgpu-core/src/hub.rs index d21e6565f88..1e68494820a 100644 --- a/wgpu-core/src/hub.rs +++ b/wgpu-core/src/hub.rs @@ -124,7 +124,7 @@ use crate::{ command::{CommandBuffer, CommandEncoder, RenderBundle}, device::{queue::Queue, Device}, instance::Adapter, - pipeline::{ComputePipeline, PipelineCache, RenderPipeline, ShaderModule}, + pipeline::{ComputePipeline, PipelineCache, RayTracingPipeline, RenderPipeline, ShaderModule}, registry::{Registry, RegistryReport}, resource::{ Blas, Buffer, ExternalTexture, Fallible, QuerySet, Sampler, StagingBuffer, Texture, @@ -146,6 +146,7 @@ pub struct HubReport { pub render_bundles: RegistryReport, pub render_pipelines: RegistryReport, pub compute_pipelines: RegistryReport, + pub ray_tracing_pipelines: RegistryReport, pub pipeline_caches: RegistryReport, pub query_sets: RegistryReport, pub buffers: RegistryReport, @@ -195,6 +196,7 @@ pub struct Hub { pub(crate) render_bundles: Registry>, pub(crate) render_pipelines: Registry>, pub(crate) compute_pipelines: Registry>, + pub(crate) ray_tracing_pipelines: Registry>, pub(crate) pipeline_caches: Registry>, pub(crate) query_sets: Registry>, pub(crate) buffers: Registry>, @@ -222,6 +224,7 @@ impl Hub { render_bundles: Registry::new(), render_pipelines: Registry::new(), compute_pipelines: Registry::new(), + ray_tracing_pipelines: Registry::new(), pipeline_caches: Registry::new(), query_sets: Registry::new(), buffers: Registry::new(), @@ -249,6 +252,7 @@ impl Hub { render_bundles: self.render_bundles.generate_report(), render_pipelines: self.render_pipelines.generate_report(), compute_pipelines: self.compute_pipelines.generate_report(), + ray_tracing_pipelines: self.ray_tracing_pipelines.generate_report(), pipeline_caches: self.pipeline_caches.generate_report(), query_sets: self.query_sets.generate_report(), buffers: self.buffers.generate_report(), diff --git a/wgpu-core/src/id.rs b/wgpu-core/src/id.rs index fdfff591167..7608b684258 100644 --- a/wgpu-core/src/id.rs +++ b/wgpu-core/src/id.rs @@ -338,6 +338,7 @@ ids! { pub type QuerySetId QuerySet; pub type BlasId Blas; pub type TlasId Tlas; + pub type RayTracingPipelineId RayTracingPipeline; } #[test] diff --git a/wgpu-core/src/lock/rank.rs b/wgpu-core/src/lock/rank.rs index 527abf684ef..aeb80f806c9 100644 --- a/wgpu-core/src/lock/rank.rs +++ b/wgpu-core/src/lock/rank.rs @@ -150,6 +150,7 @@ define_lock_ranks! { rank BLAS_COMPACTION_STATE "Blas::compaction_size" followed by { } rank TLAS_BUILT_INDEX "Tlas::built_index" followed by { } rank TLAS_DEPENDENCIES "Tlas::dependencies" followed by { } + rank TLAS_MAX_INTERSECTION_IDX "Tlas::max_intersection_index" followed by { } rank BUFFER_POOL "BufferPool::buffers" followed by { } #[cfg(test)] diff --git a/wgpu-core/src/pipeline.rs b/wgpu-core/src/pipeline.rs index 6a8bfb96060..5a0025f2c6d 100644 --- a/wgpu-core/src/pipeline.rs +++ b/wgpu-core/src/pipeline.rs @@ -5,7 +5,12 @@ use alloc::{ sync::Arc, vec::Vec, }; -use core::{marker::PhantomData, mem::ManuallyDrop, num::NonZeroU32}; +use core::{ + marker::PhantomData, + mem::ManuallyDrop, + num::{NonZeroU32, NonZeroU64}, +}; +use hal::BufferDescriptor; use arrayvec::ArrayVec; use naga::error::ShaderError; @@ -883,3 +888,255 @@ impl RenderPipeline { self.layout.get_bind_group_layout(index, self.into()) } } + +#[derive(Clone, Debug, Error)] +#[non_exhaustive] +pub enum CreateRayTracingPipelineError { + #[error(transparent)] + Device(#[from] DeviceError), + #[error("Unable to derive an implicit layout")] + Implicit(#[from] ImplicitLayoutError), + #[error(transparent)] + MissingFeatures(#[from] MissingFeatures), + #[error("Error matching {stage:?} shader requirements against the pipeline")] + Stage { + stage: wgt::ShaderStages, + #[source] + error: validation::StageError, + }, + #[error("In the provided shader, the type given for group {group} binding {binding} has a size of {size}. As the device does not support `DownlevelFlags::BUFFER_BINDINGS_NOT_16_BYTE_ALIGNED`, the type must have a size that is a multiple of 16 bytes.")] + UnalignedShader { group: u32, binding: u32, size: u64 }, + #[error("Internal error in {stage:?} shader: {error}")] + Internal { + stage: wgt::ShaderStages, + error: String, + }, + #[error("Pipeline constant error in {stage:?} shader: {error}")] + PipelineConstants { + stage: wgt::ShaderStages, + error: String, + }, + #[error(transparent)] + InvalidResource(#[from] InvalidResourceError), + #[error("The number of intersection shaders {0} is greater than the limit Limits::max_intersection_group_count {1}")] + TooManyIntersectionGroups(usize, u32), + #[error( + "The ray recursion depth {0} is greater than the limit Limits::max_ray_recursion_depth {1}" + )] + TooHighRayRecursionDepth(u32, u32), +} + +impl WebGpuError for CreateRayTracingPipelineError { + fn webgpu_error_type(&self) -> ErrorType { + match self { + Self::Device(e) => e.webgpu_error_type(), + Self::InvalidResource(e) => e.webgpu_error_type(), + Self::MissingFeatures(e) => e.webgpu_error_type(), + + Self::Internal { .. } => ErrorType::Internal, + + Self::Implicit(_) + | Self::Stage { .. } + | Self::UnalignedShader { .. } + | Self::PipelineConstants { .. } + | Self::TooManyIntersectionGroups(..) + | Self::TooHighRayRecursionDepth(..) => ErrorType::Validation, + } + } +} + +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum RayTracingIntersectionDescriptor<'a, SM = ShaderModuleId> { + Triangle { + closest_hit: ProgrammableStageDescriptor<'a, SM>, + any_hit: Option>, + }, +} + +/// Describes a ray tracing pipeline. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct RayTracingPipelineDescriptor< + 'a, + PLL = PipelineLayoutId, + SM = ShaderModuleId, + PLC = PipelineCacheId, +> { + pub label: Label<'a>, + /// The layout of bind groups for this pipeline. + pub layout: Option, + /// The ray generation stage descriptor + pub ray_generation: ProgrammableStageDescriptor<'a, SM>, + /// The miss stage descriptor + pub miss: ProgrammableStageDescriptor<'a, SM>, + /// All the descriptors for intersection functions + pub intersections: Vec>, + /// The maximum amount entry points are allowed to recurse + pub max_recursion_depth: u32, + /// The pipeline cache to use when creating this pipeline. + pub cache: Option, +} + +pub type ResolvedRayTracingPipelineDescriptor<'a> = + RayTracingPipelineDescriptor<'a, Arc, Arc, Arc>; + +/// This is a semi-opaque structure because if metal gets ray tracing pipelines, +/// this will need to turn into an enum so it shouldn't have other code +/// tangled with it. +#[derive(Debug)] +pub struct ShaderBindingData { + pub(crate) raw: ManuallyDrop>, + pub(crate) device: Arc, + pub(crate) num_intersection_groups: u64, + pub(crate) miss_offset: wgt::BufferAddress, + pub(crate) intersection_offset: wgt::BufferAddress, +} + +impl ShaderBindingData { + pub(crate) fn from_raw_pipeline( + device: Arc, + pipeline: &dyn hal::DynRayTracingPipeline, + num_intersection_groups: usize, + ) -> Result { + let mut base_data = Vec::new(); + + let closest_hit_data = unsafe { + device + .raw() + .get_raytracing_pipeline_group_data(pipeline, 0..1) + } + .map_err(|e| CreateRayTracingPipelineError::Device(device.handle_hal_error(e)))?; + + base_data.extend_from_slice(&closest_hit_data); + + let padded_miss_offset = (base_data.len() as wgt::BufferAddress).next_multiple_of( + device.alignments.ray_tracing_pipeline_data_offset_alignment as wgt::BufferAddress, + ); + + base_data.resize(padded_miss_offset as _, 0); + + let miss_data = unsafe { + device + .raw() + .get_raytracing_pipeline_group_data(pipeline, 1..2) + } + .map_err(|e| CreateRayTracingPipelineError::Device(device.handle_hal_error(e)))?; + + base_data.extend_from_slice(&miss_data); + + let padded_intersection_offset = (base_data.len() as wgt::BufferAddress).next_multiple_of( + device.alignments.ray_tracing_pipeline_data_offset_alignment as wgt::BufferAddress, + ); + + base_data.resize(padded_intersection_offset as _, 0); + + let intersection_data = unsafe { + device.raw().get_raytracing_pipeline_group_data( + pipeline, + 2..(2 + num_intersection_groups as u32), + ) + } + .map_err(|e| CreateRayTracingPipelineError::Device(device.handle_hal_error(e)))?; + + base_data.extend_from_slice(&intersection_data); + + let buffer = unsafe { + device.raw().create_buffer(&BufferDescriptor { + label: None, + size: base_data.len() as _, + usage: wgt::BufferUses::RAY_TRACING_PIPELINE_SHADER_DATA + | wgt::BufferUses::COPY_DST, + memory_flags: hal::MemoryFlags::PREFER_COHERENT, + }) + } + .map_err(|e| CreateRayTracingPipelineError::Device(device.handle_hal_error(e)))?; + + // If there is no queue anymore, the ray tracing pipeline can't be accessed, so we don't have to worry about UB from uninitialized values + if let Some(queue) = device.get_queue() { + let mut staging = crate::resource::StagingBuffer::new(&device, NonZeroU64::new(base_data.len() as _).expect("The total number of groups is always greater than zero, and `ray_tracing_pipeline_group_data_size` must be too."))?; + + staging.write(&base_data); + + let staging_buf = staging.flush(); + + let mut writes = queue.pending_writes.lock(); + let encoder = writes.activate(); + unsafe { + encoder.copy_buffer_to_buffer( + staging_buf.raw(), + buffer.as_ref(), + &[hal::BufferCopy { + src_offset: 0, + dst_offset: 0, + size: NonZeroU64::new(base_data.len() as _) + .expect("Already checked size isn't zero."), + }], + ) + }; + + writes.consume(staging_buf); + } + + Ok(Self { + raw: ManuallyDrop::new(buffer), + device, + num_intersection_groups: num_intersection_groups as _, + miss_offset: padded_miss_offset, + intersection_offset: padded_intersection_offset, + }) + } +} + +impl Drop for ShaderBindingData { + fn drop(&mut self) { + // SAFETY: We are in the Drop impl and we don't use self.raw anymore after this point. + let raw = unsafe { ManuallyDrop::take(&mut self.raw) }; + unsafe { + self.device.raw().destroy_buffer(raw); + } + } +} + +#[derive(Debug)] +pub struct RayTracingPipeline { + pub(crate) raw: ManuallyDrop>, + pub(crate) device: Arc, + pub(crate) layout: Arc, + pub(crate) _shader_modules: Vec>, + pub(crate) late_sized_buffer_groups: ArrayVec, + pub(crate) shader_binding_data: ShaderBindingData, + /// The `label` from the descriptor used to create the resource. + pub(crate) label: String, + pub(crate) tracking_data: TrackingData, +} + +impl Drop for RayTracingPipeline { + fn drop(&mut self) { + resource_log!("Destroy raw {}", self.error_ident()); + // SAFETY: We are in the Drop impl and we don't use self.raw anymore after this point. + let raw = unsafe { ManuallyDrop::take(&mut self.raw) }; + unsafe { + self.device.raw().destroy_ray_tracing_pipeline(raw); + } + } +} + +crate::impl_resource_type!(RayTracingPipeline); +crate::impl_labeled!(RayTracingPipeline); +crate::impl_parent_device!(RayTracingPipeline); +crate::impl_storage_item!(RayTracingPipeline); +crate::impl_trackable!(RayTracingPipeline); + +impl RayTracingPipeline { + pub(crate) fn raw(&self) -> &dyn hal::DynRayTracingPipeline { + self.raw.as_ref() + } + + pub fn get_bind_group_layout( + self: &Arc, + index: u32, + ) -> Result, GetBindGroupLayoutError> { + self.layout.get_bind_group_layout(index, self.into()) + } +} diff --git a/wgpu-core/src/ray_tracing.rs b/wgpu-core/src/ray_tracing.rs index 05973a917c9..ea696ac7fa0 100644 --- a/wgpu-core/src/ray_tracing.rs +++ b/wgpu-core/src/ray_tracing.rs @@ -26,7 +26,8 @@ use crate::{ id::{BlasId, BufferId, TlasId}, resource::{ Blas, BlasCompactCallback, BlasPrepareCompactResult, DestroyedResourceError, - InvalidResourceError, MissingBufferUsageError, ResourceErrorIdent, Tlas, + InvalidResourceError, MaxIntersectionIndex, MissingBufferUsageError, ResourceErrorIdent, + Tlas, }, }; @@ -180,10 +181,12 @@ pub enum BuildAccelerationStructureError { #[error("Blas {0:?} build sizes require index buffer but none was provided")] MissingIndexBuffer(ResourceErrorIdent), + #[error("Tlas {0:?} instance {1} contains an invalid custom data (more than 24bits)")] + TlasInvalidCustomData(ResourceErrorIdent, usize), #[error( - "Tlas {0:?} an associated instances contains an invalid custom index (more than 24bits)" + "Tlas {0:?} instance {1} contains an invalid query data in the `intersection_index` field (more than 24bits)" )] - TlasInvalidCustomIndex(ResourceErrorIdent), + TlasInvalidQueryData(ResourceErrorIdent, usize), #[error( "Tlas {0:?} has {1} active instances but only {2} are allowed as specified by the descriptor at creation" @@ -213,6 +216,12 @@ pub enum BuildAccelerationStructureError { #[error("Blas {0:?} AABB stride is invalid (must be >= {1} and a multiple of 8)")] InvalidAabbStride(ResourceErrorIdent, BufferAddress), + #[error("Tlas {0:?} instance {1} has a different Intersection")] + TlasInstancesIntersectionIndicesDiffer(ResourceErrorIdent, usize), + #[error( + "Tlas {0:?} instance {1} contains an intersection index {2} which is greater than `Limits::max_intersection_group_count` {3}" + )] + TlasInvalidIntersectionIndex(ResourceErrorIdent, usize, u32, u32), } impl WebGpuError for BuildAccelerationStructureError { @@ -241,11 +250,14 @@ impl WebGpuError for BuildAccelerationStructureError { | Self::DifferentBlasIndexFormats(..) | Self::CompactedBlas(..) | Self::MissingIndexBuffer(..) - | Self::TlasInvalidCustomIndex(..) + | Self::TlasInvalidCustomData(..) | Self::TlasInstanceCountExceeded(..) | Self::TransformMissing(..) | Self::UseTransformMissing(..) | Self::TlasDependentMissingVertexReturn(..) + | Self::TlasInstancesIntersectionIndicesDiffer(..) + | Self::TlasInvalidQueryData(..) + | Self::TlasInvalidIntersectionIndex(..) | Self::BlasGeometryKindMismatch(..) | Self::IncompatibleBlasAabbPrimitiveCount(..) | Self::UnalignedAabbPrimitiveOffset(..) @@ -267,15 +279,23 @@ pub enum ValidateAsActionsError { #[error("Blas {0:?} is newer than the containing Tlas {1:?}")] BlasNewerThenTlas(ResourceErrorIdent, ResourceErrorIdent), + + #[error("Tlas {0:?} has an intersection index {1:?} greater or different from {2:?}")] + TlasIntersectionInvalid( + ResourceErrorIdent, + MaxIntersectionIndex, + MaxIntersectionIndex, + ), } impl WebGpuError for ValidateAsActionsError { fn webgpu_error_type(&self) -> ErrorType { match self { Self::DestroyedResource(e) => e.webgpu_error_type(), - Self::UsedUnbuiltTlas(..) | Self::UsedUnbuiltBlas(..) | Self::BlasNewerThenTlas(..) => { - ErrorType::Validation - } + Self::UsedUnbuiltTlas(..) + | Self::UsedUnbuiltBlas(..) + | Self::BlasNewerThenTlas(..) + | Self::TlasIntersectionInvalid(..) => ErrorType::Validation, } } } @@ -324,6 +344,7 @@ pub struct TlasInstance<'a> { pub transform: &'a [f32; 12], pub custom_data: u32, pub mask: u8, + pub intersection_index: wgt::IntersectionShaderIndex, } pub struct TlasPackage<'a> { @@ -336,6 +357,7 @@ pub struct TlasPackage<'a> { pub(crate) struct TlasBuild { pub tlas: Arc, pub dependencies: Vec>, + pub max_intersection_idx: MaxIntersectionIndex, } #[derive(Debug, Clone, Default)] @@ -356,7 +378,7 @@ impl AsBuild { #[derive(Debug, Clone)] pub(crate) enum AsAction { Build(AsBuild), - UseTlas(Arc), + UseTlas(Arc, MaxIntersectionIndex), } /// Like [`BlasTriangleGeometry`], but with owned data. @@ -415,6 +437,7 @@ pub struct OwnedTlasInstance { pub transform: [f32; 12], pub custom_data: u32, pub mask: u8, + pub intersection_index: wgt::IntersectionShaderIndex, } pub type ArcTlasInstance = OwnedTlasInstance; diff --git a/wgpu-core/src/resource.rs b/wgpu-core/src/resource.rs index be5096dc052..ad671559d6b 100644 --- a/wgpu-core/src/resource.rs +++ b/wgpu-core/src/resource.rs @@ -2381,6 +2381,62 @@ crate::impl_parent_device!(Blas); crate::impl_storage_item!(Blas); crate::impl_trackable!(Blas); +#[derive(Debug, Copy, Clone)] +pub enum MaxIntersectionIndex { + /// No intersection indices, TLAS can be used for anything. + Unused, + /// Intersection indices are all for query data. + Query, + /// Intersection indices are all intersection function indices. Contains maximum. + Intersection(u32), +} + +impl MaxIntersectionIndex { + /// Attempt to set the max intersection index, returning None if new does not match + pub(crate) fn set_intersection_index( + &mut self, + new: wgt::IntersectionShaderIndex, + ) -> Option<()> { + match self { + Self::Unused => { + *self = match new { + wgt::IntersectionShaderIndex::IntersectionIndex(idx) => Self::Intersection(idx), + wgt::IntersectionShaderIndex::QueryData(_) => Self::Query, + }; + } + Self::Query => { + let wgt::IntersectionShaderIndex::QueryData(_) = new else { + return None; + }; + // dont need to assign + } + Self::Intersection(ref mut idx) => { + let wgt::IntersectionShaderIndex::IntersectionIndex(new_idx) = new else { + return None; + }; + *idx = (*idx).max(new_idx) + } + } + Some(()) + } + + /// Check self is at most max intersection index `maximum_allowed` + pub(crate) fn at_most(&self, maximum_allowed: Self) -> bool { + match *self { + // anything is allowed + Self::Unused => true, + Self::Query => matches!(maximum_allowed, Self::Query), + Self::Intersection(idx) => { + let Self::Intersection(max_idx) = maximum_allowed else { + return false; + }; + + max_idx >= idx + } + } + } +} + #[derive(Debug)] pub struct Tlas { pub(crate) raw: Snatchable>, @@ -2391,6 +2447,7 @@ pub struct Tlas { pub(crate) update_mode: wgt::AccelerationStructureUpdateMode, pub(crate) built_index: RwLock>, pub(crate) dependencies: RwLock>>, + pub(crate) max_intersection_index: RwLock, pub(crate) instance_buffer: ManuallyDrop>, /// The `label` from the descriptor used to create the resource. pub(crate) label: String, diff --git a/wgpu-core/src/track/mod.rs b/wgpu-core/src/track/mod.rs index 7abb8e1a64a..bf68f10e968 100644 --- a/wgpu-core/src/track/mod.rs +++ b/wgpu-core/src/track/mod.rs @@ -230,6 +230,7 @@ pub(crate) struct TrackerIndexAllocators { pub bind_groups: Arc, pub compute_pipelines: Arc, pub render_pipelines: Arc, + pub ray_tracing_pipelines: Arc, pub bundles: Arc, pub query_sets: Arc, pub blas_s: Arc, @@ -246,6 +247,7 @@ impl TrackerIndexAllocators { bind_groups: Arc::new(SharedTrackerIndexAllocator::new()), compute_pipelines: Arc::new(SharedTrackerIndexAllocator::new()), render_pipelines: Arc::new(SharedTrackerIndexAllocator::new()), + ray_tracing_pipelines: Arc::new(SharedTrackerIndexAllocator::new()), bundles: Arc::new(SharedTrackerIndexAllocator::new()), query_sets: Arc::new(SharedTrackerIndexAllocator::new()), blas_s: Arc::new(SharedTrackerIndexAllocator::new()), @@ -646,6 +648,7 @@ pub(crate) struct Tracker { pub compute_pipelines: StatelessTracker, pub render_pipelines: StatelessTracker, + pub ray_tracing_pipelines: StatelessTracker, pub bundles: StatelessTracker, pub query_sets: StatelessTracker, } @@ -664,6 +667,7 @@ impl Tracker { bind_groups: StatelessTracker::new(), compute_pipelines: StatelessTracker::new(), render_pipelines: StatelessTracker::new(), + ray_tracing_pipelines: StatelessTracker::new(), bundles: StatelessTracker::new(), query_sets: StatelessTracker::new(), } diff --git a/wgpu-core/src/validation.rs b/wgpu-core/src/validation.rs index 44d8b0fecc0..78dffdc138a 100644 --- a/wgpu-core/src/validation.rs +++ b/wgpu-core/src/validation.rs @@ -2017,6 +2017,14 @@ pub enum ShaderStageForValidation { }, Compute, Task, + RayGeneration, + Miss, + ClosestHit { + triangle: bool, + }, + AnyHit { + triangle: bool, + }, } impl ShaderStageForValidation { @@ -2027,6 +2035,10 @@ impl ShaderStageForValidation { Self::Fragment { .. } => naga::ShaderStage::Fragment, Self::Compute => naga::ShaderStage::Compute, Self::Task => naga::ShaderStage::Task, + Self::RayGeneration => naga::ShaderStage::RayGeneration, + Self::Miss => naga::ShaderStage::Miss, + Self::AnyHit { .. } => naga::ShaderStage::AnyHit, + Self::ClosestHit { .. } => naga::ShaderStage::ClosestHit, } } @@ -2037,6 +2049,10 @@ impl ShaderStageForValidation { Self::Fragment { .. } => wgt::ShaderStages::FRAGMENT, Self::Compute => wgt::ShaderStages::COMPUTE, Self::Task => wgt::ShaderStages::TASK, + Self::RayGeneration => wgt::ShaderStages::RAY_GENERATION, + Self::Miss => wgt::ShaderStages::MISS, + Self::AnyHit { .. } => wgt::ShaderStages::ANY_HIT, + Self::ClosestHit { .. } => wgt::ShaderStages::CLOSEST_HIT, } } } diff --git a/wgpu-hal/src/dx12/adapter.rs b/wgpu-hal/src/dx12/adapter.rs index 31833421560..ad87655a9cc 100644 --- a/wgpu-hal/src/dx12/adapter.rs +++ b/wgpu-hal/src/dx12/adapter.rs @@ -1012,6 +1012,11 @@ impl super::Adapter { max_binding_array_acceleration_structure_elements_per_shader_stage: max_acceleration_structures_per_shader_stage, max_multiview_view_count, + + // not yet implemented + max_intersection_group_count: 0, + max_ray_dispatch_count: 0, + max_ray_recursion_depth: 0, }), alignments: crate::Alignments { buffer_copy_offset: wgt::BufferSize::new( @@ -1031,6 +1036,10 @@ impl super::Adapter { .unwrap(), ray_tracing_scratch_buffer_alignment: Direct3D12::D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BYTE_ALIGNMENT, + // Not yet implemented + ray_tracing_pipeline_group_data_size: 0, + ray_tracing_pipeline_group_data_alignment: 0, + ray_tracing_pipeline_data_offset_alignment: 0, }, downlevel, cooperative_matrix_properties: Vec::new(), diff --git a/wgpu-hal/src/dx12/command.rs b/wgpu-hal/src/dx12/command.rs index 3d218efdf07..c4fbaf48713 100644 --- a/wgpu-hal/src/dx12/command.rs +++ b/wgpu-hal/src/dx12/command.rs @@ -1857,4 +1857,29 @@ impl crate::CommandEncoder for super::CommandEncoder { _dependencies: &[&super::AccelerationStructure], ) { } + + unsafe fn begin_ray_tracing_pass(&mut self, _desc: &crate::RayTracingPassDescriptor) { + unreachable!("Ray tracing pipelines not supported") + } + + unsafe fn end_ray_tracing_pass(&mut self) { + unreachable!("Ray tracing pipelines not supported") + } + + unsafe fn set_ray_tracing_pipeline( + &mut self, + _pipeline: &::RayTracingPipeline, + ) { + unreachable!("Ray tracing pipelines not supported") + } + + unsafe fn trace_rays( + &mut self, + _count: [u32; 3], + _ray_generation_group_data: crate::PipelineGroupData, + _miss_group_data: crate::PipelineGroupData, + _intersection_group_data: crate::PipelineGroupData, + ) { + unreachable!("Ray tracing pipelines not supported") + } } diff --git a/wgpu-hal/src/dx12/device.rs b/wgpu-hal/src/dx12/device.rs index 359a89fc054..33dfcc9c27c 100644 --- a/wgpu-hal/src/dx12/device.rs +++ b/wgpu-hal/src/dx12/device.rs @@ -2208,6 +2208,29 @@ impl crate::Device for super::Device { self.counters.compute_pipelines.sub(1); } + unsafe fn create_ray_tracing_pipeline( + &self, + _desc: &crate::RayTracingPipelineDescriptor< + super::PipelineLayout, + super::ShaderModule, + super::PipelineCache, + >, + ) -> Result<::RayTracingPipeline, crate::PipelineError> { + unreachable!("ray tracing pipelines not yet implemented") + } + + unsafe fn destroy_ray_tracing_pipeline(&self, _pipeline: super::RayTracingPipeline) { + unreachable!("ray tracing pipelines not yet implemented") + } + + unsafe fn get_raytracing_pipeline_group_data( + &self, + _pipeline: &super::RayTracingPipeline, + _groups: core::ops::Range, + ) -> Result, crate::DeviceError> { + unimplemented!("ray tracing pipelines not yet implemented") + } + unsafe fn create_pipeline_cache( &self, _desc: &crate::PipelineCacheDescriptor<'_>, @@ -2591,7 +2614,7 @@ impl crate::Device for super::Device { let temp = Direct3D12::D3D12_RAYTRACING_INSTANCE_DESC { Transform: instance.transform, _bitfield1: (instance.custom_data & MAX_U24) | (u32::from(instance.mask) << 24), - _bitfield2: 0, + _bitfield2: (instance.pipeline_intersection_data_offset & MAX_U24), AccelerationStructure: instance.blas_address, }; diff --git a/wgpu-hal/src/dx12/mod.rs b/wgpu-hal/src/dx12/mod.rs index 304da8c8868..6753cfffcc1 100644 --- a/wgpu-hal/src/dx12/mod.rs +++ b/wgpu-hal/src/dx12/mod.rs @@ -478,6 +478,7 @@ impl crate::Api for Api { type ShaderModule = ShaderModule; type RenderPipeline = RenderPipeline; type ComputePipeline = ComputePipeline; + type RayTracingPipeline = RayTracingPipeline; type PipelineCache = PipelineCache; type AccelerationStructure = AccelerationStructure; @@ -499,6 +500,7 @@ crate::impl_dyn_resource!( PipelineLayout, QuerySet, Queue, + RayTracingPipeline, RenderPipeline, Sampler, ShaderModule, @@ -1293,6 +1295,11 @@ impl crate::DynComputePipeline for ComputePipeline {} unsafe impl Send for ComputePipeline {} unsafe impl Sync for ComputePipeline {} +#[derive(Debug)] +pub struct RayTracingPipeline {} + +impl crate::DynRayTracingPipeline for RayTracingPipeline {} + #[derive(Debug)] pub struct PipelineCache; diff --git a/wgpu-hal/src/dynamic/command.rs b/wgpu-hal/src/dynamic/command.rs index 6b4e4fdb040..aca5bbcb3e4 100644 --- a/wgpu-hal/src/dynamic/command.rs +++ b/wgpu-hal/src/dynamic/command.rs @@ -4,8 +4,9 @@ use core::ops::Range; use crate::{ AccelerationStructureBarrier, Api, Attachment, BufferBarrier, BufferBinding, BufferCopy, BufferTextureCopy, BuildAccelerationStructureDescriptor, ColorAttachment, CommandEncoder, - ComputePassDescriptor, DepthStencilAttachment, DeviceError, Label, MemoryRange, - PassTimestampWrites, Rect, RenderPassDescriptor, TextureBarrier, TextureCopy, + ComputePassDescriptor, DepthStencilAttachment, DeviceError, DynRayTracingPipeline, Label, + MemoryRange, PassTimestampWrites, RayTracingPassDescriptor, Rect, RenderPassDescriptor, + TextureBarrier, TextureCopy, }; use super::{ @@ -190,6 +191,19 @@ pub trait DynCommandEncoder: DynResource + core::fmt::Debug { offset: wgt::BufferAddress, ); + unsafe fn begin_ray_tracing_pass(&mut self, desc: &RayTracingPassDescriptor); + unsafe fn end_ray_tracing_pass(&mut self); + + unsafe fn trace_rays( + &mut self, + count: [u32; 3], + ray_generation_group_data: crate::PipelineGroupData, + miss_group_data: crate::PipelineGroupData, + intersection_group_data: crate::PipelineGroupData, + ); + + unsafe fn set_ray_tracing_pipeline(&mut self, pipeline: &dyn DynRayTracingPipeline); + unsafe fn build_acceleration_structures<'a>( &mut self, descriptors: &'a [BuildAccelerationStructureDescriptor< @@ -646,6 +660,46 @@ impl DynCommandEncoder for C { unsafe { self.set_vertex_buffer(index, binding) }; } + unsafe fn begin_ray_tracing_pass(&mut self, desc: &RayTracingPassDescriptor) { + let desc = RayTracingPassDescriptor { label: desc.label }; + unsafe { C::begin_ray_tracing_pass(self, &desc) }; + } + + unsafe fn end_ray_tracing_pass(&mut self) { + unsafe { C::end_ray_tracing_pass(self) }; + } + + unsafe fn set_ray_tracing_pipeline(&mut self, pipeline: &dyn DynRayTracingPipeline) { + let pipeline = pipeline.expect_downcast_ref(); + unsafe { C::set_ray_tracing_pipeline(self, pipeline) }; + } + + unsafe fn trace_rays<'a>( + &mut self, + count: [u32; 3], + ray_generation_group_data: crate::PipelineGroupData<'a, dyn DynBuffer>, + miss_group_data: crate::PipelineGroupData<'a, dyn DynBuffer>, + intersection_group_data: crate::PipelineGroupData<'a, dyn DynBuffer>, + ) { + let downcast_group_data = + |data: crate::PipelineGroupData<'a, dyn DynBuffer>| crate::PipelineGroupData { + buffer: data.buffer.expect_downcast_ref(), + offset: data.offset, + stride: data.stride, + count: data.count, + }; + + unsafe { + C::trace_rays( + self, + count, + downcast_group_data(ray_generation_group_data), + downcast_group_data(miss_group_data), + downcast_group_data(intersection_group_data), + ); + } + } + unsafe fn build_acceleration_structures<'a>( &mut self, descriptors: &'a [BuildAccelerationStructureDescriptor< diff --git a/wgpu-hal/src/dynamic/device.rs b/wgpu-hal/src/dynamic/device.rs index b5ff5904520..f9b9fe682cc 100644 --- a/wgpu-hal/src/dynamic/device.rs +++ b/wgpu-hal/src/dynamic/device.rs @@ -5,16 +5,16 @@ use crate::{ BindGroupLayoutDescriptor, BufferDescriptor, BufferMapping, CommandEncoderDescriptor, ComputePipelineDescriptor, Device, DeviceError, FenceValue, GetAccelerationStructureBuildSizesDescriptor, Label, MemoryRange, PipelineCacheDescriptor, - PipelineCacheError, PipelineError, PipelineLayoutDescriptor, RenderPipelineDescriptor, - SamplerDescriptor, ShaderError, ShaderInput, ShaderModuleDescriptor, TextureDescriptor, - TextureViewDescriptor, TlasInstance, + PipelineCacheError, PipelineError, PipelineLayoutDescriptor, RayObjectIntersectionState, + RayTracingPipelineDescriptor, RenderPipelineDescriptor, SamplerDescriptor, ShaderError, + ShaderInput, ShaderModuleDescriptor, TextureDescriptor, TextureViewDescriptor, TlasInstance, }; use super::{ DynAccelerationStructure, DynBindGroup, DynBindGroupLayout, DynBuffer, DynCommandEncoder, DynComputePipeline, DynFence, DynPipelineCache, DynPipelineLayout, DynQuerySet, DynQueue, - DynRenderPipeline, DynResource, DynResourceExt as _, DynSampler, DynShaderModule, DynTexture, - DynTextureView, + DynRayTracingPipeline, DynRenderPipeline, DynResource, DynResourceExt as _, DynSampler, + DynShaderModule, DynTexture, DynTextureView, }; pub trait DynDevice: DynResource { @@ -112,6 +112,21 @@ pub trait DynDevice: DynResource { ) -> Result, PipelineError>; unsafe fn destroy_compute_pipeline(&self, pipeline: Box); + unsafe fn create_ray_tracing_pipeline( + &self, + desc: &RayTracingPipelineDescriptor< + dyn DynPipelineLayout, + dyn DynShaderModule, + dyn DynPipelineCache, + >, + ) -> Result, PipelineError>; + unsafe fn destroy_ray_tracing_pipeline(&self, pipeline: Box); + unsafe fn get_raytracing_pipeline_group_data( + &self, + pipeline: &dyn DynRayTracingPipeline, + groups: core::ops::Range, + ) -> Result, DeviceError>; + unsafe fn create_pipeline_cache( &self, desc: &PipelineCacheDescriptor<'_>, @@ -442,6 +457,55 @@ impl DynDevice for D { unsafe { D::destroy_compute_pipeline(self, pipeline.unbox()) }; } + unsafe fn create_ray_tracing_pipeline( + &self, + desc: &RayTracingPipelineDescriptor< + dyn DynPipelineLayout, + dyn DynShaderModule, + dyn DynPipelineCache, + >, + ) -> Result, PipelineError> { + let ray_intersection: Vec<_> = desc + .intersection + .iter() + .map(|stage| RayObjectIntersectionState { + closest_hit: stage.closest_hit.clone().expect_downcast(), + any_hit: stage + .any_hit + .as_ref() + .map(|stage| stage.clone().expect_downcast()), + }) + .collect(); + + let desc = RayTracingPipelineDescriptor { + label: desc.label, + layout: desc.layout.expect_downcast_ref(), + ray_generation: desc.ray_generation.clone().expect_downcast(), + miss: desc.miss.clone().expect_downcast(), + intersection: &ray_intersection, + max_recursion_depth: desc.max_recursion_depth, + cache: desc.cache.as_ref().map(|c| c.expect_downcast_ref()), + }; + + unsafe { D::create_ray_tracing_pipeline(self, &desc) } + .map(|b| -> Box { Box::new(b) }) + } + + unsafe fn destroy_ray_tracing_pipeline(&self, pipeline: Box) { + unsafe { + D::destroy_ray_tracing_pipeline(self, pipeline.unbox()); + }; + } + unsafe fn get_raytracing_pipeline_group_data( + &self, + pipeline: &dyn DynRayTracingPipeline, + groups: core::ops::Range, + ) -> Result, DeviceError> { + unsafe { + D::get_raytracing_pipeline_group_data(self, pipeline.expect_downcast_ref(), groups) + } + } + unsafe fn create_pipeline_cache( &self, desc: &PipelineCacheDescriptor<'_>, diff --git a/wgpu-hal/src/dynamic/mod.rs b/wgpu-hal/src/dynamic/mod.rs index 85d8ca00450..8284300df92 100644 --- a/wgpu-hal/src/dynamic/mod.rs +++ b/wgpu-hal/src/dynamic/mod.rs @@ -116,6 +116,7 @@ pub trait DynPipelineCache: DynResource + fmt::Debug {} pub trait DynPipelineLayout: DynResource + fmt::Debug {} pub trait DynQuerySet: DynResource + fmt::Debug {} pub trait DynRenderPipeline: DynResource + fmt::Debug {} +pub trait DynRayTracingPipeline: DynResource + fmt::Debug {} pub trait DynSampler: DynResource + fmt::Debug {} pub trait DynShaderModule: DynResource + fmt::Debug {} pub trait DynSurfaceTexture: diff --git a/wgpu-hal/src/gles/adapter.rs b/wgpu-hal/src/gles/adapter.rs index d8f14d42e47..53aa43b27c2 100644 --- a/wgpu-hal/src/gles/adapter.rs +++ b/wgpu-hal/src/gles/adapter.rs @@ -846,6 +846,10 @@ impl super::Adapter { max_acceleration_structures_per_shader_stage: 0, max_multiview_view_count: 0, + + max_intersection_group_count: 0, + max_ray_dispatch_count: 0, + max_ray_recursion_depth: 0, }); let mut workarounds = super::Workarounds::empty(); @@ -920,6 +924,9 @@ impl super::Adapter { uniform_bounds_check_alignment: wgt::BufferSize::new(1).unwrap(), raw_tlas_instance_size: 0, ray_tracing_scratch_buffer_alignment: 0, + ray_tracing_pipeline_group_data_size: 0, + ray_tracing_pipeline_group_data_alignment: 0, + ray_tracing_pipeline_data_offset_alignment: 0, }, cooperative_matrix_properties: Vec::new(), }, diff --git a/wgpu-hal/src/gles/command.rs b/wgpu-hal/src/gles/command.rs index 409481b682a..427551cec85 100644 --- a/wgpu-hal/src/gles/command.rs +++ b/wgpu-hal/src/gles/command.rs @@ -1317,4 +1317,29 @@ impl crate::CommandEncoder for super::CommandEncoder { ) { unimplemented!() } + + unsafe fn begin_ray_tracing_pass(&mut self, _desc: &crate::RayTracingPassDescriptor) { + unimplemented!() + } + + unsafe fn end_ray_tracing_pass(&mut self) { + unimplemented!() + } + + unsafe fn set_ray_tracing_pipeline( + &mut self, + _pipeline: &::RayTracingPipeline, + ) { + unimplemented!() + } + + unsafe fn trace_rays( + &mut self, + _count: [u32; 3], + _ray_generation_group_data: crate::PipelineGroupData, + _miss_group_data: crate::PipelineGroupData, + _intersection_group_data: crate::PipelineGroupData, + ) { + unimplemented!() + } } diff --git a/wgpu-hal/src/gles/device.rs b/wgpu-hal/src/gles/device.rs index 1f63a808100..46472f05d65 100644 --- a/wgpu-hal/src/gles/device.rs +++ b/wgpu-hal/src/gles/device.rs @@ -1553,6 +1553,29 @@ impl crate::Device for super::Device { self.counters.compute_pipelines.sub(1); } + unsafe fn create_ray_tracing_pipeline( + &self, + _desc: &crate::RayTracingPipelineDescriptor< + super::PipelineLayout, + super::ShaderModule, + super::PipelineCache, + >, + ) -> Result { + unimplemented!("Ray tracing is unsupported on GL") + } + + unsafe fn destroy_ray_tracing_pipeline(&self, _pipeline: super::RayTracingPipeline) { + unimplemented!("Ray tracing is unsupported on GL") + } + + unsafe fn get_raytracing_pipeline_group_data( + &self, + _pipeline: &super::RayTracingPipeline, + _groups: core::ops::Range, + ) -> Result, crate::DeviceError> { + unimplemented!("Ray tracing is unsupported on GL") + } + unsafe fn create_pipeline_cache( &self, _: &crate::PipelineCacheDescriptor<'_>, diff --git a/wgpu-hal/src/gles/mod.rs b/wgpu-hal/src/gles/mod.rs index 1a0292d4a0f..dd034f3b836 100644 --- a/wgpu-hal/src/gles/mod.rs +++ b/wgpu-hal/src/gles/mod.rs @@ -168,6 +168,7 @@ impl crate::Api for Api { type ShaderModule = ShaderModule; type RenderPipeline = RenderPipeline; type ComputePipeline = ComputePipeline; + type RayTracingPipeline = RayTracingPipeline; } crate::impl_dyn_resource!( @@ -187,6 +188,7 @@ crate::impl_dyn_resource!( QuerySet, Queue, RenderPipeline, + RayTracingPipeline, Sampler, ShaderModule, Surface, @@ -752,6 +754,11 @@ pub struct ComputePipeline { impl crate::DynComputePipeline for ComputePipeline {} +#[derive(Debug)] +pub struct RayTracingPipeline {} + +impl crate::DynRayTracingPipeline for RayTracingPipeline {} + #[cfg(send_sync)] unsafe impl Sync for ComputePipeline {} #[cfg(send_sync)] diff --git a/wgpu-hal/src/lib.rs b/wgpu-hal/src/lib.rs index 51815774dec..c92293d5223 100644 --- a/wgpu-hal/src/lib.rs +++ b/wgpu-hal/src/lib.rs @@ -288,8 +288,9 @@ pub use dynamic::{ DynAccelerationStructure, DynAcquiredSurfaceTexture, DynAdapter, DynBindGroup, DynBindGroupLayout, DynBuffer, DynCommandBuffer, DynCommandEncoder, DynComputePipeline, DynDevice, DynExposedAdapter, DynFence, DynInstance, DynOpenDevice, DynPipelineCache, - DynPipelineLayout, DynQuerySet, DynQueue, DynRenderPipeline, DynResource, DynSampler, - DynShaderModule, DynSurface, DynSurfaceTexture, DynTexture, DynTextureView, + DynPipelineLayout, DynQuerySet, DynQueue, DynRayTracingPipeline, DynRenderPipeline, + DynResource, DynSampler, DynShaderModule, DynSurface, DynSurfaceTexture, DynTexture, + DynTextureView, }; #[allow(unused)] @@ -640,6 +641,7 @@ pub trait Api: Clone + fmt::Debug + Sized + WasmNotSendSync + 'static { type ShaderModule: DynShaderModule; type RenderPipeline: DynRenderPipeline; type ComputePipeline: DynComputePipeline; + type RayTracingPipeline: DynRayTracingPipeline; type PipelineCache: DynPipelineCache; type AccelerationStructure: DynAccelerationStructure + 'static; @@ -1064,6 +1066,24 @@ pub trait Device: WasmNotSendSync { ) -> Result<::ComputePipeline, PipelineError>; unsafe fn destroy_compute_pipeline(&self, pipeline: ::ComputePipeline); + #[allow(clippy::type_complexity)] + unsafe fn create_ray_tracing_pipeline( + &self, + desc: &RayTracingPipelineDescriptor< + ::PipelineLayout, + ::ShaderModule, + ::PipelineCache, + >, + ) -> Result<::RayTracingPipeline, PipelineError>; + unsafe fn destroy_ray_tracing_pipeline(&self, pipeline: ::RayTracingPipeline); + /// Obtain the opaque data from each group, behaves as if group 0 is the closest hit, group 1 + /// is the miss shader, and group 2.. are the intersection groups. + unsafe fn get_raytracing_pipeline_group_data( + &self, + pipeline: &::RayTracingPipeline, + groups: Range, + ) -> Result, DeviceError>; + unsafe fn create_pipeline_cache( &self, desc: &PipelineCacheDescriptor<'_>, @@ -1581,10 +1601,15 @@ pub trait CommandEncoder: WasmNotSendSync + fmt::Debug { /// - All prior calls to [`begin_compute_pass`] on this [`CommandEncoder`] must have been followed /// by a call to [`end_compute_pass`]. /// + /// - All prior calls to [`begin_ray_tracing_pass`] on this [`CommandEncoder`] must have been followed + /// by a call to [`end_ray_tracing_pass`]. + /// /// [`begin_render_pass`]: CommandEncoder::begin_render_pass /// [`begin_compute_pass`]: CommandEncoder::begin_compute_pass + /// [`begin_ray_tracing_pass`]: CommandEncoder::begin_ray_tracing_pass /// [`end_render_pass`]: CommandEncoder::end_render_pass /// [`end_compute_pass`]: CommandEncoder::end_compute_pass + /// [`end_ray_tracing_pass`]: CommandEncoder::end_ray_tracing_pass unsafe fn begin_render_pass( &mut self, desc: &RenderPassDescriptor<::QuerySet, ::TextureView>, @@ -1701,10 +1726,15 @@ pub trait CommandEncoder: WasmNotSendSync + fmt::Debug { /// - All prior calls to [`begin_compute_pass`] on this [`CommandEncoder`] must have been followed /// by a call to [`end_compute_pass`]. /// + /// - All prior calls to [`begin_ray_tracing_pass`] on this [`CommandEncoder`] must have been followed + /// by a call to [`end_ray_tracing_pass`]. + /// /// [`begin_render_pass`]: CommandEncoder::begin_render_pass /// [`begin_compute_pass`]: CommandEncoder::begin_compute_pass + /// [`begin_ray_tracing_pass`]: CommandEncoder::begin_ray_tracing_pass /// [`end_render_pass`]: CommandEncoder::end_render_pass /// [`end_compute_pass`]: CommandEncoder::end_compute_pass + /// [`end_ray_tracing_pass`]: CommandEncoder::end_ray_tracing_pass unsafe fn begin_compute_pass( &mut self, desc: &ComputePassDescriptor<::QuerySet>, @@ -1730,6 +1760,58 @@ pub trait CommandEncoder: WasmNotSendSync + fmt::Debug { offset: wgt::BufferAddress, ); + /// Begin a new ray tracing pass, clearing all active bindings. + /// + /// This clears any bindings established by the following calls: + /// + /// - [`set_bind_group`](CommandEncoder::set_bind_group) + /// - [`set_immediates`](CommandEncoder::set_immediates) + /// - [`begin_query`](CommandEncoder::begin_query) + /// - [`set_ray_tracing_pipeline`](CommandEncoder::set_compute_pipeline) + /// + /// # Safety + /// + /// - All prior calls to [`begin_render_pass`] on this [`CommandEncoder`] must have been followed + /// by a call to [`end_render_pass`]. + /// + /// - All prior calls to [`begin_compute_pass`] on this [`CommandEncoder`] must have been followed + /// by a call to [`end_compute_pass`]. + /// + /// - All prior calls to [`begin_ray_tracing_pass`] on this [`CommandEncoder`] must have been followed + /// by a call to [`end_ray_tracing_pass`]. + /// + /// [`begin_render_pass`]: CommandEncoder::begin_render_pass + /// [`begin_compute_pass`]: CommandEncoder::begin_compute_pass + /// [`begin_ray_tracing_pass`]: CommandEncoder::begin_ray_tracing_pass + /// [`end_render_pass`]: CommandEncoder::end_render_pass + /// [`end_compute_pass`]: CommandEncoder::end_compute_pass + /// [`end_ray_tracing_pass`]: CommandEncoder::end_ray_tracing_pass + unsafe fn begin_ray_tracing_pass(&mut self, desc: &RayTracingPassDescriptor); + + /// End the current compute pass. + /// + /// # Safety + /// + /// - There must have been a prior call to [`begin_ray_tracing_pass`] on this [`CommandEncoder`] + /// that has not been followed by a call to [`end_ray_tracing_pass`]. + /// + /// [`begin_ray_tracing_pass`]: CommandEncoder::begin_ray_tracing_pass + /// [`end_ray_tracing_pass`]: CommandEncoder::end_ray_tracing_pass + unsafe fn end_ray_tracing_pass(&mut self); + + /// # Safety + /// + /// - Pipeline must not be destroyed + unsafe fn set_ray_tracing_pipeline(&mut self, pipeline: &::RayTracingPipeline); + + unsafe fn trace_rays<'a>( + &mut self, + count: [u32; 3], + ray_generation_group_data: PipelineGroupData<'a, ::Buffer>, + miss_group_data: PipelineGroupData<'a, ::Buffer>, + intersection_group_data: PipelineGroupData<'a, ::Buffer>, + ); + /// To get the required sizes for the buffer allocations use `get_acceleration_structure_build_sizes` per descriptor /// All buffers must be synchronized externally /// All buffer regions, which are written to may only be passed once per function call, @@ -1969,6 +2051,20 @@ pub struct Alignments { /// What the scratch buffer for building an acceleration structure must be aligned to pub ray_tracing_scratch_buffer_alignment: u32, + + /// How large a single piece of group data is. That is, how large the vector returned + /// from `device.get_raytracing_pipeline_group_data(&pipeline, n..(n+1))` is. + /// + /// If ray tracing pipelines are implemented, this must be non zero. + pub ray_tracing_pipeline_group_data_size: u32, + + /// If ray tracing pipelines are implemented, this must be a power of two (and non zero). + pub ray_tracing_pipeline_group_data_alignment: u32, + + /// If ray tracing pipelines are implemented, this must be a power of two (and non zero). + /// + /// The offset within `PipelineGroupData` must be a multiple of this + pub ray_tracing_pipeline_data_offset_alignment: u32, } #[derive(Clone, Debug)] @@ -2535,6 +2631,35 @@ pub struct RenderPipelineDescriptor< pub cache: Option<&'a Pc>, } +#[derive(Clone, Debug)] +pub struct RayObjectIntersectionState<'a, M: DynShaderModule + ?Sized> { + pub closest_hit: ProgrammableStage<'a, M>, + pub any_hit: Option>, +} + +/// Describes a ray tracing pipeline. +#[derive(Clone, Debug)] +pub struct RayTracingPipelineDescriptor< + 'a, + Pl: DynPipelineLayout + ?Sized, + M: DynShaderModule + ?Sized, + Pc: DynPipelineCache + ?Sized, +> { + pub label: Label<'a>, + /// The layout of bind groups for this pipeline. + pub layout: &'a Pl, + /// The ray generation stage. + pub ray_generation: ProgrammableStage<'a, M>, + /// The miss stage. + pub miss: ProgrammableStage<'a, M>, + /// All the object intersection stages. + pub intersection: &'a [RayObjectIntersectionState<'a, M>], + /// The maximum recursion depth allowed for the ray tracing (ray_generation shader counts as depth 0). + pub max_recursion_depth: u32, + /// The cache which will be used and filled when compiling this pipeline + pub cache: Option<&'a Pc>, +} + #[derive(Debug, Clone)] pub struct SurfaceConfiguration { /// Maximum number of queued frames. Must be in @@ -2701,6 +2826,11 @@ pub struct ComputePassDescriptor<'a, Q: DynQuerySet + ?Sized> { pub timestamp_writes: Option>, } +#[derive(Clone, Debug)] +pub struct RayTracingPassDescriptor<'a> { + pub label: Label<'a>, +} + #[test] fn test_default_limits() { let limits = wgt::Limits::default(); @@ -2860,6 +2990,7 @@ pub struct TlasInstance { pub custom_data: u32, pub mask: u8, pub blas_address: u64, + pub pipeline_intersection_data_offset: u32, } #[cfg(dx12)] @@ -2881,3 +3012,10 @@ pub struct Telemetry { result: D3D12ExposeAdapterResult, ), } + +pub struct PipelineGroupData<'a, B: DynBuffer + ?Sized> { + pub buffer: &'a B, + pub offset: wgt::BufferAddress, + pub stride: u64, + pub count: u64, +} diff --git a/wgpu-hal/src/metal/adapter.rs b/wgpu-hal/src/metal/adapter.rs index fe7ed916bd2..1fbddf71e1a 100644 --- a/wgpu-hal/src/metal/adapter.rs +++ b/wgpu-hal/src/metal/adapter.rs @@ -1398,6 +1398,10 @@ impl super::CapabilitiesQuery { max_mesh_output_primitives: 256, max_mesh_output_layers: self.max_texture_layers as u32, max_mesh_multiview_view_count: 0, + // unimplemented + max_intersection_group_count: 0, + max_ray_dispatch_count: 0, + max_ray_recursion_depth: 0, }); crate::Capabilities { @@ -1414,6 +1418,10 @@ impl super::CapabilitiesQuery { >()) .unwrap(), ray_tracing_scratch_buffer_alignment: 1, + // Not yet supported + ray_tracing_pipeline_group_data_size: 0, + ray_tracing_pipeline_group_data_alignment: 0, + ray_tracing_pipeline_data_offset_alignment: 0, }, downlevel, cooperative_matrix_properties: self.cooperative_matrix_properties(), diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 9f2560df72f..67a215f37a2 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -1883,6 +1883,31 @@ impl crate::CommandEncoder for super::CommandEncoder { } residency_set.commit(); } + + unsafe fn begin_ray_tracing_pass(&mut self, _desc: &crate::RayTracingPassDescriptor) { + unreachable!("Ray tracing pipelines not supported") + } + + unsafe fn end_ray_tracing_pass(&mut self) { + unreachable!("Ray tracing pipelines not supported") + } + + unsafe fn set_ray_tracing_pipeline( + &mut self, + _pipeline: &::RayTracingPipeline, + ) { + unreachable!("Ray tracing pipelines not supported") + } + + unsafe fn trace_rays( + &mut self, + _count: [u32; 3], + _ray_generation_group_data: crate::PipelineGroupData, + _miss_group_data: crate::PipelineGroupData, + _intersection_group_data: crate::PipelineGroupData, + ) { + unreachable!("Ray tracing pipelines not supported") + } } impl Drop for super::CommandEncoder { diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index f13bb24c3b1..525eefbabf3 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -1770,6 +1770,29 @@ impl crate::Device for super::Device { self.counters.compute_pipelines.sub(1); } + unsafe fn create_ray_tracing_pipeline( + &self, + _desc: &crate::RayTracingPipelineDescriptor< + super::PipelineLayout, + super::ShaderModule, + super::PipelineCache, + >, + ) -> Result { + unimplemented!("Ray tracing pipelines are unsupported on Metal") + } + + unsafe fn destroy_ray_tracing_pipeline(&self, _pipeline: super::RayTracingPipeline) { + unimplemented!("Ray tracing pipelines are unsupported on Metal") + } + + unsafe fn get_raytracing_pipeline_group_data( + &self, + _pipeline: &super::RayTracingPipeline, + _groups: core::ops::Range, + ) -> Result, crate::DeviceError> { + unimplemented!("Ray tracing pipelines are unsupported on Metal") + } + unsafe fn create_pipeline_cache( &self, _desc: &crate::PipelineCacheDescriptor<'_>, @@ -2019,7 +2042,7 @@ impl crate::Device for super::Device { }, options: MTLAccelerationStructureInstanceOptions::None, mask: instance.mask as u32, - intersectionFunctionTableOffset: 0, + intersectionFunctionTableOffset: instance.pipeline_intersection_data_offset, userID: instance.custom_data, accelerationStructureID: unsafe { MTLResourceID::from_raw(instance.blas_address) }, }; diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index c482b9a2a48..6b9e4a20910 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -87,6 +87,7 @@ impl crate::Api for Api { type PipelineLayout = PipelineLayout; type ShaderModule = ShaderModule; type RenderPipeline = RenderPipeline; + type RayTracingPipeline = RayTracingPipeline; type ComputePipeline = ComputePipeline; type PipelineCache = PipelineCache; @@ -110,6 +111,7 @@ crate::impl_dyn_resource!( QuerySet, Queue, RenderPipeline, + RayTracingPipeline, Sampler, ShaderModule, Surface, @@ -1003,6 +1005,11 @@ unsafe impl Sync for ComputePipeline {} impl crate::DynComputePipeline for ComputePipeline {} +#[derive(Debug)] +pub struct RayTracingPipeline {} + +impl crate::DynRayTracingPipeline for RayTracingPipeline {} + #[derive(Debug, Clone)] pub struct QuerySet { raw_buffer: Retained>, diff --git a/wgpu-hal/src/noop/command.rs b/wgpu-hal/src/noop/command.rs index 4150a1bc6f9..16cca58d436 100644 --- a/wgpu-hal/src/noop/command.rs +++ b/wgpu-hal/src/noop/command.rs @@ -300,6 +300,25 @@ impl crate::CommandEncoder for CommandBuffer { dependencies: &[&Resource], ) { } + + unsafe fn begin_ray_tracing_pass(&mut self, _desc: &crate::RayTracingPassDescriptor) {} + + unsafe fn end_ray_tracing_pass(&mut self) {} + + unsafe fn set_ray_tracing_pipeline( + &mut self, + _pipeline: &::RayTracingPipeline, + ) { + } + + unsafe fn trace_rays( + &mut self, + _count: [u32; 3], + _ray_generation_group_data: crate::PipelineGroupData, + _miss_group_data: crate::PipelineGroupData, + _intersection_group_data: crate::PipelineGroupData, + ) { + } } impl Command { diff --git a/wgpu-hal/src/noop/mod.rs b/wgpu-hal/src/noop/mod.rs index 563b1b27de9..45491701137 100644 --- a/wgpu-hal/src/noop/mod.rs +++ b/wgpu-hal/src/noop/mod.rs @@ -60,6 +60,7 @@ impl crate::Api for Api { type PipelineLayout = Resource; type ShaderModule = Resource; type RenderPipeline = Resource; + type RayTracingPipeline = Resource; type ComputePipeline = Resource; } @@ -76,6 +77,7 @@ impl crate::DynPipelineCache for Resource {} impl crate::DynPipelineLayout for Resource {} impl crate::DynQuerySet for Resource {} impl crate::DynRenderPipeline for Resource {} +impl crate::DynRayTracingPipeline for Resource {} impl crate::DynSampler for Resource {} impl crate::DynShaderModule for Resource {} impl crate::DynSurfaceTexture for Resource {} @@ -178,6 +180,9 @@ pub const CAPABILITIES: crate::Capabilities = { uniform_bounds_check_alignment: wgt::BufferSize::MIN, raw_tlas_instance_size: 0, ray_tracing_scratch_buffer_alignment: 1, + ray_tracing_pipeline_group_data_size: 1, + ray_tracing_pipeline_group_data_alignment: 1, + ray_tracing_pipeline_data_offset_alignment: 1, }, downlevel: wgt::DownlevelCapabilities { flags: wgt::DownlevelFlags::all(), @@ -389,6 +394,20 @@ impl crate::Device for Context { Ok(Resource) } unsafe fn destroy_compute_pipeline(&self, pipeline: Resource) {} + unsafe fn create_ray_tracing_pipeline( + &self, + desc: &crate::RayTracingPipelineDescriptor, + ) -> Result { + Ok(Resource) + } + unsafe fn destroy_ray_tracing_pipeline(&self, pipeline: Resource) {} + unsafe fn get_raytracing_pipeline_group_data( + &self, + pipeline: &Resource, + groups: core::ops::Range, + ) -> Result, crate::DeviceError> { + Ok(vec![0; groups.count()]) + } unsafe fn create_pipeline_cache( &self, desc: &crate::PipelineCacheDescriptor<'_>, diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index 3a8942f7b1f..2443f647c13 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -100,6 +100,9 @@ pub struct PhysicalDeviceFeatures { /// [`Instance::expose_adapter`]: super::Instance::expose_adapter ray_query: Option>, + /// Features provided by `VK_KHR_ray_tracing_pipeline`. + ray_tracing_pipeline: Option>, + /// Features provided by `VK_KHR_zero_initialize_workgroup_memory`, promoted /// to Vulkan 1.3. zero_initialize_workgroup_memory: @@ -192,6 +195,9 @@ impl PhysicalDeviceFeatures { if let Some(ref mut feature) = self.ray_query { info = info.push_next(feature); } + if let Some(ref mut feature) = self.ray_tracing_pipeline { + info = info.push_next(feature); + } if let Some(ref mut feature) = self.shader_atomic_int64 { info = info.push_next(feature); } @@ -474,6 +480,14 @@ impl PhysicalDeviceFeatures { } else { None }, + ray_tracing_pipeline: if enabled_extensions.contains(&khr::ray_tracing_pipeline::NAME) { + Some( + vk::PhysicalDeviceRayTracingPipelineFeaturesKHR::default() + .ray_tracing_pipeline(true), + ) + } else { + None + }, zero_initialize_workgroup_memory: if device_api_version >= vk::API_VERSION_1_3 || enabled_extensions.contains(&khr::zero_initialize_workgroup_memory::NAME) { @@ -950,6 +964,14 @@ impl PhysicalDeviceFeatures { supports_acceleration_structure_binding_array, ); + features.set( + F::EXPERIMENTAL_RAY_TRACING_PIPELINES + // Ditto. + | F::EXTENDED_ACCELERATION_STRUCTURE_VERTEX_FORMATS, + supports_acceleration_structures + && caps.supports_extension(khr::ray_tracing_pipeline::NAME), + ); + let rg11b10ufloat_renderable = supports_format( instance, phd, @@ -1115,6 +1137,10 @@ pub struct PhysicalDeviceProperties { /// `VK_KHR_acceleration_structure` extension. acceleration_structure: Option>, + /// Additional `vk::PhysicalDevice` properties from the + /// `VK_KHR_ray_tracing_pipeline` extension. + ray_tracing_pipeline: Option>, + /// Additional `vk::PhysicalDevice` properties from the /// `VK_KHR_driver_properties` extension, promoted to Vulkan 1.2. driver: Option>, @@ -1320,14 +1346,26 @@ impl PhysicalDeviceProperties { extensions.push(khr::draw_indirect_count::NAME); } - // Require `VK_KHR_deferred_host_operations`, `VK_KHR_acceleration_structure` `VK_KHR_buffer_device_address` (for acceleration structures) and`VK_KHR_ray_query` if `EXPERIMENTAL_RAY_QUERY` was requested - if requested_features.contains(wgt::Features::EXPERIMENTAL_RAY_QUERY) { + // Require `VK_KHR_deferred_host_operations`, `VK_KHR_acceleration_structure` `VK_KHR_buffer_device_address` (for acceleration structures) if either `EXPERIMENTAL_RAY_QUERY` or `EXPERIMENTAL_RAY_TRACING_PIPELINES` were requested. + if requested_features.intersects( + wgt::Features::EXPERIMENTAL_RAY_QUERY + | wgt::Features::EXPERIMENTAL_RAY_TRACING_PIPELINES, + ) { extensions.push(khr::deferred_host_operations::NAME); extensions.push(khr::acceleration_structure::NAME); extensions.push(khr::buffer_device_address::NAME); + } + + // Require `VK_KHR_ray_query` if `EXPERIMENTAL_RAY_QUERY` was requested + if requested_features.contains(wgt::Features::EXPERIMENTAL_RAY_QUERY) { extensions.push(khr::ray_query::NAME); } + // Require `VK_KHR_ray_tracing_pipeline` if `EXPERIMENTAL_RAY_TRACING_PIPELINES` was requested + if requested_features.contains(wgt::Features::EXPERIMENTAL_RAY_TRACING_PIPELINES) { + extensions.push(khr::ray_tracing_pipeline::NAME); + } + if requested_features.contains(wgt::Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN) { extensions.push(khr::ray_tracing_position_fetch::NAME) } @@ -1615,6 +1653,16 @@ impl PhysicalDeviceProperties { let max_color_attachment_bytes_per_sample = max_color_attachments * wgt::TextureFormat::MAX_TARGET_PIXEL_BYTE_COST; + let mut max_intersection_group_count = 0; + let mut max_ray_dispatch_count = 0; + let mut max_ray_recursion_depth = 0; + + if let Some(properties) = self.ray_tracing_pipeline { + max_intersection_group_count = (1 << 24) / properties.max_ray_hit_attribute_size; + max_ray_dispatch_count = properties.max_ray_dispatch_invocation_count; + max_ray_recursion_depth = properties.max_ray_recursion_depth; + } + let max_multiview_view_count = self .multiview .map(|a| a.max_multiview_view_count.min(32)) @@ -1715,6 +1763,10 @@ impl PhysicalDeviceProperties { max_acceleration_structures_per_shader_stage, max_multiview_view_count, + + max_intersection_group_count, + max_ray_dispatch_count, + max_ray_recursion_depth, }) } @@ -1757,6 +1809,21 @@ impl PhysicalDeviceProperties { acceleration_structure.min_acceleration_structure_scratch_offset_alignment }, ), + ray_tracing_pipeline_group_data_size: self + .ray_tracing_pipeline + .map_or(0, |ray_tracing_pipeline| { + ray_tracing_pipeline.shader_group_handle_size + }), + ray_tracing_pipeline_group_data_alignment: self + .ray_tracing_pipeline + .map_or(0, |ray_tracing_pipeline| { + ray_tracing_pipeline.shader_group_handle_alignment + }), + ray_tracing_pipeline_data_offset_alignment: self + .ray_tracing_pipeline + .map_or(0, |ray_tracing_pipeline| { + ray_tracing_pipeline.shader_group_base_alignment + }), } } } @@ -1798,6 +1865,9 @@ impl super::InstanceShared { let supports_acceleration_structure = capabilities.supports_extension(khr::acceleration_structure::NAME); + let supports_ray_tracing_pipeline = + capabilities.supports_extension(khr::ray_tracing_pipeline::NAME); + let supports_mesh_shader = capabilities.supports_extension(ext::mesh_shader::NAME); let mut properties2 = vk::PhysicalDeviceProperties2KHR::default(); @@ -1829,6 +1899,13 @@ impl super::InstanceShared { properties2 = properties2.push_next(next); } + if supports_ray_tracing_pipeline { + let next = capabilities + .ray_tracing_pipeline + .insert(vk::PhysicalDeviceRayTracingPipelinePropertiesKHR::default()); + properties2 = properties2.push_next(next); + } + if supports_driver_properties { let next = capabilities .driver @@ -2336,6 +2413,7 @@ impl super::Instance { .map(|a| a.max_multiview_instance_index) .unwrap_or(0), scratch_buffer_alignment: alignments.ray_tracing_scratch_buffer_alignment, + ray_tracing_pipeline_group_data_size: alignments.ray_tracing_pipeline_group_data_size, }; let capabilities = crate::Capabilities { limits: phd_capabilities.to_wgpu_limits(), @@ -2516,6 +2594,15 @@ impl super::Adapter { } else { None }; + let ray_tracing_pipeline_fns = + if enabled_extensions.contains(&khr::ray_tracing_pipeline::NAME) { + Some(khr::ray_tracing_pipeline::Device::new( + &self.instance.raw, + &raw_device, + )) + } else { + None + }; let mesh_shading_fns = if enabled_extensions.contains(&ext::mesh_shader::NAME) { Some(ext::mesh_shader::Device::new( &self.instance.raw, @@ -2651,7 +2738,8 @@ impl super::Adapter { true, // could check `super::Workarounds::SEPARATE_ENTRY_POINTS` ); flags.set( - spv::WriterFlags::PRINT_ON_RAY_QUERY_INITIALIZATION_FAIL, + spv::WriterFlags::PRINT_ON_RAY_QUERY_INITIALIZATION_FAIL + | spv::WriterFlags::PRINT_ON_TRACE_RAYS_FAIL, self.instance.flags.contains(wgt::InstanceFlags::DEBUG) && (self.instance.instance_api_version >= vk::API_VERSION_1_3 || enabled_extensions.contains(&khr::shader_non_semantic_info::NAME)), @@ -2659,6 +2747,9 @@ impl super::Adapter { if features.contains(wgt::Features::EXPERIMENTAL_RAY_QUERY) { capabilities.push(spv::Capability::RayQueryKHR); } + if features.contains(wgt::Features::EXPERIMENTAL_RAY_TRACING_PIPELINES) { + capabilities.push(spv::Capability::RayTracingKHR); + } if features.contains(wgt::Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN) { capabilities.push(spv::Capability::RayQueryPositionFetchKHR) } @@ -2775,6 +2866,7 @@ impl super::Adapter { draw_indirect_count: indirect_count_fn, timeline_semaphore: timeline_semaphore_fn, ray_tracing: ray_tracing_fns, + ray_tracing_pipelines: ray_tracing_pipeline_fns, mesh_shading: mesh_shading_fns, external_memory_fd: external_memory_fd_fn, }, diff --git a/wgpu-hal/src/vulkan/command.rs b/wgpu-hal/src/vulkan/command.rs index 0bdc5eaedfe..77db8a88743 100644 --- a/wgpu-hal/src/vulkan/command.rs +++ b/wgpu-hal/src/vulkan/command.rs @@ -1359,6 +1359,94 @@ impl crate::CommandEncoder for super::CommandEncoder { } } + // ray tracing + + unsafe fn begin_ray_tracing_pass(&mut self, desc: &crate::RayTracingPassDescriptor<'_>) { + self.bind_point = vk::PipelineBindPoint::RAY_TRACING_KHR; + if let Some(label) = desc.label { + unsafe { self.begin_debug_marker(label) }; + self.rpass_debug_marker_active = true; + } + } + unsafe fn end_ray_tracing_pass(&mut self) { + if self.rpass_debug_marker_active { + unsafe { self.end_debug_marker() }; + self.rpass_debug_marker_active = false + } + } + + unsafe fn trace_rays( + &mut self, + count: [u32; 3], + ray_generation_group_data: crate::PipelineGroupData, + miss_group_data: crate::PipelineGroupData, + intersection_group_data: crate::PipelineGroupData, + ) { + let ray_tracing_functions = self + .device + .extension_fns + .ray_tracing + .as_ref() + .expect("Feature `RAY_TRACING` not enabled"); + + let ray_tracing_pipeline_functions = self + .device + .extension_fns + .ray_tracing_pipelines + .as_ref() + .expect("Feature `RAY_TRACING_PIPELINES` not enabled"); + + let get_device_address = |buffer: &super::Buffer| unsafe { + ray_tracing_functions + .buffer_device_address + .get_buffer_device_address( + &vk::BufferDeviceAddressInfo::default().buffer(buffer.raw), + ) + }; + + unsafe { + ray_tracing_pipeline_functions.cmd_trace_rays( + self.raw_handle(), + &vk::StridedDeviceAddressRegionKHR { + device_address: get_device_address(ray_generation_group_data.buffer) + + ray_generation_group_data.offset, + stride: ray_generation_group_data.stride, + size: ray_generation_group_data.stride /* no need for multiplying by count, vulkan requires the to be just one group */, + }, + &vk::StridedDeviceAddressRegionKHR { + device_address: get_device_address(miss_group_data.buffer) + + miss_group_data.offset, + stride: miss_group_data.stride, + size: miss_group_data.stride * miss_group_data.count, + }, + &vk::StridedDeviceAddressRegionKHR { + device_address: get_device_address(intersection_group_data.buffer) + + intersection_group_data.offset, + stride: intersection_group_data.stride, + size: intersection_group_data.stride * intersection_group_data.count, + }, + &vk::StridedDeviceAddressRegionKHR { + device_address: 0, + stride: 0, + size: 0, + }, + count[0], + count[1], + count[2], + ) + }; + } + + unsafe fn set_ray_tracing_pipeline(&mut self, pipeline: &super::RayTracingPipeline) { + unsafe { + self.device.raw.cmd_bind_pipeline( + self.active, + vk::PipelineBindPoint::RAY_TRACING_KHR, + pipeline.raw, + ) + }; + } + unsafe fn copy_acceleration_structure_to_acceleration_structure( &mut self, src: &super::AccelerationStructure, diff --git a/wgpu-hal/src/vulkan/conv.rs b/wgpu-hal/src/vulkan/conv.rs index 24c3a5ba408..e47c01d5ba8 100644 --- a/wgpu-hal/src/vulkan/conv.rs +++ b/wgpu-hal/src/vulkan/conv.rs @@ -567,6 +567,10 @@ pub fn map_buffer_usage(usage: wgt::BufferUses) -> vk::BufferUsageFlags { if usage.intersects(wgt::BufferUses::ACCELERATION_STRUCTURE_QUERY) { flags |= vk::BufferUsageFlags::TRANSFER_DST; } + if usage.intersects(wgt::BufferUses::RAY_TRACING_PIPELINE_SHADER_DATA) { + flags |= vk::BufferUsageFlags::SHADER_BINDING_TABLE_KHR + | vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS; + } flags } @@ -767,6 +771,18 @@ pub fn map_shader_stage(stage: wgt::ShaderStages) -> vk::ShaderStageFlags { if stage.contains(wgt::ShaderStages::MESH) { flags |= vk::ShaderStageFlags::MESH_EXT; } + if stage.contains(wgt::ShaderStages::RAY_GENERATION) { + flags |= vk::ShaderStageFlags::RAYGEN_KHR; + } + if stage.contains(wgt::ShaderStages::MISS) { + flags |= vk::ShaderStageFlags::MISS_KHR; + } + if stage.contains(wgt::ShaderStages::ANY_HIT) { + flags |= vk::ShaderStageFlags::ANY_HIT_KHR; + } + if stage.contains(wgt::ShaderStages::CLOSEST_HIT) { + flags |= vk::ShaderStageFlags::CLOSEST_HIT_KHR; + } flags } @@ -1024,6 +1040,12 @@ pub fn map_acceleration_structure_usage_to_barrier( | vk::PipelineStageFlags::COMPUTE_SHADER; access |= vk::AccessFlags::ACCELERATION_STRUCTURE_READ_KHR; } + if usage.contains(crate::AccelerationStructureUses::SHADER_INPUT) + && features.contains(wgt::Features::EXPERIMENTAL_RAY_TRACING_PIPELINES) + { + stages |= vk::PipelineStageFlags::RAY_TRACING_SHADER_KHR; + access |= vk::AccessFlags::ACCELERATION_STRUCTURE_READ_KHR; + } if usage.contains(crate::AccelerationStructureUses::COPY_SRC) { stages |= vk::PipelineStageFlags::ACCELERATION_STRUCTURE_BUILD_KHR; access |= vk::AccessFlags::ACCELERATION_STRUCTURE_READ_KHR; diff --git a/wgpu-hal/src/vulkan/device.rs b/wgpu-hal/src/vulkan/device.rs index 6f7d8b3c8fa..151e97a4b50 100644 --- a/wgpu-hal/src/vulkan/device.rs +++ b/wgpu-hal/src/vulkan/device.rs @@ -2324,6 +2324,178 @@ impl crate::Device for super::Device { self.counters.compute_pipelines.sub(1); } + unsafe fn create_ray_tracing_pipeline( + &self, + desc: &crate::RayTracingPipelineDescriptor< + super::PipelineLayout, + super::ShaderModule, + super::PipelineCache, + >, + ) -> Result { + let mut stages = Vec::new(); + let mut groups = Vec::new(); + + let compiled_ray_gen = self.compile_stage( + &desc.ray_generation, + naga::ShaderStage::RayGeneration, + &desc.layout.binding_map, + )?; + + groups.push( + vk::RayTracingShaderGroupCreateInfoKHR::default() + .closest_hit_shader(vk::SHADER_UNUSED_KHR) + .any_hit_shader(vk::SHADER_UNUSED_KHR) + .intersection_shader(vk::SHADER_UNUSED_KHR) + .general_shader(stages.len() as _) + .ty(vk::RayTracingShaderGroupTypeKHR::GENERAL), + ); + + stages.push(compiled_ray_gen.create_info); + + let compiled_miss = self.compile_stage( + &desc.miss, + naga::ShaderStage::Miss, + &desc.layout.binding_map, + )?; + + groups.push( + vk::RayTracingShaderGroupCreateInfoKHR::default() + .closest_hit_shader(vk::SHADER_UNUSED_KHR) + .any_hit_shader(vk::SHADER_UNUSED_KHR) + .intersection_shader(vk::SHADER_UNUSED_KHR) + .general_shader(stages.len() as _) + .ty(vk::RayTracingShaderGroupTypeKHR::GENERAL), + ); + + stages.push(compiled_miss.create_info); + + // This is to keep alive the CStrings, as the ones in the loop would be deallocated + // causing UB otherwise. + let mut compiled_stages = Vec::new(); + + for group in desc.intersection { + let compiled_closest_hits = self.compile_stage( + &group.closest_hit, + naga::ShaderStage::ClosestHit, + &desc.layout.binding_map, + )?; + + let closest_idx = stages.len(); + + stages.push(compiled_closest_hits.create_info); + + compiled_stages.push(compiled_closest_hits); + + let mut raw_hit: vk::RayTracingShaderGroupCreateInfoKHR<'_> = + vk::RayTracingShaderGroupCreateInfoKHR::default() + .closest_hit_shader(closest_idx as _) + .any_hit_shader(vk::SHADER_UNUSED_KHR) + .intersection_shader(vk::SHADER_UNUSED_KHR) + .general_shader(vk::SHADER_UNUSED_KHR) + .ty(vk::RayTracingShaderGroupTypeKHR::TRIANGLES_HIT_GROUP); + + if let Some(any_hit) = &group.any_hit { + let compiled_any_hit = self.compile_stage( + any_hit, + naga::ShaderStage::AnyHit, + &desc.layout.binding_map, + )?; + + let any_idx = stages.len(); + + stages.push(compiled_any_hit.create_info); + + compiled_stages.push(compiled_any_hit); + + raw_hit = raw_hit.any_hit_shader(any_idx as _); + } + + groups.push(raw_hit); + } + + let create_infos = [{ + vk::RayTracingPipelineCreateInfoKHR::default() + .layout(desc.layout.raw) + .max_pipeline_ray_recursion_depth(desc.max_recursion_depth) + .stages(&stages) + .groups(&groups) + }]; + + let pipeline_cache = desc + .cache + .map(|it| it.raw) + .unwrap_or(vk::PipelineCache::null()); + + let fns = self + .shared + .extension_fns + .ray_tracing_pipelines + .as_ref() + .unwrap(); + let pipelines = unsafe { + fns.create_ray_tracing_pipelines( + vk::DeferredOperationKHR::null(), + pipeline_cache, + &create_infos, + None, + ) + .map_err(|(_, e)| super::map_pipeline_err(e)) + }?; + + if let Some(raw_module) = compiled_ray_gen.temp_raw_module { + unsafe { self.shared.raw.destroy_shader_module(raw_module, None) }; + } + + if let Some(raw_module) = compiled_miss.temp_raw_module { + unsafe { self.shared.raw.destroy_shader_module(raw_module, None) }; + } + + for raw_module in compiled_stages + .into_iter() + .flat_map(|stage| stage.temp_raw_module) + { + unsafe { self.shared.raw.destroy_shader_module(raw_module, None) }; + } + + self.counters.ray_tracing_pipelines.add(1); + + Ok(super::RayTracingPipeline { raw: pipelines[0] }) + } + + unsafe fn destroy_ray_tracing_pipeline(&self, pipeline: super::RayTracingPipeline) { + unsafe { self.shared.raw.destroy_pipeline(pipeline.raw, None) }; + + self.counters.ray_tracing_pipelines.sub(1); + } + + unsafe fn get_raytracing_pipeline_group_data( + &self, + pipeline: &super::RayTracingPipeline, + groups: core::ops::Range, + ) -> Result, crate::DeviceError> { + let fns = self + .shared + .extension_fns + .ray_tracing_pipelines + .as_ref() + .unwrap(); + + let num = groups.end - groups.start; + + unsafe { + fns.get_ray_tracing_shader_group_handles( + pipeline.raw, + groups.start, + num, + (num * self + .shared + .private_caps + .ray_tracing_pipeline_group_data_size) as usize, + ) + } + .map_err(super::map_host_device_oom_err) + } + unsafe fn create_pipeline_cache( &self, desc: &crate::PipelineCacheDescriptor<'_>, @@ -2847,7 +3019,7 @@ impl crate::Device for super::Device { transform: instance.transform, custom_data_and_mask: (instance.custom_data & MAX_U24) | (u32::from(instance.mask) << 24), - shader_binding_table_record_offset_and_flags: 0, + shader_binding_table_record_offset_and_flags: (instance.custom_data & MAX_U24), acceleration_structure_reference: instance.blas_address, }; bytemuck::bytes_of(&temp).to_vec() diff --git a/wgpu-hal/src/vulkan/mod.rs b/wgpu-hal/src/vulkan/mod.rs index f90ad45d4b4..71ac1e8ac24 100644 --- a/wgpu-hal/src/vulkan/mod.rs +++ b/wgpu-hal/src/vulkan/mod.rs @@ -92,6 +92,7 @@ impl crate::Api for Api { type ShaderModule = ShaderModule; type RenderPipeline = RenderPipeline; type ComputePipeline = ComputePipeline; + type RayTracingPipeline = RayTracingPipeline; } crate::impl_dyn_resource!( @@ -111,6 +112,7 @@ crate::impl_dyn_resource!( QuerySet, Queue, RenderPipeline, + RayTracingPipeline, Sampler, ShaderModule, Surface, @@ -298,6 +300,7 @@ struct DeviceExtensionFunctions { draw_indirect_count: Option, timeline_semaphore: Option>, ray_tracing: Option, + ray_tracing_pipelines: Option, mesh_shading: Option, #[cfg_attr(not(unix), allow(dead_code))] external_memory_fd: Option, @@ -396,6 +399,11 @@ struct PrivateCapabilities { /// these usages do not have as high of an alignment requirement using the buffer as /// a scratch buffer when building acceleration structures. scratch_buffer_alignment: u32, + + /// `get_raytracing_pipeline_group_data` requires both a group count and a data size. + /// The data size parameter is just this * the group count, so we store this to not + /// require an unnecessary parameter. + ray_tracing_pipeline_group_data_size: u32, } bitflags::bitflags!( @@ -1077,6 +1085,13 @@ pub struct ComputePipeline { impl crate::DynComputePipeline for ComputePipeline {} +#[derive(Debug)] +pub struct RayTracingPipeline { + raw: vk::Pipeline, +} + +impl crate::DynRayTracingPipeline for RayTracingPipeline {} + #[derive(Debug)] pub struct PipelineCache { raw: vk::PipelineCache, diff --git a/wgpu-info/src/human.rs b/wgpu-info/src/human.rs index 1e642c49ce4..00de404f576 100644 --- a/wgpu-info/src/human.rs +++ b/wgpu-info/src/human.rs @@ -203,6 +203,10 @@ fn print_adapter(output: &mut impl io::Write, report: &AdapterReport, idx: usize max_acceleration_structures_per_shader_stage, max_multiview_view_count, + + max_intersection_group_count, + max_ray_dispatch_count, + max_ray_recursion_depth, } = limits; writeln!(output, "\t\t Max Texture Dimension 1d: {max_texture_dimension_1d}")?; writeln!(output, "\t\t Max Texture Dimension 2d: {max_texture_dimension_2d}")?; @@ -261,6 +265,9 @@ fn print_adapter(output: &mut impl io::Write, report: &AdapterReport, idx: usize writeln!(output, "\t\t Max Acceleration Structures Per Shader Stage: {max_acceleration_structures_per_shader_stage}")?; writeln!(output, "\t\t Max Multiview View Count: {max_multiview_view_count}")?; + writeln!(output, "\t\t Max Intersection Group Count: {max_intersection_group_count}")?; + writeln!(output, "\t\t Max Ray Dispatch Count: {max_ray_dispatch_count}")?; + writeln!(output, "\t\t Max Ray Recursion Depth: {max_ray_recursion_depth}")?; // This one reflects more of a wgpu implementation limitations than a hardware limit // so don't show it here. let _ = max_non_sampler_bindings; diff --git a/wgpu-naga-bridge/src/lib.rs b/wgpu-naga-bridge/src/lib.rs index 03c77a4ed78..39d14ed88a3 100644 --- a/wgpu-naga-bridge/src/lib.rs +++ b/wgpu-naga-bridge/src/lib.rs @@ -175,6 +175,10 @@ pub fn features_to_naga_capabilities( Caps::MEMORY_DECORATION_VOLATILE, features.contains(wgt::Features::MEMORY_DECORATION_VOLATILE), ); + caps.set( + Caps::RAY_TRACING_PIPELINE, + features.intersects(wgt::Features::EXPERIMENTAL_RAY_TRACING_PIPELINES), + ); caps } diff --git a/wgpu-types/src/buffer.rs b/wgpu-types/src/buffer.rs index 46d4c044c0e..b1d856f720f 100644 --- a/wgpu-types/src/buffer.rs +++ b/wgpu-types/src/buffer.rs @@ -138,10 +138,12 @@ bitflags::bitflags! { const TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT = 1 << 13; /// A buffer used to store the compacted size of an acceleration structure const ACCELERATION_STRUCTURE_QUERY = 1 << 14; + /// Buffer used for storing opaque shader data + const RAY_TRACING_PIPELINE_SHADER_DATA = 1 << 15; /// The combination of states that a buffer may be in _at the same time_. const INCLUSIVE = Self::MAP_READ.bits() | Self::COPY_SRC.bits() | Self::INDEX.bits() | Self::VERTEX.bits() | Self::UNIFORM.bits() | - Self::STORAGE_READ_ONLY.bits() | Self::INDIRECT.bits() | Self::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT.bits() | Self::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT.bits(); + Self::STORAGE_READ_ONLY.bits() | Self::INDIRECT.bits() | Self::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT.bits() | Self::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT.bits() | Self::RAY_TRACING_PIPELINE_SHADER_DATA.bits(); /// The combination of states that a buffer must exclusively be in. const EXCLUSIVE = Self::MAP_WRITE.bits() | Self::COPY_DST.bits() | Self::STORAGE_READ_WRITE.bits() | Self::ACCELERATION_STRUCTURE_SCRATCH.bits(); } diff --git a/wgpu-types/src/counters.rs b/wgpu-types/src/counters.rs index 147a96c45ee..8302a8e22b4 100644 --- a/wgpu-types/src/counters.rs +++ b/wgpu-types/src/counters.rs @@ -114,6 +114,7 @@ pub struct HalCounters { pub bind_group_layouts: InternalCounter, pub render_pipelines: InternalCounter, pub compute_pipelines: InternalCounter, + pub ray_tracing_pipelines: InternalCounter, pub pipeline_layouts: InternalCounter, pub samplers: InternalCounter, pub command_encoders: InternalCounter, diff --git a/wgpu-types/src/features.rs b/wgpu-types/src/features.rs index 1db5f5d0488..0c8fc737900 100644 --- a/wgpu-types/src/features.rs +++ b/wgpu-types/src/features.rs @@ -1470,6 +1470,10 @@ bitflags_array! { #[name("wgpu-memory-decoration-volatile")] const MEMORY_DECORATION_VOLATILE = 1 << 62; + /// Allows for constructing ray tracing pipelines. + #[name("wgpu-ray-tracing-pipelines")] + const EXPERIMENTAL_RAY_TRACING_PIPELINES = 1 << 24; + // Adding a new feature? All bits in the first u64 are used. Use the second u64 (bits 64+). } @@ -1842,7 +1846,8 @@ impl Features { | FeaturesWGPU::EXPERIMENTAL_MESH_SHADER_POINTS.bits() | FeaturesWGPU::EXPERIMENTAL_RAY_QUERY.bits() | FeaturesWGPU::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN.bits() - | FeaturesWGPU::EXPERIMENTAL_COOPERATIVE_MATRIX.bits(), + | FeaturesWGPU::EXPERIMENTAL_COOPERATIVE_MATRIX.bits() + | FeaturesWGPU::EXPERIMENTAL_RAY_TRACING_PIPELINES.bits(), FeaturesWebGPU::empty().bits(), ])) } @@ -1851,7 +1856,8 @@ impl Features { #[must_use] pub fn allowed_vertex_formats_for_blas(&self) -> Vec { let mut formats = Vec::new(); - if self.intersects(Self::EXPERIMENTAL_RAY_QUERY) { + if self.intersects(Self::EXPERIMENTAL_RAY_QUERY | Self::EXPERIMENTAL_RAY_TRACING_PIPELINES) + { formats.push(VertexFormat::Float32x3); } if self.contains(Self::EXTENDED_ACCELERATION_STRUCTURE_VERTEX_FORMATS) { diff --git a/wgpu-types/src/limits.rs b/wgpu-types/src/limits.rs index 7a994571f24..406674a351c 100644 --- a/wgpu-types/src/limits.rs +++ b/wgpu-types/src/limits.rs @@ -89,6 +89,10 @@ macro_rules! with_limits { $macro_name!(max_acceleration_structures_per_shader_stage, Ordering::Less); $macro_name!(max_multiview_view_count, Ordering::Less); + + $macro_name!(max_intersection_group_count, Ordering::Less); + $macro_name!(max_ray_dispatch_count, Ordering::Less); + $macro_name!(max_ray_recursion_depth, Ordering::Less); }; } @@ -316,6 +320,21 @@ pub struct Limits { /// The maximum number of views that can be used in multiview rendering pub max_multiview_view_count: u32, + + /// The maximum number of intersection groups in a ray tracing pipeline. + /// When a tlas instance is built with [`crate::IntersectionShaderIndex::IntersectionIndex`], + /// the value in there must be limited to this limit minus one. Requesting + /// more than 0 during device creation only makes sense if [`Features::EXPERIMENTAL_RAY_TRACING_PIPELINES`] + /// is enabled. + pub max_intersection_group_count: u32, + /// The maximum total number (`x*y*z`) of rays able to be dispatched by a trace rays call in a ray + /// tracing pass. Requesting more than 0 during device creation only makes sense if [`Features::EXPERIMENTAL_RAY_TRACING_PIPELINES`] + /// is enabled. + pub max_ray_dispatch_count: u32, + /// The maximum number that one can pass into a ray tracing pipeline creation to be the maximum ray + /// recursion depth. (the maximum of the max ray recursion depth) Requesting more than 0 during device + /// creation only makes sense if [`Features::EXPERIMENTAL_RAY_TRACING_PIPELINES`] is enabled. + pub max_ray_recursion_depth: u32, } impl Default for Limits { @@ -386,6 +405,9 @@ impl Limits { /// max_tlas_instance_count: 0, /// max_acceleration_structures_per_shader_stage: 0, /// max_multiview_view_count: 0, + /// max_intersection_group_count: 0, + /// max_ray_dispatch_count: 0, + /// max_ray_recursion_depth: 0, /// }); /// ``` /// @@ -451,6 +473,10 @@ impl Limits { max_acceleration_structures_per_shader_stage: 0, max_multiview_view_count: 0, + + max_intersection_group_count: 0, + max_ray_dispatch_count: 0, + max_ray_recursion_depth: 0, } } @@ -517,6 +543,10 @@ impl Limits { /// max_acceleration_structures_per_shader_stage: 0, /// /// max_multiview_view_count: 0, + /// + /// max_intersection_group_count: 0, + /// max_ray_dispatch_count: 0, + /// max_ray_recursion_depth: 0, /// }); /// ``` #[must_use] @@ -599,6 +629,10 @@ impl Limits { /// max_acceleration_structures_per_shader_stage: 0, /// /// max_multiview_view_count: 0, + /// + /// max_intersection_group_count: 0, + /// max_ray_dispatch_count: 0, + /// max_ray_recursion_depth: 0, /// }); /// ``` #[must_use] @@ -697,6 +731,9 @@ impl Limits { max_acceleration_structures_per_shader_stage: ALLOC_MAX_U32, max_multiview_view_count: ALLOC_MAX_U32, + max_intersection_group_count: ALLOC_MAX_U32, + max_ray_dispatch_count: ALLOC_MAX_U32, + max_ray_recursion_depth: ALLOC_MAX_U32, } } @@ -754,6 +791,18 @@ impl Limits { } } + /// The minimum guaranteed limits for acceleration structures if you enable [`Features::EXPERIMENTAL_RAY_TRACING_PIPELINES`] + /// These may change in the future (including downwards). + #[must_use] + pub const fn using_minimum_supported_ray_tracing_pipeline_values(self) -> Self { + Self { + max_intersection_group_count: 524288, // Vulkan has an exact size of each intersection group being 32, (2 ^ 24 - intersection bytes) / 32 = 524288 + max_ray_dispatch_count: 1 << 30, + max_ray_recursion_depth: 1, + ..self + } + } + /// The recommended minimum limits for mesh shaders if you enable [`Features::EXPERIMENTAL_MESH_SHADER`] /// /// These are chosen somewhat arbitrarily. They are small enough that they should cover all physical devices, diff --git a/wgpu-types/src/ray_tracing.rs b/wgpu-types/src/ray_tracing.rs index 0ccf23339a0..cc3c3bf1d71 100644 --- a/wgpu-types/src/ray_tracing.rs +++ b/wgpu-types/src/ray_tracing.rs @@ -195,3 +195,13 @@ pub const TRANSFORM_BUFFER_ALIGNMENT: crate::BufferAddress = 16; /// Alignment requirement for instance buffers used in acceleration structure builds (`build_acceleration_structures_unsafe_tlas`) pub const INSTANCE_BUFFER_ALIGNMENT: crate::BufferAddress = 16; + +/// An option of either an index into a ray tracing pipeline or some data for ray queries. +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum IntersectionShaderIndex { + /// An index into the intersection groups in a shader. + IntersectionIndex(u32), + /// Data returned to ray queries. + QueryData(u32), +} diff --git a/wgpu/src/api/blas.rs b/wgpu/src/api/blas.rs index 358902b06b3..4ff15bd0500 100644 --- a/wgpu/src/api/blas.rs +++ b/wgpu/src/api/blas.rs @@ -8,6 +8,8 @@ use wgt::{WasmNotSend, WasmNotSendSync}; use crate::dispatch; use crate::{Buffer, Label}; +pub use wgpu_types::IntersectionShaderIndex; + /// Descriptor for the size defining attributes of a triangle geometry, for a bottom level acceleration structure. pub type BlasTriangleGeometrySizeDescriptor = wgt::BlasTriangleGeometrySizeDescriptor; static_assertions::assert_impl_all!(BlasTriangleGeometrySizeDescriptor: Send, Sync); @@ -65,6 +67,9 @@ pub struct TlasInstance { /// Mask for the instance used inside the shader to filter instances. /// Reports hit only if `(shader_cull_mask & tlas_instance.mask) != 0u`. pub mask: u8, + /// Intersection group index into a ray tracing pipeline. Must be [`wgt::IntersectionShaderIndex::IntersectionIndex`] if used in a ray tracing pipeline trace, + /// and [`wgt::IntersectionShaderIndex::QueryData`] if used for ray queries. (Note: this means all of these must be the same in any TLAs.) + pub intersection_index: wgt::IntersectionShaderIndex, } impl TlasInstance { @@ -73,18 +78,26 @@ impl TlasInstance { /// - transform: Transform buffer offset in bytes (optional, required if transform buffer is present) /// - custom_data: Custom index for the instance used inside the shader (max 24 bits) /// - mask: Mask for the instance used inside the shader to filter instances + /// - intersection_index: Either a index into an intersection group in a ray tracing pipeline or data for ray queries /// /// Note: while one of these contains a reference to a BLAS that BLAS will not be dropped, /// but it can still be destroyed. Destroying a BLAS that is referenced by one or more /// TlasInstance(s) will immediately make them invalid. If one or more of those invalid /// TlasInstances is inside a TlasPackage that is attempted to be built, the build will /// generate a validation error. - pub fn new(blas: &Blas, transform: [f32; 12], custom_data: u32, mask: u8) -> Self { + pub fn new( + blas: &Blas, + transform: [f32; 12], + custom_data: u32, + mask: u8, + intersection_index: wgt::IntersectionShaderIndex, + ) -> Self { Self { blas: blas.inner.clone(), transform, custom_data, mask, + intersection_index, } } diff --git a/wgpu/src/api/command_encoder.rs b/wgpu/src/api/command_encoder.rs index 77648d40f27..9f66dfc8da6 100644 --- a/wgpu/src/api/command_encoder.rs +++ b/wgpu/src/api/command_encoder.rs @@ -107,6 +107,28 @@ impl CommandEncoder { } } + /// Begins recording of a ray tracing pass. + /// + /// This function returns a [`RayTracingPass`] object which records a single compute pass. + /// + /// As long as the returned [`RayTracingPass`] has not ended, + /// any mutating operation on this command encoder causes an error and invalidates it. + /// Note that the `'encoder` lifetime relationship protects against this, + /// but it is possible to opt out of it by calling [`RayTracingPass::forget_lifetime`]. + /// This can be useful for runtime handling of the encoder->pass + /// dependency e.g. when pass and encoder are stored in the same data structure. + pub fn begin_ray_tracing_pass<'encoder>( + &'encoder mut self, + desc: &RayTracingPassDescriptor<'_>, + ) -> RayTracingPass<'encoder> { + let rtpass = self.inner.begin_ray_tracing_pass(desc); + RayTracingPass { + inner: rtpass, + actions: Arc::clone(&self.actions), + _encoder_guard: api::PhantomDrop::default(), + } + } + /// Copy data from one buffer to another. /// /// # Panics diff --git a/wgpu/src/api/compute_pipeline.rs b/wgpu/src/api/compute_pipeline.rs index 499f967b540..b3c8c2ad536 100644 --- a/wgpu/src/api/compute_pipeline.rs +++ b/wgpu/src/api/compute_pipeline.rs @@ -75,6 +75,7 @@ pub struct ComputePipelineDescriptor<'a> { /// selected. // NOTE: keep phrasing in sync. with `FragmentState::entry_point` // NOTE: keep phrasing in sync. with `VertexState::entry_point` + // NOTE: keep phrasing in sync. with `RayTracingStage::entry_point` pub entry_point: Option<&'a str>, /// Advanced options for when this pipeline is compiled /// diff --git a/wgpu/src/api/device.rs b/wgpu/src/api/device.rs index 6f533aa26b8..91f90903833 100644 --- a/wgpu/src/api/device.rs +++ b/wgpu/src/api/device.rs @@ -4,6 +4,7 @@ use core::ops::Deref; use core::{error, fmt, future::Future, marker::PhantomData}; use crate::api::blas::{Blas, BlasGeometrySizeDescriptors, CreateBlasDescriptor}; +use crate::api::ray_tracing_pipeline::{RayTracingPipeline, RayTracingPipelineDescriptor}; use crate::api::tlas::{CreateTlasDescriptor, Tlas}; use crate::util::Mutex; use crate::*; @@ -272,6 +273,16 @@ impl Device { ComputePipeline { inner: pipeline } } + /// Creates a [`RayTracingPipeline`]. + #[must_use] + pub fn create_ray_tracing_pipeline( + &self, + desc: &RayTracingPipelineDescriptor<'_>, + ) -> RayTracingPipeline { + let pipeline = self.inner.create_ray_tracing_pipeline(desc); + RayTracingPipeline { inner: pipeline } + } + /// Creates a [`Buffer`]. #[must_use] pub fn create_buffer(&self, desc: &BufferDescriptor<'_>) -> Buffer { diff --git a/wgpu/src/api/mod.rs b/wgpu/src/api/mod.rs index b564aa75f16..50a9956b705 100644 --- a/wgpu/src/api/mod.rs +++ b/wgpu/src/api/mod.rs @@ -40,6 +40,8 @@ mod pipeline_cache; mod pipeline_layout; mod query_set; mod queue; +mod ray_tracing_pass; +mod ray_tracing_pipeline; mod render_bundle; mod render_bundle_encoder; mod render_pass; @@ -70,6 +72,8 @@ pub use pipeline_cache::*; pub use pipeline_layout::*; pub use query_set::*; pub use queue::*; +pub use ray_tracing_pass::*; +pub use ray_tracing_pipeline::*; pub use render_bundle::*; pub use render_bundle_encoder::*; pub use render_pass::*; diff --git a/wgpu/src/api/ray_tracing_pass.rs b/wgpu/src/api/ray_tracing_pass.rs new file mode 100644 index 00000000000..988b325b1c0 --- /dev/null +++ b/wgpu/src/api/ray_tracing_pass.rs @@ -0,0 +1,112 @@ +use wgt::DynamicOffset; + +use crate::{api::SharedDeferredCommandBufferActions, *}; + +/// In-progress recording of a ray tracing pass. +/// +/// It can be created with [`CommandEncoder::begin_ray_tracing_pass`]. +#[derive(Debug)] +pub struct RayTracingPass<'encoder> { + pub(crate) inner: dispatch::DispatchRayTracingPass, + + /// Shared with CommandEncoder to enqueue deferred actions from within a pass. + pub(crate) actions: SharedDeferredCommandBufferActions, + + /// This lifetime is used to protect the [`CommandEncoder`] from being used + /// while the pass is alive. This needs to be PhantomDrop to prevent the lifetime + /// from being shortened. + pub(crate) _encoder_guard: crate::api::PhantomDrop<&'encoder ()>, +} + +#[cfg(send_sync)] +static_assertions::assert_impl_all!(RayTracingPass<'_>: Send, Sync); + +crate::cmp::impl_eq_ord_hash_proxy!(RayTracingPass<'_> => .inner); + +impl RayTracingPass<'_> { + /// Drops the lifetime relationship to the parent command encoder, making usage of + /// the encoder while this pass is recorded a run-time error instead. + /// + /// Attention: As long as the ray tracing pass has not been ended, any mutating operation on the parent + /// command encoder will cause a run-time error and invalidate it! + /// By default, the lifetime constraint prevents this, but it can be useful + /// to handle this at run time, such as when storing the pass and encoder in the same + /// data structure. + /// + /// This operation has no effect on pass recording. + /// It's a safe operation, since [`CommandEncoder`] is in a locked state as long as the pass is active + /// regardless of the lifetime constraint or its absence. + pub fn forget_lifetime(self) -> RayTracingPass<'static> { + RayTracingPass { + inner: self.inner, + actions: self.actions, + _encoder_guard: crate::api::PhantomDrop::default(), + } + } + + /// Sets the active ray tracing pipeline. + pub fn set_pipeline(&mut self, pipeline: &RayTracingPipeline) { + self.inner.set_pipeline(&pipeline.inner); + } + + /// Sets the active bind group for a given bind group index. The bind group layout + /// in the active pipeline when the `trace_rays()` function is called must match the layout of this bind group. + /// + /// If the bind group have dynamic offsets, provide them in the binding order. + /// These offsets have to be aligned to [`Limits::min_uniform_buffer_offset_alignment`] + /// or [`Limits::min_storage_buffer_offset_alignment`] appropriately. + pub fn set_bind_group<'a, BG>(&mut self, index: u32, bind_group: BG, offsets: &[DynamicOffset]) + where + Option<&'a BindGroup>: From, + { + let bg: Option<&BindGroup> = bind_group.into(); + let bg = bg.map(|bg| &bg.inner); + self.inner.set_bind_group(index, bg, offsets); + } + + /// Inserts debug marker. + pub fn insert_debug_marker(&mut self, label: &str) { + self.inner.insert_debug_marker(label); + } + + /// Start record commands and group it into debug marker group. + pub fn push_debug_group(&mut self, label: &str) { + self.inner.push_debug_group(label); + } + + /// Stops command recording and creates debug group. + pub fn pop_debug_group(&mut self) { + self.inner.pop_debug_group(); + } + + /// Dispatches rays in the current ray tracing pipeline. + pub fn trace_rays(&mut self, x: u32, y: u32, z: u32) { + self.inner.trace_rays(x, y, z); + } +} + +/// [`Features::IMMEDIATES`] must be enabled on the device in order to call these functions. +impl RayTracingPass<'_> { + /// Set immediate data for subsequent dispatch calls. + /// + /// Write the bytes in `data` at offset `offset` within immediate data + /// storage. Both `offset` and the length of `data` must be + /// multiples of [`crate::IMMEDIATE_DATA_ALIGNMENT`], which is always 4. + /// + /// For example, if `offset` is `4` and `data` is eight bytes long, this + /// call will write `data` to bytes `4..12` of immediate data storage. + pub fn set_immediates(&mut self, offset: u32, data: &[u8]) { + self.inner.set_immediates(offset, data); + } +} + +/// Describes the attachments of a ray tracing pass. +/// +/// For use with [`CommandEncoder::begin_ray_tracing_pass`]. +#[derive(Clone, Default, Debug)] +pub struct RayTracingPassDescriptor<'a> { + /// Debug label of the ray tracing pass. This will show up in graphics debuggers for easy identification. + pub label: Label<'a>, +} +#[cfg(send_sync)] +static_assertions::assert_impl_all!(RayTracingPassDescriptor<'_>: Send, Sync); diff --git a/wgpu/src/api/ray_tracing_pipeline.rs b/wgpu/src/api/ray_tracing_pipeline.rs new file mode 100644 index 00000000000..f8915a0768e --- /dev/null +++ b/wgpu/src/api/ray_tracing_pipeline.rs @@ -0,0 +1,113 @@ +use crate::*; + +/// Handle to a ray tracing pipeline. +/// +/// A `RayTracingPipeline` object represents a graphics pipeline and its stages, bindings, vertex +/// buffers and targets. It can be created with [`Device::create_ray_tracing_pipeline`]. +#[derive(Debug, Clone)] +pub struct RayTracingPipeline { + pub(crate) inner: dispatch::DispatchRayTracingPipeline, +} +#[cfg(send_sync)] +static_assertions::assert_impl_all!(RayTracingPipeline: Send, Sync); + +crate::cmp::impl_eq_ord_hash_proxy!(RayTracingPipeline => .inner); + +impl RayTracingPipeline { + /// Get an object representing the bind group layout at a given index. + /// + /// If this pipeline was created with a [default layout][RayTracingPipelineDescriptor::layout], then + /// bind groups created with the returned `BindGroupLayout` can only be used with this pipeline. + /// + /// This method will raise a validation error if there is no bind group layout at `index`. + pub fn get_bind_group_layout(&self, index: u32) -> BindGroupLayout { + let layout = self.inner.get_bind_group_layout(index); + BindGroupLayout { inner: layout } + } + + #[cfg(custom)] + /// Returns custom implementation of RayTracingPipeline (if custom backend and is internally T) + pub fn as_custom(&self) -> Option<&T> { + self.inner.as_custom() + } +} + +/// Describes a stage in a ray tracing pipeline +/// +/// For use in [`RayTracingPipelineDescriptor`] +#[derive(Clone, Debug)] +pub struct RayTracingStage<'a> { + /// The compiled shader module for this stage. + pub module: &'a ShaderModule, + /// The name of the entry point in the compiled shader to use. + /// + /// If [`Some`], there must be a shader entry point with this name in `module` of the stage required. + /// Otherwise, expect exactly one entry point in `module` of the stage required, which will be + /// selected. + // NOTE: keep phrasing in sync. with `ComputePipelineDescriptor::entry_point` + // NOTE: keep phrasing in sync. with `VertexState::entry_point` + // NOTE: keep phrasing in sync. with `FragmentState::entry_point` + pub entry_point: Option<&'a str>, + /// Advanced options for when this pipeline is compiled + /// + /// This implements `Default`, and for most users can be set to `Default::default()` + pub compilation_options: PipelineCompilationOptions<'a>, +} + +/// Describes a group of stages to be called for an intersection in a ray tracing pipeline +/// +/// For use in [`RayTracingPipelineDescriptor`] +#[derive(Clone, Debug)] +pub enum RayTracingIntersectionDescriptor<'a> { + /// This group of shaders may only be used when + /// a BLAS with triangle geometry is intersected. + Triangle { + /// Stage to call if, after the entire intersection process is complete, a triangle within an instance bound to this + /// descriptor is is the closest hit. + closest_hit: RayTracingStage<'a>, + /// Optional stage to call when a trianle within an instance bound to this descriptor is hit at any point during the + /// intersection process. + any_hit: Option>, + }, +} + +/// Describes a ray tracing pipeline. +/// +/// For use with [`Device::create_ray_tracing_pipeline`]. +#[derive(Clone, Debug)] +pub struct RayTracingPipelineDescriptor<'a> { + /// Debug label of the pipeline. This will show up in graphics debuggers for easy identification. + pub label: Label<'a>, + /// The layout of bind groups for this pipeline. + /// + /// If this is set, then [`Device::create_ray_tracing_pipeline`] will raise a validation error if + /// the layout doesn't match what the shader module(s) expect. + /// + /// Using the same [`PipelineLayout`] for many [`RayTracingPipeline`] or [`RenderPipeline`] or [`ComputePipeline`] + /// pipelines guarantees that you don't have to rebind any resources when switching between + /// those pipelines. + /// + /// ## Default pipeline layout + /// + /// If `layout` is `None`, then the pipeline has a [default layout] created and used instead. + /// The default layout is deduced from the shader modules. + /// + /// You can use [`RayTracingPipeline::get_bind_group_layout`] to create bind groups for use with the + /// default layout. However, these bind groups cannot be used with any other pipelines. This is + /// convenient for simple pipelines, but using an explicit layout is recommended in most cases. + /// + /// [default layout]: https://www.w3.org/TR/webgpu/#default-pipeline-layout + pub layout: Option<&'a PipelineLayout>, + /// The ray generation stage. The shader stage invoked by command encoder trace rays. + pub ray_generation: RayTracingStage<'a>, + /// The miss stage. Called if a ray does not hit any object. + pub miss: RayTracingStage<'a>, + /// The list of intersection descriptors + pub intersection_descs: &'a [RayTracingIntersectionDescriptor<'a>], + /// The maximum depth of entry points able to be recursed into, discounting the ray generation stage. + pub max_recersion_depth: u32, + /// The pipeline cache to use when creating this pipeline. + pub cache: Option<&'a PipelineCache>, +} +#[cfg(send_sync)] +static_assertions::assert_impl_all!(RayTracingPipelineDescriptor<'_>: Send, Sync); diff --git a/wgpu/src/api/render_pipeline.rs b/wgpu/src/api/render_pipeline.rs index b7312f9e150..9ec5adaa967 100644 --- a/wgpu/src/api/render_pipeline.rs +++ b/wgpu/src/api/render_pipeline.rs @@ -102,6 +102,7 @@ pub struct VertexState<'a> { /// selected. // NOTE: keep phrasing in sync. with `ComputePipelineDescriptor::entry_point` // NOTE: keep phrasing in sync. with `FragmentState::entry_point` + // NOTE: keep phrasing in sync. with `RayTracingStage::entry_point` pub entry_point: Option<&'a str>, /// Advanced options for when this pipeline is compiled /// diff --git a/wgpu/src/backend/custom.rs b/wgpu/src/backend/custom.rs index 4d205e0012d..3a2f735e67f 100644 --- a/wgpu/src/backend/custom.rs +++ b/wgpu/src/backend/custom.rs @@ -86,10 +86,12 @@ dyn_type!(pub ref struct DynQuerySet(dyn QuerySetInterface)); dyn_type!(pub ref struct DynPipelineLayout(dyn PipelineLayoutInterface)); dyn_type!(pub ref struct DynRenderPipeline(dyn RenderPipelineInterface)); dyn_type!(pub ref struct DynComputePipeline(dyn ComputePipelineInterface)); +dyn_type!(pub ref struct DynRayTracingPipeline(dyn RayTracingPipelineInterface)); dyn_type!(pub ref struct DynPipelineCache(dyn PipelineCacheInterface)); dyn_type!(pub mut struct DynCommandEncoder(dyn CommandEncoderInterface)); dyn_type!(pub mut struct DynComputePass(dyn ComputePassInterface)); dyn_type!(pub mut struct DynRenderPass(dyn RenderPassInterface)); +dyn_type!(pub mut struct DynRayTracingPass(dyn RayTracingPassInterface)); dyn_type!(pub mut struct DynCommandBuffer(dyn CommandBufferInterface)); dyn_type!(pub mut struct DynRenderBundleEncoder(dyn RenderBundleEncoderInterface)); dyn_type!(pub ref struct DynRenderBundle(dyn RenderBundleInterface)); diff --git a/wgpu/src/backend/webgpu.rs b/wgpu/src/backend/webgpu.rs index 364158a09ac..14edd02d35b 100644 --- a/wgpu/src/backend/webgpu.rs +++ b/wgpu/src/backend/webgpu.rs @@ -877,6 +877,9 @@ fn map_wgt_limits(limits: webgpu_sys::GpuSupportedLimits) -> wgt::Limits { .max_acceleration_structures_per_shader_stage, max_multiview_view_count: wgt::Limits::default().max_multiview_view_count, + max_intersection_group_count: wgt::Limits::default().max_intersection_group_count, + max_ray_dispatch_count: wgt::Limits::default().max_ray_dispatch_count, + max_ray_recursion_depth: wgt::Limits::default().max_ray_recursion_depth, } } @@ -1392,6 +1395,12 @@ pub struct WebComputePipeline { ident: crate::cmp::Identifier, } +#[derive(Debug, Clone)] +pub struct WebRayTracingPipeline { + /// Unique identifier for this RayTracingPipeline. + ident: crate::cmp::Identifier, +} // no ray tracing pipelines on web + #[derive(Debug, Clone)] pub struct WebPipelineCache { /// Unique identifier for this PipelineCache. @@ -1419,6 +1428,12 @@ pub struct WebRenderPassEncoder { ident: crate::cmp::Identifier, } +#[derive(Debug, Clone)] +pub struct WebRayTracingPassEncoder { + /// Unique identifier for this RayTracingPassEncoder. + ident: crate::cmp::Identifier, +} // no ray tracing pipelines on web + #[derive(Debug)] pub struct WebCommandBuffer { pub(crate) inner: webgpu_sys::GpuCommandBuffer, @@ -1492,10 +1507,12 @@ impl_send_sync!(WebQuerySet); impl_send_sync!(WebPipelineLayout); impl_send_sync!(WebRenderPipeline); impl_send_sync!(WebComputePipeline); +impl_send_sync!(WebRayTracingPipeline); impl_send_sync!(WebPipelineCache); impl_send_sync!(WebCommandEncoder); impl_send_sync!(WebComputePassEncoder); impl_send_sync!(WebRenderPassEncoder); +impl_send_sync!(WebRayTracingPassEncoder); impl_send_sync!(WebCommandBuffer); impl_send_sync!(WebRenderBundleEncoder); impl_send_sync!(WebRenderBundle); @@ -1522,11 +1539,13 @@ crate::cmp::impl_eq_ord_hash_proxy!(WebQuerySet => .ident); crate::cmp::impl_eq_ord_hash_proxy!(WebPipelineLayout => .ident); crate::cmp::impl_eq_ord_hash_proxy!(WebRenderPipeline => .ident); crate::cmp::impl_eq_ord_hash_proxy!(WebComputePipeline => .ident); +crate::cmp::impl_eq_ord_hash_proxy!(WebRayTracingPipeline => .ident); crate::cmp::impl_eq_ord_hash_proxy!(WebPipelineCache => .ident); crate::cmp::impl_eq_ord_hash_proxy!(WebCommandEncoder => .ident); crate::cmp::impl_eq_ord_hash_proxy!(WebComputePassEncoder => .ident); crate::cmp::impl_eq_ord_hash_proxy!(WebRenderPassEncoder => .ident); crate::cmp::impl_eq_ord_hash_proxy!(WebCommandBuffer => .ident); +crate::cmp::impl_eq_ord_hash_proxy!(WebRayTracingPassEncoder => .ident); crate::cmp::impl_eq_ord_hash_proxy!(WebRenderBundleEncoder => .ident); crate::cmp::impl_eq_ord_hash_proxy!(WebRenderBundle => .ident); crate::cmp::impl_eq_ord_hash_proxy!(WebSurface => .ident); @@ -2355,6 +2374,13 @@ impl dispatch::DeviceInterface for WebDevice { .into() } + fn create_ray_tracing_pipeline( + &self, + _desc: &crate::RayTracingPipelineDescriptor<'_>, + ) -> dispatch::DispatchRayTracingPipeline { + unreachable!("ray tracing is not web.") + } + unsafe fn create_pipeline_cache( &self, _desc: &crate::PipelineCacheDescriptor<'_>, @@ -3018,6 +3044,17 @@ impl Drop for WebComputePipeline { } } +impl dispatch::RayTracingPipelineInterface for WebRayTracingPipeline { + fn get_bind_group_layout(&self, _index: u32) -> dispatch::DispatchBindGroupLayout { + unreachable!("ray tracing is not web.") + } +} +impl Drop for WebRayTracingPipeline { + fn drop(&mut self) { + // no-op + } +} + impl dispatch::CommandEncoderInterface for WebCommandEncoder { fn copy_buffer_to_buffer( &self, @@ -3245,6 +3282,13 @@ impl dispatch::CommandEncoderInterface for WebCommandEncoder { .into() } + fn begin_ray_tracing_pass( + &self, + _desc: &crate::RayTracingPassDescriptor<'_>, + ) -> dispatch::DispatchRayTracingPass { + unreachable!("ray tracing is not web.") + } + fn finish(&mut self) -> dispatch::DispatchCommandBuffer { let label = self.inner.label(); let buffer = if label.is_empty() { @@ -3753,6 +3797,43 @@ impl Drop for WebRenderPassEncoder { } } +impl dispatch::RayTracingPassInterface for WebRayTracingPassEncoder { + fn set_pipeline(&mut self, _pipeline: &dispatch::DispatchRayTracingPipeline) { + unreachable!("Ray tracing pipelines are unavailable on the web.") + } + fn set_bind_group( + &mut self, + _index: u32, + _bind_group: Option<&dispatch::DispatchBindGroup>, + _offsets: &[crate::DynamicOffset], + ) { + unreachable!("Ray tracing pipelines are unavailable on the web.") + } + fn set_immediates(&mut self, _offset: u32, _data: &[u8]) { + unreachable!("Ray tracing pipelines are unavailable on the web.") + } + + fn insert_debug_marker(&mut self, _label: &str) { + unreachable!("Ray tracing pipelines are unavailable on the web.") + } + fn push_debug_group(&mut self, _group_label: &str) { + unreachable!("Ray tracing pipelines are unavailable on the web.") + } + fn pop_debug_group(&mut self) { + unreachable!("Ray tracing pipelines are unavailable on the web.") + } + + fn trace_rays(&mut self, _x: u32, _y: u32, _z: u32) { + unreachable!("Ray tracing pipelines are unavailable on the web.") + } +} + +impl Drop for WebRayTracingPassEncoder { + fn drop(&mut self) { + // no-op + } +} + impl dispatch::CommandBufferInterface for WebCommandBuffer {} impl Drop for WebCommandBuffer { fn drop(&mut self) { diff --git a/wgpu/src/backend/wgpu_core.rs b/wgpu/src/backend/wgpu_core.rs index 9bf68c12978..e0c63d3c726 100644 --- a/wgpu/src/backend/wgpu_core.rs +++ b/wgpu/src/backend/wgpu_core.rs @@ -573,6 +573,13 @@ pub struct CoreRenderPipeline { error_sink: ErrorSink, } +#[derive(Debug)] +pub struct CoreRayTracingPipeline { + pub(crate) context: ContextWgpuCore, + id: wgc::id::RayTracingPipelineId, + error_sink: ErrorSink, +} + #[derive(Debug)] pub struct CoreComputePass { pub(crate) context: ContextWgpuCore, @@ -589,6 +596,14 @@ pub struct CoreRenderPass { id: crate::cmp::Identifier, } +#[derive(Debug)] +pub struct CoreRayTracingPass { + pub(crate) context: ContextWgpuCore, + pass: wgc::command::RayTracingPass, + error_sink: ErrorSink, + id: crate::cmp::Identifier, +} + #[derive(Debug)] pub struct CoreCommandEncoder { pub(crate) context: ContextWgpuCore, @@ -759,10 +774,12 @@ crate::cmp::impl_eq_ord_hash_proxy!(CoreQuerySet => .id); crate::cmp::impl_eq_ord_hash_proxy!(CorePipelineLayout => .id); crate::cmp::impl_eq_ord_hash_proxy!(CoreRenderPipeline => .id); crate::cmp::impl_eq_ord_hash_proxy!(CoreComputePipeline => .id); +crate::cmp::impl_eq_ord_hash_proxy!(CoreRayTracingPipeline => .id); crate::cmp::impl_eq_ord_hash_proxy!(CorePipelineCache => .id); crate::cmp::impl_eq_ord_hash_proxy!(CoreCommandEncoder => .id); crate::cmp::impl_eq_ord_hash_proxy!(CoreComputePass => .id); crate::cmp::impl_eq_ord_hash_proxy!(CoreRenderPass => .id); +crate::cmp::impl_eq_ord_hash_proxy!(CoreRayTracingPass => .id); crate::cmp::impl_eq_ord_hash_proxy!(CoreCommandBuffer => .id); crate::cmp::impl_eq_ord_hash_proxy!(CoreRenderBundleEncoder => .id); crate::cmp::impl_eq_ord_hash_proxy!(CoreRenderBundle => .id); @@ -1571,6 +1588,78 @@ impl dispatch::DeviceInterface for CoreDevice { .into() } + fn create_ray_tracing_pipeline( + &self, + desc: &crate::RayTracingPipelineDescriptor<'_>, + ) -> dispatch::DispatchRayTracingPipeline { + use wgc::pipeline as pipe; + + fn downcast_rt_stage<'a>( + stage: &'a crate::RayTracingStage<'a>, + ) -> pipe::ProgrammableStageDescriptor<'a> { + pipe::ProgrammableStageDescriptor { + module: stage.module.inner.as_core().id, + entry_point: stage.entry_point.map(Borrowed), + constants: stage + .compilation_options + .constants + .iter() + .map(|&(key, value)| (String::from(key), value)) + .collect(), + zero_initialize_workgroup_memory: stage + .compilation_options + .zero_initialize_workgroup_memory, + } + } + + let descriptor = pipe::RayTracingPipelineDescriptor { + label: desc.label.map(Borrowed), + layout: desc.layout.map(|pll| pll.inner.as_core().id), + ray_generation: downcast_rt_stage(&desc.ray_generation), + miss: downcast_rt_stage(&desc.miss), + intersections: desc + .intersection_descs + .iter() + .map(|intersection_desc| match intersection_desc { + crate::RayTracingIntersectionDescriptor::Triangle { + closest_hit, + any_hit, + } => pipe::RayTracingIntersectionDescriptor::Triangle { + closest_hit: downcast_rt_stage(closest_hit), + any_hit: any_hit.as_ref().map(|stage| downcast_rt_stage(stage)), + }, + }) + .collect::>(), + max_recursion_depth: desc.max_recersion_depth, + cache: desc.cache.map(|cache| cache.inner.as_core().id), + }; + + let (id, error) = + self.context + .0 + .device_create_ray_tracing_pipeline(self.id, &descriptor, None); + if let Some(cause) = error { + if let wgc::pipeline::CreateRayTracingPipelineError::Internal { stage, ref error } = + cause + { + log::error!("Shader translation error for stage {:?}: {}", stage, error); + log::error!("Please report it to https://github.com/gfx-rs/wgpu"); + } + self.context.handle_error( + &self.error_sink, + cause, + desc.label, + "Device::create_ray_tracing_pipeline", + ); + } + CoreRayTracingPipeline { + context: self.context.clone(), + id, + error_sink: Arc::clone(&self.error_sink), + } + .into() + } + unsafe fn create_pipeline_cache( &self, desc: &crate::PipelineCacheDescriptor<'_>, @@ -2425,6 +2514,33 @@ impl Drop for CoreRenderPipeline { } } +impl dispatch::RayTracingPipelineInterface for CoreRayTracingPipeline { + fn get_bind_group_layout(&self, index: u32) -> dispatch::DispatchBindGroupLayout { + let (id, error) = self + .context + .0 + .ray_tracing_pipeline_get_bind_group_layout(self.id, index, None); + if let Some(err) = error { + self.context.handle_error_nolabel( + &self.error_sink, + err, + "RayTracingPipeline::get_bind_group_layout", + ) + } + CoreBindGroupLayout { + context: self.context.clone(), + id, + } + .into() + } +} + +impl Drop for CoreRayTracingPipeline { + fn drop(&mut self) { + self.context.0.ray_tracing_pipeline_drop(self.id) + } +} + impl dispatch::ComputePipelineInterface for CoreComputePipeline { fn get_bind_group_layout(&self, index: u32) -> dispatch::DispatchBindGroupLayout { let (id, error) = self @@ -2657,6 +2773,35 @@ impl dispatch::CommandEncoderInterface for CoreCommandEncoder { .into() } + fn begin_ray_tracing_pass( + &self, + desc: &crate::RayTracingPassDescriptor<'_>, + ) -> dispatch::DispatchRayTracingPass { + let (pass, err) = self.context.0.command_encoder_begin_ray_tracing_pass( + self.id, + &wgc::command::RayTracingPassDescriptor { + label: desc.label.map(Borrowed), + }, + ); + + if let Some(cause) = err { + self.context.handle_error( + &self.error_sink, + cause, + desc.label, + "CommandEncoder::begin_ray_tracing_pass", + ); + } + + CoreRayTracingPass { + context: self.context.clone(), + pass, + error_sink: self.error_sink.clone(), + id: crate::cmp::Identifier::create(), + } + .into() + } + fn finish(&mut self) -> dispatch::DispatchCommandBuffer { let descriptor = wgt::CommandBufferDescriptor::default(); let (id, opt_label_and_error) = @@ -2873,6 +3018,7 @@ impl dispatch::CommandEncoderInterface for CoreCommandEncoder { transform: &instance.transform, custom_data: instance.custom_data, mask: instance.mask, + intersection_index: instance.intersection_index, }) }); wgc::ray_tracing::TlasPackage { @@ -3781,6 +3927,135 @@ impl Drop for CoreRenderPass { } } +impl dispatch::RayTracingPassInterface for CoreRayTracingPass { + fn set_pipeline(&mut self, pipeline: &dispatch::DispatchRayTracingPipeline) { + let pipeline = pipeline.as_core(); + + if let Err(cause) = self + .context + .0 + .ray_tracing_pass_set_pipeline(&mut self.pass, pipeline.id) + { + self.context.handle_error( + &self.error_sink, + cause, + self.pass.label(), + "RayTracingPass::set_pipeline", + ); + } + } + + fn set_bind_group( + &mut self, + index: u32, + bind_group: Option<&dispatch::DispatchBindGroup>, + offsets: &[crate::DynamicOffset], + ) { + let bg = bind_group.map(|bg| bg.as_core().id); + + if let Err(cause) = + self.context + .0 + .ray_tracing_pass_set_bind_group(&mut self.pass, index, bg, offsets) + { + self.context.handle_error( + &self.error_sink, + cause, + self.pass.label(), + "RayTracingPass::set_bind_group", + ); + } + } + + fn set_immediates(&mut self, offset: u32, data: &[u8]) { + if let Err(cause) = + self.context + .0 + .ray_tracing_pass_set_immediates(&mut self.pass, offset, data) + { + self.context.handle_error( + &self.error_sink, + cause, + self.pass.label(), + "RayTracingPass::set_immediates", + ); + } + } + + fn insert_debug_marker(&mut self, label: &str) { + if let Err(cause) = + self.context + .0 + .ray_tracing_pass_insert_debug_marker(&mut self.pass, label, 0) + { + self.context.handle_error( + &self.error_sink, + cause, + self.pass.label(), + "RayTracingPass::insert_debug_marker", + ); + } + } + + fn push_debug_group(&mut self, group_label: &str) { + if let Err(cause) = + self.context + .0 + .ray_tracing_pass_push_debug_group(&mut self.pass, group_label, 0) + { + self.context.handle_error( + &self.error_sink, + cause, + self.pass.label(), + "RayTracingPass::push_debug_group", + ); + } + } + + fn pop_debug_group(&mut self) { + if let Err(cause) = self + .context + .0 + .ray_tracing_pass_pop_debug_group(&mut self.pass) + { + self.context.handle_error( + &self.error_sink, + cause, + self.pass.label(), + "RayTracingPass::pop_debug_group", + ); + } + } + + fn trace_rays(&mut self, x: u32, y: u32, z: u32) { + if let Err(cause) = self + .context + .0 + .ray_tracing_pass_trace_rays(&mut self.pass, x, y, z) + { + self.context.handle_error( + &self.error_sink, + cause, + self.pass.label(), + "RayTracingPass::trace_rays", + ); + } + } +} + +impl Drop for CoreRayTracingPass { + fn drop(&mut self) { + if let Err(cause) = self.context.0.ray_tracing_pass_end(&mut self.pass) { + self.context.handle_error( + &self.error_sink, + cause, + self.pass.label(), + "RayTracingPass::end", + ); + } + } +} + impl dispatch::RenderBundleEncoderInterface for CoreRenderBundleEncoder { fn set_pipeline(&mut self, pipeline: &dispatch::DispatchRenderPipeline) { let pipeline = pipeline.as_core(); diff --git a/wgpu/src/dispatch.rs b/wgpu/src/dispatch.rs index 2f8031ee80b..9fa9b70e95b 100644 --- a/wgpu/src/dispatch.rs +++ b/wgpu/src/dispatch.rs @@ -176,6 +176,10 @@ pub trait DeviceInterface: CommonTraits { &self, desc: &crate::ComputePipelineDescriptor<'_>, ) -> DispatchComputePipeline; + fn create_ray_tracing_pipeline( + &self, + desc: &crate::RayTracingPipelineDescriptor<'_>, + ) -> DispatchRayTracingPipeline; unsafe fn create_pipeline_cache( &self, desc: &crate::PipelineCacheDescriptor<'_>, @@ -309,6 +313,9 @@ pub trait RenderPipelineInterface: CommonTraits { pub trait ComputePipelineInterface: CommonTraits { fn get_bind_group_layout(&self, index: u32) -> DispatchBindGroupLayout; } +pub trait RayTracingPipelineInterface: CommonTraits { + fn get_bind_group_layout(&self, index: u32) -> DispatchBindGroupLayout; +} pub trait PipelineCacheInterface: CommonTraits { fn get_data(&self) -> Option>; } @@ -342,6 +349,10 @@ pub trait CommandEncoderInterface: CommonTraits { fn begin_compute_pass(&self, desc: &crate::ComputePassDescriptor<'_>) -> DispatchComputePass; fn begin_render_pass(&self, desc: &crate::RenderPassDescriptor<'_>) -> DispatchRenderPass; + fn begin_ray_tracing_pass( + &self, + desc: &crate::RayTracingPassDescriptor<'_>, + ) -> DispatchRayTracingPass; fn finish(&mut self) -> DispatchCommandBuffer; fn clear_texture( @@ -530,6 +541,22 @@ pub trait RenderPassInterface: CommonTraits + Drop { fn execute_bundles(&mut self, render_bundles: &mut dyn Iterator); } +pub trait RayTracingPassInterface: CommonTraits + Drop { + fn set_pipeline(&mut self, pipeline: &DispatchRayTracingPipeline); + fn set_bind_group( + &mut self, + index: u32, + bind_group: Option<&DispatchBindGroup>, + offsets: &[crate::DynamicOffset], + ); + fn set_immediates(&mut self, offset: u32, data: &[u8]); + + fn insert_debug_marker(&mut self, label: &str); + fn push_debug_group(&mut self, group_label: &str); + fn pop_debug_group(&mut self); + + fn trace_rays(&mut self, x: u32, y: u32, z: u32); +} pub trait RenderBundleEncoderInterface: CommonTraits { fn set_pipeline(&mut self, pipeline: &DispatchRenderPipeline); @@ -983,10 +1010,12 @@ dispatch_types! {ref type DispatchQuerySet: QuerySetInterface = CoreQuerySet, We dispatch_types! {ref type DispatchPipelineLayout: PipelineLayoutInterface = CorePipelineLayout, WebPipelineLayout, DynPipelineLayout} dispatch_types! {ref type DispatchRenderPipeline: RenderPipelineInterface = CoreRenderPipeline, WebRenderPipeline, DynRenderPipeline} dispatch_types! {ref type DispatchComputePipeline: ComputePipelineInterface = CoreComputePipeline, WebComputePipeline, DynComputePipeline} +dispatch_types! {ref type DispatchRayTracingPipeline: RayTracingPipelineInterface = CoreRayTracingPipeline, WebRayTracingPipeline, DynRayTracingPipeline} dispatch_types! {ref type DispatchPipelineCache: PipelineCacheInterface = CorePipelineCache, WebPipelineCache, DynPipelineCache} dispatch_types! {mut type DispatchCommandEncoder: CommandEncoderInterface = CoreCommandEncoder, WebCommandEncoder, DynCommandEncoder} dispatch_types! {mut type DispatchComputePass: ComputePassInterface = CoreComputePass, WebComputePassEncoder, DynComputePass} dispatch_types! {mut type DispatchRenderPass: RenderPassInterface = CoreRenderPass, WebRenderPassEncoder, DynRenderPass} +dispatch_types! {mut type DispatchRayTracingPass: RayTracingPassInterface = CoreRayTracingPass, WebRayTracingPassEncoder, DynRayTracingPass} dispatch_types! {mut type DispatchCommandBuffer: CommandBufferInterface = CoreCommandBuffer, WebCommandBuffer, DynCommandBuffer} dispatch_types! {mut type DispatchRenderBundleEncoder: RenderBundleEncoderInterface = CoreRenderBundleEncoder, WebRenderBundleEncoder, DynRenderBundleEncoder} dispatch_types! {ref type DispatchRenderBundle: RenderBundleInterface = CoreRenderBundle, WebRenderBundle, DynRenderBundle}