Skip to content

Commit d432113

Browse files
robtfmJMS55
authored andcommitted
add raycast special types (backport to 0.17)
1 parent f8e7d53 commit d432113

File tree

5 files changed

+134
-9
lines changed

5 files changed

+134
-9
lines changed

src/compose/mod.rs

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ use tracing::{debug, trace};
134134

135135
use crate::{
136136
compose::preprocess::{PreprocessOutput, PreprocessorMetaData},
137-
derive::DerivedModule,
137+
derive::{DerivedModule, RequiredSpecialTypes},
138138
redirect::Redirector,
139139
};
140140

@@ -246,6 +246,8 @@ pub struct ComposableModule {
246246
header_ir: naga::Module,
247247
// character offset of the start of the owned module string
248248
start_offset: usize,
249+
// any required special types for this module
250+
required_special_types: RequiredSpecialTypes,
249251
}
250252

251253
// data used to build a ComposableModule
@@ -1095,12 +1097,22 @@ impl Composer {
10951097
module_builder.set_shader_source(&source_ir, 0);
10961098
header_builder.set_shader_source(&source_ir, 0);
10971099

1100+
let mut special_types: HashSet<&naga::Handle<naga::Type>> = HashSet::new();
1101+
special_types.extend(source_ir.special_types.predeclared_types.values());
1102+
special_types.extend(
1103+
[
1104+
source_ir.special_types.ray_desc.as_ref(),
1105+
source_ir.special_types.ray_intersection.as_ref(),
1106+
]
1107+
.iter()
1108+
.flatten(),
1109+
);
1110+
10981111
let mut owned_types = HashSet::new();
10991112
for (h, ty) in source_ir.types.iter() {
11001113
if let Some(name) = &ty.name {
1101-
// we need to exclude autogenerated struct names, i.e. those that begin with "__"
1102-
// "__" is a reserved prefix for naga so user variables cannot use it.
1103-
if !name.contains(DECORATION_PRE) && !name.starts_with("__") {
1114+
// we exclude any special types, these are added back later
1115+
if !name.contains(DECORATION_PRE) && !special_types.contains(&h) {
11041116
let name = format!("{name}{module_decoration}");
11051117
owned_types.insert(name.clone());
11061118
// copy and rename types
@@ -1165,10 +1177,12 @@ impl Composer {
11651177
}
11661178
}
11671179

1180+
let required_special_types = module_builder.get_required_special_types();
11681181
let module_ir = module_builder.into_module_with_entrypoints();
11691182
let mut header_ir: naga::Module = header_builder.into();
11701183

1171-
if self.validate && create_headers {
1184+
// note: we cannot validate when special types are used, as writeback isn't supported
1185+
if self.validate && create_headers && required_special_types.is_empty() {
11721186
// check that identifiers haven't been renamed
11731187
#[allow(clippy::single_element_loop)]
11741188
for language in [
@@ -1202,6 +1216,7 @@ impl Composer {
12021216
module_ir,
12031217
header_ir,
12041218
start_offset,
1219+
required_special_types,
12051220
};
12061221

12071222
Ok(composable_module)
@@ -1285,6 +1300,8 @@ impl Composer {
12851300
}
12861301
}
12871302

1303+
derived.add_required_special_types(composable.required_special_types);
1304+
12881305
derived.clear_shader_source();
12891306
}
12901307

@@ -1792,11 +1809,15 @@ impl Composer {
17921809
});
17931810
}
17941811

1812+
let required_special_types = derived.get_required_special_types();
1813+
17951814
let mut naga_module = naga::Module {
17961815
entry_points,
17971816
..derived.into()
17981817
};
17991818

1819+
required_special_types.generate(&mut naga_module);
1820+
18001821
// apply overrides
18011822
if !composable.override_functions.is_empty() {
18021823
let mut redirect = Redirector::new(naga_module);

src/compose/test.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,45 @@ mod test {
12631263
output_eq!(wgsl, "tests/expected/atomics.txt");
12641264
}
12651265

1266+
#[test]
1267+
fn test_raycasts() {
1268+
let mut composer =
1269+
Composer::default().with_capabilities(naga::valid::Capabilities::RAY_QUERY);
1270+
1271+
composer
1272+
.add_composable_module(ComposableModuleDescriptor {
1273+
source: include_str!("tests/raycast/mod.wgsl"),
1274+
file_path: "tests/raycast/mod.wgsl",
1275+
..Default::default()
1276+
})
1277+
.unwrap();
1278+
1279+
let _module = composer
1280+
.make_naga_module(NagaModuleDescriptor {
1281+
source: include_str!("tests/raycast/top.wgsl"),
1282+
file_path: "tests/raycast/top.wgsl",
1283+
..Default::default()
1284+
})
1285+
.unwrap();
1286+
1287+
// writeback doesn't work for raycast structures
1288+
// at least we can test that it compiles
1289+
1290+
// let info = composer.create_validator().validate(&module).unwrap();
1291+
// let wgsl = naga::back::wgsl::write_string(
1292+
// &module,
1293+
// &info,
1294+
// naga::back::wgsl::WriterFlags::EXPLICIT_TYPES,
1295+
// )
1296+
// .unwrap();
1297+
1298+
// let mut f = std::fs::File::create("raycast.txt").unwrap();
1299+
// f.write_all(wgsl.as_bytes()).unwrap();
1300+
// drop(f);
1301+
1302+
// output_eq!(wgsl, "tests/expected/raycast.txt");
1303+
}
1304+
12661305
#[test]
12671306
#[should_panic] // Diagnostic filters not yet supported
12681307
fn test_diagnostic_filters() {

src/compose/tests/raycast/mod.wgsl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#define_import_path test_module
2+
3+
@group(0) @binding(0) var tlas: acceleration_structure;
4+
5+
const RAY_NO_CULL = 0xFFu;
6+
7+
fn ray_func() -> RayIntersection {
8+
let ray = RayDesc(0u, RAY_NO_CULL, 0.0001, 100000.0, vec3<f32>(0.0, 0.0, 0.0), vec3<f32>(1.0, 0.0, 0.0));
9+
var rq: ray_query;
10+
rayQueryInitialize(&rq, tlas, ray);
11+
rayQueryProceed(&rq);
12+
return rayQueryGetCommittedIntersection(&rq);
13+
}

src/compose/tests/raycast/top.wgsl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#import test_module
2+
3+
fn main() -> f32 {
4+
let ray = test_module::ray_func();
5+
return ray.t;
6+
}

src/derive.rs

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,35 @@ use naga::{
66
};
77
use std::{cell::RefCell, rc::Rc};
88

9+
#[derive(Clone, Copy, Debug, Default)]
10+
pub struct RequiredSpecialTypes {
11+
pub ray_query: bool,
12+
pub ray_intersection: bool,
13+
}
14+
15+
impl std::ops::BitOrAssign for RequiredSpecialTypes {
16+
fn bitor_assign(&mut self, rhs: Self) {
17+
self.ray_query |= rhs.ray_query;
18+
self.ray_intersection |= rhs.ray_intersection;
19+
}
20+
}
21+
22+
impl RequiredSpecialTypes {
23+
pub fn generate(&self, target: &mut naga::Module) {
24+
if self.ray_query {
25+
target.generate_ray_desc_type();
26+
}
27+
28+
if self.ray_intersection {
29+
target.generate_ray_intersection_type();
30+
}
31+
}
32+
33+
pub fn is_empty(&self) -> bool {
34+
!self.ray_intersection && !self.ray_query
35+
}
36+
}
37+
938
#[derive(Debug, Default)]
1039
pub struct DerivedModule<'a> {
1140
shader: Option<&'a Module>,
@@ -30,6 +59,7 @@ pub struct DerivedModule<'a> {
3059
globals: Arena<GlobalVariable>,
3160
functions: Arena<Function>,
3261
pipeline_overrides: Arena<Override>,
62+
required_special_types: RequiredSpecialTypes,
3363
}
3464

3565
impl<'a> DerivedModule<'a> {
@@ -397,10 +427,14 @@ impl<'a> DerivedModule<'a> {
397427
naga::RayQueryFunction::Initialize {
398428
acceleration_structure,
399429
descriptor,
400-
} => naga::RayQueryFunction::Initialize {
401-
acceleration_structure: map_expr!(acceleration_structure),
402-
descriptor: map_expr!(descriptor),
403-
},
430+
} => {
431+
// record the use of ray queries, to later add to the final module
432+
self.required_special_types.ray_query = true;
433+
naga::RayQueryFunction::Initialize {
434+
acceleration_structure: map_expr!(acceleration_structure),
435+
descriptor: map_expr!(descriptor),
436+
}
437+
}
404438
naga::RayQueryFunction::Proceed { result } => {
405439
naga::RayQueryFunction::Proceed {
406440
result: map_expr!(result),
@@ -689,6 +723,8 @@ impl<'a> DerivedModule<'a> {
689723
}
690724
Expression::RayQueryProceedResult => expr.clone(),
691725
Expression::RayQueryGetIntersection { query, committed } => {
726+
// record use of the intersection type
727+
self.required_special_types.ray_intersection = true;
692728
Expression::RayQueryGetIntersection {
693729
query: map_expr!(query),
694730
committed: *committed,
@@ -831,6 +867,16 @@ impl<'a> DerivedModule<'a> {
831867
self.import_function(func, span)
832868
}
833869

870+
/// get any required special types for this module
871+
pub fn get_required_special_types(&self) -> RequiredSpecialTypes {
872+
self.required_special_types
873+
}
874+
875+
/// add required special types for this module
876+
pub fn add_required_special_types(&mut self, types: RequiredSpecialTypes) {
877+
self.required_special_types |= types;
878+
}
879+
834880
pub fn into_module_with_entrypoints(mut self) -> naga::Module {
835881
let entry_points = self
836882
.shader

0 commit comments

Comments
 (0)