Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from gsplat.optimizers import SelectiveAdam
from gsplat.rendering import rasterization
from gsplat.strategy import DefaultStrategy, MCMCStrategy
from gsplat.utils import xyz_to_polar
from gsplat_viewer import GsplatViewer, GsplatRenderTabState
from nerfview import CameraState, RenderTabState, apply_float_colormap

Expand Down Expand Up @@ -146,6 +147,9 @@ class Config:
# Scale regularization
scale_reg: float = 0.0

# Use homogeneous coordinates, 50k max_steps and 30k steps densifications recommended!
use_hom_coords: bool = False

# Enable camera optimization.
pose_opt: bool = False
# Learning rate for camera optimization
Expand Down Expand Up @@ -225,6 +229,7 @@ def create_splats_with_optimizers(
visible_adam: bool = False,
batch_size: int = 1,
feature_dim: Optional[int] = None,
use_hom_coords: bool = False,
device: str = "cuda",
world_rank: int = 0,
world_size: int = 1,
Expand All @@ -238,6 +243,12 @@ def create_splats_with_optimizers(
else:
raise ValueError("Please specify a correct init_type: sfm or random")

if use_hom_coords:
w, _ = xyz_to_polar(points)
points *= w.unsqueeze(1)
w = torch.log(w)
w = w[world_rank::world_size]

# Initialize the GS size to be the average dist of the 3 nearest neighbors
dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,]
dist_avg = torch.sqrt(dist2_avg)
Expand All @@ -260,6 +271,9 @@ def create_splats_with_optimizers(
("opacities", torch.nn.Parameter(opacities), opacities_lr),
]

if use_hom_coords:
params.append(("w", torch.nn.Parameter(w), 0.0002 * scene_scale))

if feature_dim is None:
# color is SH coefficients.
colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3]
Expand Down Expand Up @@ -366,6 +380,7 @@ def __init__(
visible_adam=cfg.visible_adam,
batch_size=cfg.batch_size,
feature_dim=feature_dim,
use_hom_coords=cfg.use_hom_coords,
device=self.device,
world_rank=world_rank,
world_size=world_size,
Expand Down Expand Up @@ -495,6 +510,11 @@ def rasterize_splats(
scales = torch.exp(self.splats["scales"]) # [N, 3]
opacities = torch.sigmoid(self.splats["opacities"]) # [N,]

if cfg.use_hom_coords:
w_inv = 1.0 / torch.exp(self.splats["w"]).unsqueeze(1)
means = means * w_inv
scales = scales * w_inv

image_ids = kwargs.pop("image_ids", None)
if self.cfg.app_opt:
colors = self.app_module(
Expand Down Expand Up @@ -560,6 +580,12 @@ def train(self):
self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps)
),
]
if cfg.use_hom_coords:
schedulers.append(
torch.optim.lr_scheduler.ExponentialLR(
self.optimizers["w"], gamma=0.1 ** (0.005 / max_steps)
)
)
if cfg.pose_opt:
# pose optimization has a learning rate schedule
schedulers.append(
Expand Down Expand Up @@ -801,6 +827,11 @@ def train(self):

means = self.splats["means"]
scales = self.splats["scales"]
if cfg.use_hom_coords:
w_inv = 1.0 / torch.exp(self.splats["w"]).unsqueeze(1)
means = means * w_inv
scales = torch.log(torch.exp(scales) * w_inv)

quats = self.splats["quats"]
opacities = self.splats["opacities"]
export_splats(
Expand Down Expand Up @@ -1211,7 +1242,23 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config):
),
),
}

cfg = tyro.extras.overridable_config_cli(configs)
if cfg.use_hom_coords:
cfg.max_steps = 50_000
if isinstance(cfg.strategy, DefaultStrategy):
cfg.strategy.refine_stop_iter = 30_000
cfg.strategy.refine_every = 200
cfg.strategy.reset_every = 6_000
cfg.strategy.refine_start_iter = 1_500
cfg.strategy.prune_too_big = False
elif isinstance(cfg.strategy, MCMCStrategy):
cfg.strategy.refine_start_iter: 1_500
cfg.strategy.refine_stop_iter: 40_000
cfg.strategy.refine_every: int = 200
else:
assert_never(cfg.strategy)

cfg.adjust_steps(cfg.steps_scaler)

