diff --git a/src/scene/shader-lib/glsl/chunks/gsplat/vert/gsplatCommon.js b/src/scene/shader-lib/glsl/chunks/gsplat/vert/gsplatCommon.js index 54bd741ef32..3171b6a0dec 100644 --- a/src/scene/shader-lib/glsl/chunks/gsplat/vert/gsplatCommon.js +++ b/src/scene/shader-lib/glsl/chunks/gsplat/vert/gsplatCommon.js @@ -45,17 +45,33 @@ struct SplatCorner { #endif }; +#if SH_BANDS == 1 + #define SH_COEFFS 3 +#elif SH_BANDS == 2 + #define SH_COEFFS 8 +#elif SH_BANDS == 3 + #define SH_COEFFS 15 +#else + #define SH_COEFFS 0 +#endif + #if GSPLAT_COMPRESSED_DATA == true #include "gsplatCompressedDataVS" - #include "gsplatCompressedSHVS" + #if SH_COEFFS > 0 + #include "gsplatCompressedSHVS" + #endif #elif GSPLAT_SOGS_DATA == true #include "gsplatSogsDataVS" #include "gsplatSogsColorVS" - #include "gsplatSogsSHVS" + #if SH_COEFFS > 0 + #include "gsplatSogsSHVS" + #endif #else #include "gsplatDataVS" #include "gsplatColorVS" - #include "gsplatSHVS" + #if SH_COEFFS > 0 + #include "gsplatSHVS" + #endif #endif #include "gsplatSourceVS" @@ -97,15 +113,7 @@ void clipCorner(inout SplatCorner corner, float alpha) { // see https://github.com/graphdeco-inria/gaussian-splatting/blob/main/utils/sh_utils.py vec3 evalSH(in SplatSource source, in vec3 dir) { - #if SH_BANDS > 0 - #if SH_BANDS == 1 - vec3 sh[3]; - #elif SH_BANDS == 2 - vec3 sh[8]; - #elif SH_BANDS == 3 - vec3 sh[15]; - #endif - #endif + vec3 sh[SH_COEFFS]; // read sh coefficients float scale; diff --git a/src/scene/shader-lib/glsl/chunks/gsplat/vert/gsplatSogsSH.js b/src/scene/shader-lib/glsl/chunks/gsplat/vert/gsplatSogsSH.js index cb7383a9e85..c9f584bdf70 100644 --- a/src/scene/shader-lib/glsl/chunks/gsplat/vert/gsplatSogsSH.js +++ b/src/scene/shader-lib/glsl/chunks/gsplat/vert/gsplatSogsSH.js @@ -5,15 +5,17 @@ uniform highp sampler2D sh_centroids; uniform float shN_mins; uniform float shN_maxs; -void readSHData(in SplatSource source, out vec3 sh[15], out float scale) { +// To support each SH degree, readSHData is overloaded based on the SH vector depth + +void readSHData(in SplatSource source, out vec3 sh[SH_COEFFS], out float scale) { // extract spherical harmonics palette index ivec2 t = ivec2(texelFetch(sh_labels, source.uv, 0).xy * 255.0); int n = t.x + t.y * 256; - int u = (n % 64) * 15; + int u = (n % 64) * SH_COEFFS; int v = n / 64; - // calculate offset into the centroids texture and read 15 consecutive texels - for (int i = 0; i < 15; i++) { + // calculate offset into the centroids texture and read consecutive texels + for (int i = 0; i < SH_COEFFS; i++) { sh[i] = mix(vec3(shN_mins), vec3(shN_maxs), texelFetch(sh_centroids, ivec2(u + i, v), 0).xyz); } diff --git a/src/scene/shader-lib/wgsl/chunks/gsplat/vert/gsplatCommon.js b/src/scene/shader-lib/wgsl/chunks/gsplat/vert/gsplatCommon.js index 5fedb04e20a..6b99e0fac68 100644 --- a/src/scene/shader-lib/wgsl/chunks/gsplat/vert/gsplatCommon.js +++ b/src/scene/shader-lib/wgsl/chunks/gsplat/vert/gsplatCommon.js @@ -39,17 +39,33 @@ fn quatToMat3(R: vec4) -> mat3x3 { ); } -#if GSPLAT_COMPRESSED_DATA == true +#if SH_BANDS == 1 + const SH_COEFFS: i32 = 3; +#elif SH_BANDS == 2 + const SH_COEFFS: i32 = 8; +#elif SH_BANDS == 3 + const SH_COEFFS: i32 = 15; +#else + const SH_COEFFS: i32 = 0; +#endif + +#if GSPLAT_COMPRESSED_DATA #include "gsplatCompressedDataVS" - #include "gsplatCompressedSHVS" -#elif GSPLAT_SOGS_DATA == true + #if SH_BANDS > 0 + #include "gsplatCompressedSHVS" + #endif +#elif GSPLAT_SOGS_DATA #include "gsplatSogsDataVS" #include "gsplatSogsColorVS" - #include "gsplatSogsSHVS" + #if SH_BANDS > 0 + #include "gsplatSogsSHVS" + #endif #else #include "gsplatDataVS" #include "gsplatColorVS" - #include "gsplatSHVS" + #if SH_BANDS > 0 + #include "gsplatSHVS" + #endif #endif #include "gsplatSourceVS" @@ -90,15 +106,7 @@ fn clipCorner(corner: ptr, alpha: f32) { // see https://github.com/graphdeco-inria/gaussian-splatting/blob/main/utils/sh_utils.py fn evalSH(source: ptr, dir: vec3f) -> vec3f { - #if SH_BANDS > 0 - #if SH_BANDS == 1 - var sh: array; - #elif SH_BANDS == 2 - var sh: array; - #elif SH_BANDS == 3 - var sh: array; - #endif - #endif + var sh: array; var scale: f32; readSHData(source, &sh, &scale); diff --git a/src/scene/shader-lib/wgsl/chunks/gsplat/vert/gsplatSogsColor.js b/src/scene/shader-lib/wgsl/chunks/gsplat/vert/gsplatSogsColor.js index 2528fe4ce63..669ce7d382c 100644 --- a/src/scene/shader-lib/wgsl/chunks/gsplat/vert/gsplatSogsColor.js +++ b/src/scene/shader-lib/wgsl/chunks/gsplat/vert/gsplatSogsColor.js @@ -7,7 +7,7 @@ uniform sh0_maxs: vec4f; const SH_C0: f32 = 0.28209479177387814; fn readColor(source: ptr) -> vec4f { - let clr: vec4f = mix(sh0_mins, sh0_maxs, textureLoad(sh0, source.uv, 0)); + let clr: vec4f = mix(uniform.sh0_mins, uniform.sh0_maxs, textureLoad(sh0, source.uv, 0)); return vec4f(vec3f(0.5) + clr.xyz * SH_C0, 1.0 / (1.0 + exp(-clr.w))); } `; diff --git a/src/scene/shader-lib/wgsl/chunks/gsplat/vert/gsplatSogsData.js b/src/scene/shader-lib/wgsl/chunks/gsplat/vert/gsplatSogsData.js index 354645ee708..30adcb4adf2 100644 --- a/src/scene/shader-lib/wgsl/chunks/gsplat/vert/gsplatSogsData.js +++ b/src/scene/shader-lib/wgsl/chunks/gsplat/vert/gsplatSogsData.js @@ -7,9 +7,6 @@ var scales: texture_2d; uniform means_mins: vec3f; uniform means_maxs: vec3f; -uniform quats_mins: vec3f; -uniform quats_maxs: vec3f; - uniform scales_mins: vec3f; uniform scales_maxs: vec3f; @@ -19,7 +16,7 @@ fn readCenter(source: ptr) -> vec3f { let l: vec3f = textureLoad(means_l, source.uv, 0).xyz; let n: vec3f = (l * 255.0 + u * 255.0 * 256.0) / 65535.0; - let v: vec3f = mix(means_mins, means_maxs, n); + let v: vec3f = mix(uniform.means_mins, uniform.means_maxs, n); return sign(v) * (exp(abs(v)) - 1.0); } @@ -46,7 +43,7 @@ fn readCovariance(source: ptr, covA_ptr: ptr; uniform shN_mins: f32; uniform shN_maxs: f32; -fn readSHData(source: ptr, sh: ptr>, scale: ptr) { +fn readSHData(source: ptr, sh: ptr>, scale: ptr) { // extract spherical harmonics palette index let t: vec2 = vec2(textureLoad(sh_labels, source.uv, 0).xy * 255.0); let n: i32 = t.x + t.y * 256; - let u: i32 = (n % 64) * 15; + let u: i32 = (n % 64) * SH_COEFFS; let v: i32 = n / 64; - // calculate offset into the centroids texture and read 15 consecutive texels - for (var i: i32 = 0; i < 15; i = i + 1) { - sh[i] = mix(vec3f(shN_mins), vec3f(shN_maxs), textureLoad(sh_centroids, vec2(u + i, v), 0).xyz); + // calculate offset into the centroids texture and read consecutive texels + for (var i: i32 = 0; i < SH_COEFFS; i = i + 1) { + sh[i] = mix(vec3f(uniform.shN_mins), vec3f(uniform.shN_maxs), textureLoad(sh_centroids, vec2(u + i, v), 0).xyz); } *scale = 1.0; diff --git a/src/scene/shader-lib/wgsl/collections/shader-chunks-wgsl.js b/src/scene/shader-lib/wgsl/collections/shader-chunks-wgsl.js index e415d33f2ff..58d774cc3c2 100644 --- a/src/scene/shader-lib/wgsl/collections/shader-chunks-wgsl.js +++ b/src/scene/shader-lib/wgsl/collections/shader-chunks-wgsl.js @@ -48,6 +48,9 @@ import gsplatColorVS from '../chunks/gsplat/vert/gsplatColor.js'; import gsplatCommonVS from '../chunks/gsplat/vert/gsplatCommon.js'; import gsplatCompressedDataVS from '../chunks/gsplat/vert/gsplatCompressedData.js'; import gsplatCompressedSHVS from '../chunks/gsplat/vert/gsplatCompressedSH.js'; +import gsplatSogsColorVS from '../chunks/gsplat/vert/gsplatSogsColor.js'; +import gsplatSogsDataVS from '../chunks/gsplat/vert/gsplatSogsData.js'; +import gsplatSogsSHVS from '../chunks/gsplat/vert/gsplatSogsSH.js'; import gsplatCornerVS from '../chunks/gsplat/vert/gsplatCorner.js'; import gsplatDataVS from '../chunks/gsplat/vert/gsplatData.js'; import gsplatOutputVS from '../chunks/gsplat/vert/gsplatOutput.js'; @@ -221,6 +224,9 @@ const shaderChunksWGSL = { gsplatCommonVS, gsplatCompressedDataVS, gsplatCompressedSHVS, + gsplatSogsColorVS, + gsplatSogsDataVS, + gsplatSogsSHVS, gsplatDataVS, gsplatOutputVS, gsplatPS,