Skip to content

Commit 6dd3c87

Browse files
authored
Merge pull request #183 from TimTheBig/main
Make resolver generated constants strongly typed
2 parents 0747880 + b925155 commit 6dd3c87

4 files changed

Lines changed: 151 additions & 32 deletions

File tree

crates/wesl-test/bevy/tonemapping_test_patterns.wgsl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,19 @@ import bevy::render::maths::PI;
88
@if(TONEMAP_IN_SHADER)
99
import bevy::core_pipeline::tonemapping::tone_mapping;
1010

11-
// Sweep across hues on y axis with value from 0.0 to +15EV across x axis
11+
// Sweep across hues on y axis with value from 0.0 to +15EV across x axis
1212
// quantized into 24 steps for both axis.
1313
fn color_sweep(uv_input: vec2<f32>) -> vec3<f32> {
1414
var uv = uv_input;
1515
let steps = 24.0;
1616
uv.y = uv.y * (1.0 + 1.0 / steps);
1717
let ratio = 2.0;
18-
18+
1919
let h = PI * 2.0 * floor(1.0 + steps * uv.y) / steps;
2020
let L = floor(uv.x * steps * ratio) / (steps * ratio) - 0.5;
21-
21+
2222
var color = vec3(0.0);
23-
if uv.y < 1.0 {
23+
if uv.y < 1.0 {
2424
color = cos(h + vec3(0.0, 1.0, 2.0) * PI * 2.0 / 3.0);
2525
let maxRGB = max(color.r, max(color.g, color.b));
2626
let minRGB = min(color.r, min(color.g, color.b));

crates/wesl-test/tests/testsuite.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -432,11 +432,11 @@ pub fn bevy_case(path: PathBuf) -> Result<(), libtest_mimic::Failed> {
432432
compiler
433433
.add_package(&bevy_wgsl::PACKAGE)
434434
.add_constants([
435-
("MAX_CASCADES_PER_LIGHT", 10.0),
436-
("MAX_DIRECTIONAL_LIGHTS", 10.0),
437-
("PER_OBJECT_BUFFER_BATCH_SIZE", 10.0),
438-
("TONEMAPPING_LUT_TEXTURE_BINDING_INDEX", 10.0),
439-
("TONEMAPPING_LUT_SAMPLER_BINDING_INDEX", 10.0),
435+
("MAX_CASCADES_PER_LIGHT", 10u32.into()),
436+
("MAX_DIRECTIONAL_LIGHTS", 10.into()),
437+
("PER_OBJECT_BUFFER_BATCH_SIZE", 10.into()),
438+
("TONEMAPPING_LUT_TEXTURE_BINDING_INDEX", 10.into()),
439+
("TONEMAPPING_LUT_SAMPLER_BINDING_INDEX", 10.into()),
440440
])
441441
.set_options(CompileOptions {
442442
strip: false,

crates/wesl/src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ use wgsl_parse::{
5757
SyntaxNode,
5858
syntax::{Ident, TranslationUnit},
5959
};
60+
use wgsl_types::inst::LiteralInstance;
6061

6162
/// Compilation options. Used in [`compile`] and [`Wesl::set_options`].
6263
#[derive(Clone, Debug, PartialEq, Eq)]
@@ -301,7 +302,7 @@ impl Wesl<StandardResolver> {
301302
/// Add a const-declaration to the special `constants` module.
302303
///
303304
/// See [`StandardResolver::add_constant`].
304-
pub fn add_constant(&mut self, name: impl ToString, value: f64) -> &mut Self {
305+
pub fn add_constant(&mut self, name: impl ToString, value: LiteralInstance) -> &mut Self {
305306
self.resolver.add_constant(name, value);
306307
self
307308
}
@@ -311,7 +312,7 @@ impl Wesl<StandardResolver> {
311312
/// See [`StandardResolver::add_constant`].
312313
pub fn add_constants(
313314
&mut self,
314-
constants: impl IntoIterator<Item = (impl ToString, f64)>,
315+
constants: impl IntoIterator<Item = (impl ToString, LiteralInstance)>,
315316
) -> &mut Self {
316317
for (name, value) in constants {
317318
self.resolver.add_constant(name, value);

crates/wesl/src/resolve.rs

Lines changed: 139 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use crate::{Diagnostic, Error};
22

33
use itertools::Itertools;
44
use wgsl_parse::syntax::{ModulePath, PathOrigin, TranslationUnit};
5+
use wgsl_types::inst::LiteralInstance;
56

67
use std::{
78
borrow::Cow,
@@ -275,10 +276,12 @@ impl<R: Resolver, F: ResolveFn> Resolver for Preprocessor<R, F> {
275276
/// This resolver is not thread-safe (not [`Send`] or [`Sync`]).
276277
pub struct Router {
277278
mount_points: Vec<(ModulePath, Box<dyn Resolver>)>,
278-
fallback: Option<(ModulePath, Box<dyn Resolver>)>,
279+
fallback: Option<Box<dyn Resolver>>,
279280
}
280281

281282
/// Dispatches resolution of a module path to sub-resolvers.
283+
///
284+
/// See documentation in [`Self::mount_resolver`]
282285
impl Router {
283286
/// Create a new resolver.
284287
pub fn new() -> Self {
@@ -293,32 +296,41 @@ impl Router {
293296
/// All import paths starting with `prefix` will be dispatched to the resolver with
294297
/// the suffix of the path. The prefix path must have an `Absolute` or `Package`
295298
/// origin and the suffix path will be given an `Absolute` origin.
299+
///
300+
/// If none of the `prefix`es match, the fallback resolver will be used.
296301
pub fn mount_resolver(&mut self, prefix: ModulePath, resolver: impl Resolver + 'static) {
297302
self.mount_points.push((prefix, Box::new(resolver)));
298303
}
299304

300305
/// Mount a fallback resolver that is used when no other prefix match.
301306
pub fn mount_fallback_resolver(&mut self, resolver: impl Resolver + 'static) {
302-
self.fallback = Some((ModulePath::new_root(), Box::new(resolver)));
307+
self.fallback = Some(Box::new(resolver));
303308
}
304309

305310
fn route(&self, path: &ModulePath) -> Result<(&dyn Resolver, ModulePath), ResolveError> {
306-
let (mount_path, resolver) = self
311+
if let Some((mount_path, resolver)) = self
307312
.mount_points
308313
.iter()
309314
.filter(|(prefix, _)| path.starts_with(prefix))
310315
.max_by_key(|(prefix, _)| prefix.components.len())
311-
.or(self.fallback.as_ref())
312-
.ok_or_else(|| E::ModuleNotFound(path.clone(), "no mount point".to_string()))?;
313-
314-
let components = path
315-
.components
316-
.iter()
317-
.skip(mount_path.components.len())
318-
.cloned()
319-
.collect_vec();
320-
let suffix = ModulePath::new(PathOrigin::Absolute, components);
321-
Ok((resolver, suffix))
316+
{
317+
let components = path
318+
.components
319+
.iter()
320+
.skip(mount_path.components.len())
321+
.cloned()
322+
.collect_vec();
323+
324+
let suffix = ModulePath::new(PathOrigin::Absolute, components);
325+
Ok((resolver, suffix))
326+
} else if let Some(resolver) = &self.fallback {
327+
Ok((resolver, path.clone()))
328+
} else {
329+
Err(E::ModuleNotFound(
330+
path.clone(),
331+
"no mount point".to_string(),
332+
))
333+
}
322334
}
323335
}
324336

@@ -463,7 +475,7 @@ impl Resolver for PkgResolver {
463475
pub struct StandardResolver {
464476
pkg: PkgResolver,
465477
files: FileResolver,
466-
constants: HashMap<String, f64>,
478+
constants: HashMap<String, LiteralInstance>,
467479
}
468480

469481
impl StandardResolver {
@@ -487,18 +499,24 @@ impl StandardResolver {
487499
///
488500
/// Numeric constants live WESL's special package named `constants`. This package is
489501
/// *virtual*, meaning it doesn't exist on the filesystem. Constants can be accessed
490-
/// by importing them: `import constants::MY_CONSTANT;`. All constants are of type
491-
/// AbstractFloat, which can be implicitly converted to all scalar types.
492-
pub fn add_constant(&mut self, name: impl ToString, value: f64) {
502+
/// by importing them: `import constants::MY_CONSTANT;`.
503+
///
504+
/// The type is specified by the variant of [`LiteralInstance`].\
505+
/// If specifying a constant that is used with multiple different types or
506+
/// a constant that benefits from precision, like π, use AbstractFloat,
507+
/// which can be implicitly converted to all scalar types.
508+
///
509+
/// Note: [`LiteralInstance`] implements [`From`] for all standard numeric types
510+
pub fn add_constant(&mut self, name: impl ToString, value: LiteralInstance) {
493511
self.constants.insert(name.to_string(), value);
494512
}
495513

514+
/// Generate a module with all declared virtual constants in the resolver
496515
fn generate_constant_module(&self) -> String {
497516
self.constants
498517
.iter()
499518
.map(|(name, value)| format!("const {name} = {value};"))
500-
.format("\n")
501-
.to_string()
519+
.join("\n")
502520
}
503521
}
504522

@@ -577,7 +595,7 @@ mod test {
577595
r.mount_resolver("package::bar".parse().unwrap(), v2);
578596

579597
let mut v3 = VirtualResolver::new();
580-
v3.add_module("package::bar".parse().unwrap(), "m6".into());
598+
v3.add_module("foo::bar".parse().unwrap(), "m6".into());
581599
r.mount_fallback_resolver(v3);
582600

583601
assert_eq!(r.resolve_source(&"package".parse().unwrap()).unwrap(), "m1");
@@ -599,4 +617,104 @@ mod test {
599617
"m6"
600618
);
601619
}
620+
621+
#[test]
622+
/// Test WGSL type casting of virtual constants
623+
fn type_virtual_constants() {
624+
// standard resolver to register some constants
625+
let mut std = StandardResolver::new(".");
626+
// AbstractFloat
627+
std.add_constant("TAU", std::f64::consts::TAU.into());
628+
// f32
629+
std.add_constant("LIGHTING_ANGLE", 10.0f32.into());
630+
// i32
631+
std.add_constant("Z_ROTATION", (-10i32).into());
632+
// u32
633+
std.add_constant("H", (12u32).into());
634+
// bool
635+
std.add_constant("BRIGHTEN", (false).into());
636+
637+
// use virtual resolver for the main module
638+
let mut v = VirtualResolver::new();
639+
v.add_module(
640+
"package::color_math".parse().unwrap(),
641+
// the main module imports constants::TAU and uses it in a context that requires f32,
642+
// therfor it will be cast from AbstractFloat
643+
r#"
644+
import constants::{TAU, H, BRIGHTEN};
645+
646+
fn color_sweep(h: u32) -> f32 {
647+
let color = cos(h + vec3(0.0, 1.0, 2.0) * TAU / 3.0);
648+
if (BRIGHTEN) {
649+
color += 0.1;
650+
}
651+
652+
return color;
653+
}
654+
655+
@fragment
656+
fn fragment() -> @location(0) vec4<f32> {
657+
return vec4(color_sweep(H), color_sweep(H + 0.1), color_sweep(H + 0.2), 1.0);
658+
}
659+
"#
660+
.into(),
661+
);
662+
663+
// route package imports whose prefix is "constants" to the StandardResolver
664+
// and absolute module paths to the VirtualResolver.
665+
let mut r = Router::new();
666+
r.mount_resolver(ModulePath::new_root(), v);
667+
r.mount_fallback_resolver(std);
668+
669+
// compile to test imports and casting
670+
crate::Wesl::new(".")
671+
.set_custom_resolver(r)
672+
.compile(&"package::color_math".parse().unwrap())
673+
.unwrap();
674+
}
675+
676+
#[test]
677+
/// Test resolving virtual constants from `add_constant`
678+
fn resolve_virtual_constants() {
679+
// todo impl `add_constant` for VirtualResolver then use that
680+
let mut sr = StandardResolver::new(".");
681+
682+
// add math constants
683+
sr.add_constant("PI", LiteralInstance::from(std::f64::consts::PI));
684+
sr.add_constant("E", LiteralInstance::from(std::f64::consts::E));
685+
// add misc constants
686+
sr.add_constant("NEG_2", LiteralInstance::from(-2i32));
687+
sr.add_constant("ONE", LiteralInstance::from(1u32));
688+
sr.add_constant("F32_MAX", LiteralInstance::from(f32::MAX));
689+
sr.add_constant("IS_HEAVY", LiteralInstance::from(false));
690+
sr.add_constant(
691+
"NUM_CONSTS",
692+
LiteralInstance::from(sr.constants.len() as i64),
693+
);
694+
695+
// generate the virtual module
696+
let generated = sr.generate_constant_module();
697+
// test that it contains the consts with correct values
698+
assert!(generated.contains(&format!("const PI = {:?};", std::f64::consts::PI)));
699+
assert!(generated.contains(&format!("const E = {:?};", std::f64::consts::E)));
700+
assert!(generated.contains("const NEG_2 = -2i;"));
701+
assert!(generated.contains("const ONE = 1u;"));
702+
assert!(generated.contains(&format!("const F32_MAX = {}f;", f32::MAX)));
703+
assert!(generated.contains("const IS_HEAVY = false;"));
704+
assert!(generated.contains(&format!(
705+
"const NUM_CONSTS = {};",
706+
(sr.constants.len() as i64) - 1
707+
)));
708+
709+
// resolve the package path with the origin `constants`,
710+
// the source of which should be the same as the generated module
711+
let src_root = sr.resolve_source(&"constants".parse().unwrap()).unwrap();
712+
assert_eq!(src_root, generated);
713+
714+
// resolving a path with components should return the same
715+
let src_comp = sr
716+
.resolve_source(&"constants::PI".parse().unwrap())
717+
.unwrap();
718+
assert_eq!(src_comp, generated);
719+
}
602720
}

0 commit comments

Comments
 (0)