# try import extra dependencies
Expand Down
10 changes: 8 additions & 2 deletions examples/simple_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,18 @@ def main(local_rank: int, world_rank, world_size: int, args):
means, quats, scales, opacities, sh0, shN = [], [], [], [], [], []
for ckpt_path in args.ckpt:
ckpt = torch.load(ckpt_path, map_location=device)["splats"]
means.append(ckpt["means"])
if "w" in ckpt:
w_inv = 1.0 / torch.exp(ckpt["w"]).unsqueeze(1)
means.append(ckpt["means"] * w_inv)
scales.append(torch.exp(ckpt["scales"]) * w_inv)
else:
means.append(ckpt["means"])
scales.append(torch.exp(ckpt["scales"]))
quats.append(F.normalize(ckpt["quats"], p=2, dim=-1))
scales.append(torch.exp(ckpt["scales"]))
opacities.append(torch.sigmoid(ckpt["opacities"]))
sh0.append(ckpt["sh0"])
shN.append(ckpt["shN"])

means = torch.cat(means, dim=0)
quats = torch.cat(quats, dim=0)
scales = torch.cat(scales, dim=0)
Expand Down
20 changes: 13 additions & 7 deletions gsplat/strategy/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class DefaultStrategy(Strategy):
prune_scale3d: float = 0.1
prune_scale2d: float = 0.15
refine_scale2d_stop_iter: int = 0
prune_too_big: bool = True
refine_start_iter: int = 500
refine_stop_iter: int = 15_000
reset_every: int = 3000
Expand Down Expand Up @@ -272,10 +273,12 @@ def _grow_gs(
device = grads.device

is_grad_high = grads > self.grow_grad2d
is_small = (
torch.exp(params["scales"]).max(dim=-1).values
<= self.grow_scale3d * state["scene_scale"]
)
scales = torch.exp(params["scales"])
if "w" in params:
w_inv = 1.0 / torch.exp(params["w"]).unsqueeze(1)
scales = scales * w_inv

is_small = scales.max(dim=-1).values <= self.grow_scale3d * state["scene_scale"]
is_dupli = is_grad_high & is_small
n_dupli = is_dupli.sum().item()

Expand Down Expand Up @@ -317,10 +320,13 @@ def _prune_gs(
step: int,
) -> int:
is_prune = torch.sigmoid(params["opacities"].flatten()) < self.prune_opa
if step > self.reset_every:
if step > self.reset_every and self.prune_too_big:
scales = torch.exp(params["scales"])
if "w" in params:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello author, I have a question, when use_hom_coords is false and prune_too_big is true, this logic will be executed, but at this time, "w" will not be in params, is this judgment redundant? Looking forward to your reply

w_inv = 1.0 / torch.exp(params["w"]).unsqueeze(1)
scales = scales * w_inv
is_too_big = (
torch.exp(params["scales"]).max(dim=-1).values
> self.prune_scale3d * state["scene_scale"]
scales.max(dim=-1).values > self.prune_scale3d * state["scene_scale"]
)
# The official code also implements sreen-size pruning but
# it's actually not being used due to a bug:
Expand Down
59 changes: 51 additions & 8 deletions gsplat/strategy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from gsplat import quat_scale_to_covar_preci
from gsplat.relocation import compute_relocation
from gsplat.utils import normalized_quat_to_rotmat
from gsplat.utils import normalized_quat_to_rotmat, xyz_to_polar


@torch.no_grad()
Expand Down Expand Up @@ -142,6 +142,11 @@ def split(
rest = torch.where(~mask)[0]

scales = torch.exp(params["scales"][sel])
means = params["means"][sel]
if "w" in params:
w_inv = 1.0 / torch.exp(params["w"][sel]).unsqueeze(1)
scales = scales * w_inv
means = means * w_inv
quats = F.normalize(params["quats"][sel], dim=-1)
rotmats = normalized_quat_to_rotmat(quats) # [N, 3, 3]
samples = torch.einsum(
Expand All @@ -151,12 +156,21 @@ def split(
torch.randn(2, len(scales), 3, device=device),
) # [2, N, 3]

means = (means + samples).reshape(-1, 3)
if "w" in params:
w, _ = xyz_to_polar(means) # [2N]
means = means * w.unsqueeze(1) # [2N, 3]

def param_fn(name: str, p: Tensor) -> Tensor:
repeats = [2] + [1] * (p.dim() - 1)
if name == "means":
p_split = (p[sel] + samples).reshape(-1, 3) # [2N, 3]
p_split = means # [2N, 3]
elif name == "w":
p_split = torch.log(w)
elif name == "scales":
p_split = torch.log(scales / 1.6).repeat(2, 1) # [2N, 3]
p_split = torch.log(torch.exp(params["scales"][sel]) / 1.6).repeat(
2, 1
) # [2N, 3]
elif name == "opacities" and revised_opacity:
new_opacities = 1.0 - torch.sqrt(1.0 - torch.sigmoid(p[sel]))
p_split = torch.logit(new_opacities).repeat(repeats) # [2N]
Expand Down Expand Up @@ -269,9 +283,14 @@ def relocate(
probs = opacities[alive_indices].flatten() # ensure its shape is [N,]
sampled_idxs = _multinomial_sample(probs, n, replacement=True)
sampled_idxs = alive_indices[sampled_idxs]

scales = torch.exp(params["scales"])[sampled_idxs]
if "w" in params:
w = torch.exp(params["w"][sampled_idxs]).unsqueeze(1)
scales = scales / w
new_opacities, new_scales = compute_relocation(
opacities=opacities[sampled_idxs],
scales=torch.exp(params["scales"])[sampled_idxs],
scales=scales,
ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1,
binoms=binoms,
)
Expand All @@ -281,7 +300,10 @@ def param_fn(name: str, p: Tensor) -> Tensor:
if name == "opacities":
p[sampled_idxs] = torch.logit(new_opacities)
elif name == "scales":
p[sampled_idxs] = torch.log(new_scales)
if "w" in params:
p[sampled_idxs] = torch.log(new_scales * w)
else:
p[sampled_idxs] = torch.log(new_scales)
p[dead_indices] = p[sampled_idxs]
return torch.nn.Parameter(p, requires_grad=p.requires_grad)

Expand Down Expand Up @@ -311,9 +333,14 @@ def sample_add(
eps = torch.finfo(torch.float32).eps
probs = opacities.flatten()
sampled_idxs = _multinomial_sample(probs, n, replacement=True)

scales = torch.exp(params["scales"])[sampled_idxs]
if "w" in params:
w = torch.exp(params["w"][sampled_idxs]).unsqueeze(1)
scales = scales / w
new_opacities, new_scales = compute_relocation(
opacities=opacities[sampled_idxs],
scales=torch.exp(params["scales"])[sampled_idxs],
scales=scales,
ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1,
binoms=binoms,
)
Expand All @@ -323,7 +350,10 @@ def param_fn(name: str, p: Tensor) -> Tensor:
if name == "opacities":
p[sampled_idxs] = torch.logit(new_opacities)
elif name == "scales":
p[sampled_idxs] = torch.log(new_scales)
if "w" in params:
p[sampled_idxs] = torch.log(new_scales * w)
else:
p[sampled_idxs] = torch.log(new_scales)
p_new = torch.cat([p, p[sampled_idxs]])
return torch.nn.Parameter(p_new, requires_grad=p.requires_grad)

Expand All @@ -348,7 +378,13 @@ def inject_noise_to_position(
scaler: float,
):
opacities = torch.sigmoid(params["opacities"].flatten())
means = params["means"]
scales = torch.exp(params["scales"])
if "w" in params:
w_inv = 1.0 / torch.exp(params["w"]).unsqueeze(1)
means = means * w_inv
scales = scales * w_inv

covars, _ = quat_scale_to_covar_preci(
params["quats"],
scales,
Expand All @@ -366,4 +402,11 @@ def op_sigmoid(x, k=100, x0=0.995):
* scaler
)
noise = torch.einsum("bij,bj->bi", covars, noise)
params["means"].add_(noise)
means = means + noise

if "w" in params:
w, _ = xyz_to_polar(means)
means = means * w.unsqueeze(1)
params["w"].data = torch.log(w.clamp_min(1e-8))

params["means"].data = means
6 changes: 6 additions & 0 deletions gsplat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ def normalized_quat_to_rotmat(quat: Tensor) -> Tensor:
return mat.reshape(quat.shape[:-1] + (3, 3))


def xyz_to_polar(means):
x, y, z = means[:, 0], means[:, 1], means[:, 2]
r = torch.sqrt(x**2 + y**2 + z**2)
return 1 / r, r


def log_transform(x):
return torch.sign(x) * torch.log1p(torch.abs(x))

Expand Down