Skip to content

Fix SOGS rendering scenes with SH 1 or 2 #7703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions src/scene/shader-lib/glsl/chunks/gsplat/vert/gsplatCommon.js
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,33 @@ struct SplatCorner {
#endif
};

#if SH_BANDS == 1
#define SH_COEFFS 3
#elif SH_BANDS == 2
#define SH_COEFFS 8
#elif SH_BANDS == 3
#define SH_COEFFS 15
#else
#define SH_COEFFS 0
#endif

#if GSPLAT_COMPRESSED_DATA == true
#include "gsplatCompressedDataVS"
#include "gsplatCompressedSHVS"
#if SH_COEFFS > 0
#include "gsplatCompressedSHVS"
#endif
#elif GSPLAT_SOGS_DATA == true
#include "gsplatSogsDataVS"
#include "gsplatSogsColorVS"
#include "gsplatSogsSHVS"
#if SH_COEFFS > 0
#include "gsplatSogsSHVS"
#endif
#else
#include "gsplatDataVS"
#include "gsplatColorVS"
#include "gsplatSHVS"
#if SH_COEFFS > 0
#include "gsplatSHVS"
#endif
#endif

#include "gsplatSourceVS"
Expand Down Expand Up @@ -97,15 +113,7 @@ void clipCorner(inout SplatCorner corner, float alpha) {
// see https://github.com/graphdeco-inria/gaussian-splatting/blob/main/utils/sh_utils.py
vec3 evalSH(in SplatSource source, in vec3 dir) {

#if SH_BANDS > 0
#if SH_BANDS == 1
vec3 sh[3];
#elif SH_BANDS == 2
vec3 sh[8];
#elif SH_BANDS == 3
vec3 sh[15];
#endif
#endif
vec3 sh[SH_COEFFS];

// read sh coefficients
float scale;
Expand Down
10 changes: 6 additions & 4 deletions src/scene/shader-lib/glsl/chunks/gsplat/vert/gsplatSogsSH.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@ uniform highp sampler2D sh_centroids;
uniform float shN_mins;
uniform float shN_maxs;

void readSHData(in SplatSource source, out vec3 sh[15], out float scale) {
// To support each SH degree, readSHData is overloaded based on the SH vector depth

void readSHData(in SplatSource source, out vec3 sh[SH_COEFFS], out float scale) {
// extract spherical harmonics palette index
ivec2 t = ivec2(texelFetch(sh_labels, source.uv, 0).xy * 255.0);
int n = t.x + t.y * 256;
int u = (n % 64) * 15;
int u = (n % 64) * SH_COEFFS;
int v = n / 64;

// calculate offset into the centroids texture and read 15 consecutive texels
for (int i = 0; i < 15; i++) {
// calculate offset into the centroids texture and read consecutive texels
for (int i = 0; i < SH_COEFFS; i++) {
sh[i] = mix(vec3(shN_mins), vec3(shN_maxs), texelFetch(sh_centroids, ivec2(u + i, v), 0).xyz);
}

Expand Down
36 changes: 22 additions & 14 deletions src/scene/shader-lib/wgsl/chunks/gsplat/vert/gsplatCommon.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,33 @@ fn quatToMat3(R: vec4<f32>) -> mat3x3<f32> {
);
}

#if GSPLAT_COMPRESSED_DATA == true
#if SH_BANDS == 1
const SH_COEFFS: i32 = 3;
#elif SH_BANDS == 2
const SH_COEFFS: i32 = 8;
#elif SH_BANDS == 3
const SH_COEFFS: i32 = 15;
#else
const SH_COEFFS: i32 = 0;
#endif

#if GSPLAT_COMPRESSED_DATA
#include "gsplatCompressedDataVS"
#include "gsplatCompressedSHVS"
#elif GSPLAT_SOGS_DATA == true
#if SH_BANDS > 0
#include "gsplatCompressedSHVS"
#endif
#elif GSPLAT_SOGS_DATA
#include "gsplatSogsDataVS"
#include "gsplatSogsColorVS"
#include "gsplatSogsSHVS"
#if SH_BANDS > 0
#include "gsplatSogsSHVS"
#endif
#else
#include "gsplatDataVS"
#include "gsplatColorVS"
#include "gsplatSHVS"
#if SH_BANDS > 0
#include "gsplatSHVS"
#endif
#endif

#include "gsplatSourceVS"
Expand Down Expand Up @@ -90,15 +106,7 @@ fn clipCorner(corner: ptr<function, SplatCorner>, alpha: f32) {
// see https://github.com/graphdeco-inria/gaussian-splatting/blob/main/utils/sh_utils.py
fn evalSH(source: ptr<function, SplatSource>, dir: vec3f) -> vec3f {

#if SH_BANDS > 0
#if SH_BANDS == 1
var sh: array<vec3f, 3>;
#elif SH_BANDS == 2
var sh: array<vec3f, 8>;
#elif SH_BANDS == 3
var sh: array<vec3f, 15>;
#endif
#endif
var sh: array<vec3f, SH_COEFFS>;

var scale: f32;
readSHData(source, &sh, &scale);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ uniform sh0_maxs: vec4f;
const SH_C0: f32 = 0.28209479177387814;

fn readColor(source: ptr<function, SplatSource>) -> vec4f {
let clr: vec4f = mix(sh0_mins, sh0_maxs, textureLoad(sh0, source.uv, 0));
let clr: vec4f = mix(uniform.sh0_mins, uniform.sh0_maxs, textureLoad(sh0, source.uv, 0));
return vec4f(vec3f(0.5) + clr.xyz * SH_C0, 1.0 / (1.0 + exp(-clr.w)));
}
`;
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ var scales: texture_2d<f32>;
uniform means_mins: vec3f;
uniform means_maxs: vec3f;

uniform quats_mins: vec3f;
uniform quats_maxs: vec3f;

uniform scales_mins: vec3f;
uniform scales_maxs: vec3f;

Expand All @@ -19,7 +16,7 @@ fn readCenter(source: ptr<function, SplatSource>) -> vec3f {
let l: vec3f = textureLoad(means_l, source.uv, 0).xyz;
let n: vec3f = (l * 255.0 + u * 255.0 * 256.0) / 65535.0;

let v: vec3f = mix(means_mins, means_maxs, n);
let v: vec3f = mix(uniform.means_mins, uniform.means_maxs, n);
return sign(v) * (exp(abs(v)) - 1.0);
}

Expand All @@ -46,7 +43,7 @@ fn readCovariance(source: ptr<function, SplatSource>, covA_ptr: ptr<function, ve


let rot: mat3x3f = quatToMat3(quat);
let scale: vec3f = exp(mix(scales_mins, scales_maxs, textureLoad(scales, source.uv, 0).xyz));
let scale: vec3f = exp(mix(uniform.scales_mins, uniform.scales_maxs, textureLoad(scales, source.uv, 0).xyz));

// M = S * R
let M: mat3x3f = transpose(mat3x3f(
Expand Down
10 changes: 5 additions & 5 deletions src/scene/shader-lib/wgsl/chunks/gsplat/vert/gsplatSogsSH.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ var sh_centroids: texture_2d<f32>;
uniform shN_mins: f32;
uniform shN_maxs: f32;

fn readSHData(source: ptr<function, SplatSource>, sh: ptr<function, array<vec3f, 15>>, scale: ptr<function, f32>) {
fn readSHData(source: ptr<function, SplatSource>, sh: ptr<function, array<vec3f, SH_COEFFS>>, scale: ptr<function, f32>) {
// extract spherical harmonics palette index
let t: vec2<i32> = vec2<i32>(textureLoad(sh_labels, source.uv, 0).xy * 255.0);
let n: i32 = t.x + t.y * 256;
let u: i32 = (n % 64) * 15;
let u: i32 = (n % 64) * SH_COEFFS;
let v: i32 = n / 64;

// calculate offset into the centroids texture and read 15 consecutive texels
for (var i: i32 = 0; i < 15; i = i + 1) {
sh[i] = mix(vec3f(shN_mins), vec3f(shN_maxs), textureLoad(sh_centroids, vec2<i32>(u + i, v), 0).xyz);
// calculate offset into the centroids texture and read consecutive texels
for (var i: i32 = 0; i < SH_COEFFS; i = i + 1) {
sh[i] = mix(vec3f(uniform.shN_mins), vec3f(uniform.shN_maxs), textureLoad(sh_centroids, vec2<i32>(u + i, v), 0).xyz);
}

*scale = 1.0;
Expand Down
6 changes: 6 additions & 0 deletions src/scene/shader-lib/wgsl/collections/shader-chunks-wgsl.js
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ import gsplatColorVS from '../chunks/gsplat/vert/gsplatColor.js';
import gsplatCommonVS from '../chunks/gsplat/vert/gsplatCommon.js';
import gsplatCompressedDataVS from '../chunks/gsplat/vert/gsplatCompressedData.js';
import gsplatCompressedSHVS from '../chunks/gsplat/vert/gsplatCompressedSH.js';
import gsplatSogsColorVS from '../chunks/gsplat/vert/gsplatSogsColor.js';
import gsplatSogsDataVS from '../chunks/gsplat/vert/gsplatSogsData.js';
import gsplatSogsSHVS from '../chunks/gsplat/vert/gsplatSogsSH.js';
import gsplatCornerVS from '../chunks/gsplat/vert/gsplatCorner.js';
import gsplatDataVS from '../chunks/gsplat/vert/gsplatData.js';
import gsplatOutputVS from '../chunks/gsplat/vert/gsplatOutput.js';
Expand Down Expand Up @@ -221,6 +224,9 @@ const shaderChunksWGSL = {
gsplatCommonVS,
gsplatCompressedDataVS,
gsplatCompressedSHVS,
gsplatSogsColorVS,
gsplatSogsDataVS,
gsplatSogsSHVS,
gsplatDataVS,
gsplatOutputVS,
gsplatPS,
Expand Down