forked from graphdeco-inria/gaussian-splatting
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathGaussianScoreScaler.py
More file actions
44 lines (35 loc) · 1.95 KB
/
GaussianScoreScaler.py
File metadata and controls
44 lines (35 loc) · 1.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
class GaussianScoreScaler:
@staticmethod
def linear(grad_threshold, gaussian_scores, n_init_points, alpha=1.0):
return GaussianScoreScaler._scale(grad_threshold, gaussian_scores, n_init_points, alpha, "linear")
@staticmethod
def sqrt(grad_threshold, gaussian_scores, n_init_points, alpha=1.0):
return GaussianScoreScaler._scale(grad_threshold, gaussian_scores, n_init_points, alpha, "sqrt")
@staticmethod
def log(grad_threshold, gaussian_scores, n_init_points, alpha=1.0):
return GaussianScoreScaler._scale(grad_threshold, gaussian_scores, n_init_points, alpha, "log")
@staticmethod
def _scale(grad_threshold, gaussian_scores, n_init_points, alpha, method):
"""Applies the selected scaling method."""
with torch.no_grad(): # Ensures no gradient tracking
# Ensure `gaussian_scores` is detached and has `requires_grad=False`
gaussian_scores = gaussian_scores.detach()
# Resize `gaussian_scores` if needed
if gaussian_scores.shape[0] < n_init_points:
gaussian_scores = torch.cat([
gaussian_scores,
torch.zeros(n_init_points - gaussian_scores.shape[0], device="cuda", dtype=gaussian_scores.dtype)
])
# Apply different scaling methods
if method == "linear":
scaled_scores = gaussian_scores / gaussian_scores.max()
elif method == "sqrt":
scaled_scores = torch.sqrt(gaussian_scores) / torch.sqrt(gaussian_scores.max())
elif method == "log":
scaled_scores = torch.log1p(gaussian_scores) / torch.log1p(gaussian_scores.max())
else:
raise ValueError(f"Unknown scaling method: {method}")
# Adjust the gradient threshold using scaled scores
scaled_scores = 1.0 + alpha * (1.0 - scaled_scores)
return grad_threshold * scaled_scores