@@ -2,6 +2,7 @@ use crate::{Diagnostic, Error};
22
33use itertools:: Itertools ;
44use wgsl_parse:: syntax:: { ModulePath , PathOrigin , TranslationUnit } ;
5+ use wgsl_types:: inst:: LiteralInstance ;
56
67use std:: {
78 borrow:: Cow ,
@@ -275,10 +276,12 @@ impl<R: Resolver, F: ResolveFn> Resolver for Preprocessor<R, F> {
275276/// This resolver is not thread-safe (not [`Send`] or [`Sync`]).
276277pub struct Router {
277278 mount_points : Vec < ( ModulePath , Box < dyn Resolver > ) > ,
278- fallback : Option < ( ModulePath , Box < dyn Resolver > ) > ,
279+ fallback : Option < Box < dyn Resolver > > ,
279280}
280281
281282/// Dispatches resolution of a module path to sub-resolvers.
283+ ///
284+ /// See documentation in [`Self::mount_resolver`]
282285impl Router {
283286 /// Create a new resolver.
284287 pub fn new ( ) -> Self {
@@ -293,32 +296,41 @@ impl Router {
293296 /// All import paths starting with `prefix` will be dispatched to the resolver with
294297 /// the suffix of the path. The prefix path must have an `Absolute` or `Package`
295298 /// origin and the suffix path will be given an `Absolute` origin.
299+ ///
300+ /// If none of the `prefix`es match, the fallback resolver will be used.
296301 pub fn mount_resolver ( & mut self , prefix : ModulePath , resolver : impl Resolver + ' static ) {
297302 self . mount_points . push ( ( prefix, Box :: new ( resolver) ) ) ;
298303 }
299304
300305 /// Mount a fallback resolver that is used when no other prefix match.
301306 pub fn mount_fallback_resolver ( & mut self , resolver : impl Resolver + ' static ) {
302- self . fallback = Some ( ( ModulePath :: new_root ( ) , Box :: new ( resolver) ) ) ;
307+ self . fallback = Some ( Box :: new ( resolver) ) ;
303308 }
304309
305310 fn route ( & self , path : & ModulePath ) -> Result < ( & dyn Resolver , ModulePath ) , ResolveError > {
306- let ( mount_path, resolver) = self
311+ if let Some ( ( mount_path, resolver) ) = self
307312 . mount_points
308313 . iter ( )
309314 . filter ( |( prefix, _) | path. starts_with ( prefix) )
310315 . max_by_key ( |( prefix, _) | prefix. components . len ( ) )
311- . or ( self . fallback . as_ref ( ) )
312- . ok_or_else ( || E :: ModuleNotFound ( path. clone ( ) , "no mount point" . to_string ( ) ) ) ?;
313-
314- let components = path
315- . components
316- . iter ( )
317- . skip ( mount_path. components . len ( ) )
318- . cloned ( )
319- . collect_vec ( ) ;
320- let suffix = ModulePath :: new ( PathOrigin :: Absolute , components) ;
321- Ok ( ( resolver, suffix) )
316+ {
317+ let components = path
318+ . components
319+ . iter ( )
320+ . skip ( mount_path. components . len ( ) )
321+ . cloned ( )
322+ . collect_vec ( ) ;
323+
324+ let suffix = ModulePath :: new ( PathOrigin :: Absolute , components) ;
325+ Ok ( ( resolver, suffix) )
326+ } else if let Some ( resolver) = & self . fallback {
327+ Ok ( ( resolver, path. clone ( ) ) )
328+ } else {
329+ Err ( E :: ModuleNotFound (
330+ path. clone ( ) ,
331+ "no mount point" . to_string ( ) ,
332+ ) )
333+ }
322334 }
323335}
324336
@@ -463,7 +475,7 @@ impl Resolver for PkgResolver {
463475pub struct StandardResolver {
464476 pkg : PkgResolver ,
465477 files : FileResolver ,
466- constants : HashMap < String , f64 > ,
478+ constants : HashMap < String , LiteralInstance > ,
467479}
468480
469481impl StandardResolver {
@@ -487,18 +499,24 @@ impl StandardResolver {
487499 ///
488500 /// Numeric constants live WESL's special package named `constants`. This package is
489501 /// *virtual*, meaning it doesn't exist on the filesystem. Constants can be accessed
490- /// by importing them: `import constants::MY_CONSTANT;`. All constants are of type
491- /// AbstractFloat, which can be implicitly converted to all scalar types.
492- pub fn add_constant ( & mut self , name : impl ToString , value : f64 ) {
502+ /// by importing them: `import constants::MY_CONSTANT;`.
503+ ///
504+ /// The type is specified by the variant of [`LiteralInstance`].\
505+ /// If specifying a constant that is used with multiple different types or
506+ /// a constant that benefits from precision, like π, use AbstractFloat,
507+ /// which can be implicitly converted to all scalar types.
508+ ///
509+ /// Note: [`LiteralInstance`] implements [`From`] for all standard numeric types
510+ pub fn add_constant ( & mut self , name : impl ToString , value : LiteralInstance ) {
493511 self . constants . insert ( name. to_string ( ) , value) ;
494512 }
495513
514+ /// Generate a module with all declared virtual constants in the resolver
496515 fn generate_constant_module ( & self ) -> String {
497516 self . constants
498517 . iter ( )
499518 . map ( |( name, value) | format ! ( "const {name} = {value};" ) )
500- . format ( "\n " )
501- . to_string ( )
519+ . join ( "\n " )
502520 }
503521}
504522
@@ -577,7 +595,7 @@ mod test {
577595 r. mount_resolver ( "package::bar" . parse ( ) . unwrap ( ) , v2) ;
578596
579597 let mut v3 = VirtualResolver :: new ( ) ;
580- v3. add_module ( "package ::bar" . parse ( ) . unwrap ( ) , "m6" . into ( ) ) ;
598+ v3. add_module ( "foo ::bar" . parse ( ) . unwrap ( ) , "m6" . into ( ) ) ;
581599 r. mount_fallback_resolver ( v3) ;
582600
583601 assert_eq ! ( r. resolve_source( & "package" . parse( ) . unwrap( ) ) . unwrap( ) , "m1" ) ;
@@ -599,4 +617,104 @@ mod test {
599617 "m6"
600618 ) ;
601619 }
620+
621+ #[ test]
622+ /// Test WGSL type casting of virtual constants
623+ fn type_virtual_constants ( ) {
624+ // standard resolver to register some constants
625+ let mut std = StandardResolver :: new ( "." ) ;
626+ // AbstractFloat
627+ std. add_constant ( "TAU" , std:: f64:: consts:: TAU . into ( ) ) ;
628+ // f32
629+ std. add_constant ( "LIGHTING_ANGLE" , 10.0f32 . into ( ) ) ;
630+ // i32
631+ std. add_constant ( "Z_ROTATION" , ( -10i32 ) . into ( ) ) ;
632+ // u32
633+ std. add_constant ( "H" , ( 12u32 ) . into ( ) ) ;
634+ // bool
635+ std. add_constant ( "BRIGHTEN" , ( false ) . into ( ) ) ;
636+
637+ // use virtual resolver for the main module
638+ let mut v = VirtualResolver :: new ( ) ;
639+ v. add_module (
640+ "package::color_math" . parse ( ) . unwrap ( ) ,
641+ // the main module imports constants::TAU and uses it in a context that requires f32,
642+ // therfor it will be cast from AbstractFloat
643+ r#"
644+ import constants::{TAU, H, BRIGHTEN};
645+
646+ fn color_sweep(h: u32) -> f32 {
647+ let color = cos(h + vec3(0.0, 1.0, 2.0) * TAU / 3.0);
648+ if (BRIGHTEN) {
649+ color += 0.1;
650+ }
651+
652+ return color;
653+ }
654+
655+ @fragment
656+ fn fragment() -> @location(0) vec4<f32> {
657+ return vec4(color_sweep(H), color_sweep(H + 0.1), color_sweep(H + 0.2), 1.0);
658+ }
659+ "#
660+ . into ( ) ,
661+ ) ;
662+
663+ // route package imports whose prefix is "constants" to the StandardResolver
664+ // and absolute module paths to the VirtualResolver.
665+ let mut r = Router :: new ( ) ;
666+ r. mount_resolver ( ModulePath :: new_root ( ) , v) ;
667+ r. mount_fallback_resolver ( std) ;
668+
669+ // compile to test imports and casting
670+ crate :: Wesl :: new ( "." )
671+ . set_custom_resolver ( r)
672+ . compile ( & "package::color_math" . parse ( ) . unwrap ( ) )
673+ . unwrap ( ) ;
674+ }
675+
676+ #[ test]
677+ /// Test resolving virtual constants from `add_constant`
678+ fn resolve_virtual_constants ( ) {
679+ // todo impl `add_constant` for VirtualResolver then use that
680+ let mut sr = StandardResolver :: new ( "." ) ;
681+
682+ // add math constants
683+ sr. add_constant ( "PI" , LiteralInstance :: from ( std:: f64:: consts:: PI ) ) ;
684+ sr. add_constant ( "E" , LiteralInstance :: from ( std:: f64:: consts:: E ) ) ;
685+ // add misc constants
686+ sr. add_constant ( "NEG_2" , LiteralInstance :: from ( -2i32 ) ) ;
687+ sr. add_constant ( "ONE" , LiteralInstance :: from ( 1u32 ) ) ;
688+ sr. add_constant ( "F32_MAX" , LiteralInstance :: from ( f32:: MAX ) ) ;
689+ sr. add_constant ( "IS_HEAVY" , LiteralInstance :: from ( false ) ) ;
690+ sr. add_constant (
691+ "NUM_CONSTS" ,
692+ LiteralInstance :: from ( sr. constants . len ( ) as i64 ) ,
693+ ) ;
694+
695+ // generate the virtual module
696+ let generated = sr. generate_constant_module ( ) ;
697+ // test that it contains the consts with correct values
698+ assert ! ( generated. contains( & format!( "const PI = {:?};" , std:: f64 :: consts:: PI ) ) ) ;
699+ assert ! ( generated. contains( & format!( "const E = {:?};" , std:: f64 :: consts:: E ) ) ) ;
700+ assert ! ( generated. contains( "const NEG_2 = -2i;" ) ) ;
701+ assert ! ( generated. contains( "const ONE = 1u;" ) ) ;
702+ assert ! ( generated. contains( & format!( "const F32_MAX = {}f;" , f32 :: MAX ) ) ) ;
703+ assert ! ( generated. contains( "const IS_HEAVY = false;" ) ) ;
704+ assert ! ( generated. contains( & format!(
705+ "const NUM_CONSTS = {};" ,
706+ ( sr. constants. len( ) as i64 ) - 1
707+ ) ) ) ;
708+
709+ // resolve the package path with the origin `constants`,
710+ // the source of which should be the same as the generated module
711+ let src_root = sr. resolve_source ( & "constants" . parse ( ) . unwrap ( ) ) . unwrap ( ) ;
712+ assert_eq ! ( src_root, generated) ;
713+
714+ // resolving a path with components should return the same
715+ let src_comp = sr
716+ . resolve_source ( & "constants::PI" . parse ( ) . unwrap ( ) )
717+ . unwrap ( ) ;
718+ assert_eq ! ( src_comp, generated) ;
719+ }
602720}
0 commit comments