@@ -6,6 +6,35 @@ use naga::{
66} ;
77use std:: { cell:: RefCell , rc:: Rc } ;
88
9+ #[ derive( Clone , Copy , Debug , Default ) ]
10+ pub struct RequiredSpecialTypes {
11+ pub ray_query : bool ,
12+ pub ray_intersection : bool ,
13+ }
14+
15+ impl std:: ops:: BitOrAssign for RequiredSpecialTypes {
16+ fn bitor_assign ( & mut self , rhs : Self ) {
17+ self . ray_query |= rhs. ray_query ;
18+ self . ray_intersection |= rhs. ray_intersection ;
19+ }
20+ }
21+
22+ impl RequiredSpecialTypes {
23+ pub fn generate ( & self , target : & mut naga:: Module ) {
24+ if self . ray_query {
25+ target. generate_ray_desc_type ( ) ;
26+ }
27+
28+ if self . ray_intersection {
29+ target. generate_ray_intersection_type ( ) ;
30+ }
31+ }
32+
33+ pub fn is_empty ( & self ) -> bool {
34+ !self . ray_intersection && !self . ray_query
35+ }
36+ }
37+
938#[ derive( Debug , Default ) ]
1039pub struct DerivedModule < ' a > {
1140 shader : Option < & ' a Module > ,
@@ -30,6 +59,7 @@ pub struct DerivedModule<'a> {
3059 globals : Arena < GlobalVariable > ,
3160 functions : Arena < Function > ,
3261 pipeline_overrides : Arena < Override > ,
62+ required_special_types : RequiredSpecialTypes ,
3363}
3464
3565impl < ' a > DerivedModule < ' a > {
@@ -397,10 +427,14 @@ impl<'a> DerivedModule<'a> {
397427 naga:: RayQueryFunction :: Initialize {
398428 acceleration_structure,
399429 descriptor,
400- } => naga:: RayQueryFunction :: Initialize {
401- acceleration_structure : map_expr ! ( acceleration_structure) ,
402- descriptor : map_expr ! ( descriptor) ,
403- } ,
430+ } => {
431+ // record the use of ray queries, to later add to the final module
432+ self . required_special_types . ray_query = true ;
433+ naga:: RayQueryFunction :: Initialize {
434+ acceleration_structure : map_expr ! ( acceleration_structure) ,
435+ descriptor : map_expr ! ( descriptor) ,
436+ }
437+ }
404438 naga:: RayQueryFunction :: Proceed { result } => {
405439 naga:: RayQueryFunction :: Proceed {
406440 result : map_expr ! ( result) ,
@@ -689,6 +723,8 @@ impl<'a> DerivedModule<'a> {
689723 }
690724 Expression :: RayQueryProceedResult => expr. clone ( ) ,
691725 Expression :: RayQueryGetIntersection { query, committed } => {
726+ // record use of the intersection type
727+ self . required_special_types . ray_intersection = true ;
692728 Expression :: RayQueryGetIntersection {
693729 query : map_expr ! ( query) ,
694730 committed : * committed,
@@ -831,6 +867,16 @@ impl<'a> DerivedModule<'a> {
831867 self . import_function ( func, span)
832868 }
833869
870+ /// get any required special types for this module
871+ pub fn get_required_special_types ( & self ) -> RequiredSpecialTypes {
872+ self . required_special_types
873+ }
874+
875+ /// add required special types for this module
876+ pub fn add_required_special_types ( & mut self , types : RequiredSpecialTypes ) {
877+ self . required_special_types |= types;
878+ }
879+
834880 pub fn into_module_with_entrypoints ( mut self ) -> naga:: Module {
835881 let entry_points = self
836882 . shader
0 commit comments