diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 53ffe7080..4fcf2ac3f 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -37,6 +37,7 @@ from gsplat.optimizers import SelectiveAdam from gsplat.rendering import rasterization from gsplat.strategy import DefaultStrategy, MCMCStrategy +from gsplat.utils import xyz_to_polar from gsplat_viewer import GsplatViewer, GsplatRenderTabState from nerfview import CameraState, RenderTabState, apply_float_colormap @@ -146,6 +147,9 @@ class Config: # Scale regularization scale_reg: float = 0.0 + # Use homogeneous coordinates, 50k max_steps and 30k steps densifications recommended! + use_hom_coords: bool = False + # Enable camera optimization. pose_opt: bool = False # Learning rate for camera optimization @@ -225,6 +229,7 @@ def create_splats_with_optimizers( visible_adam: bool = False, batch_size: int = 1, feature_dim: Optional[int] = None, + use_hom_coords: bool = False, device: str = "cuda", world_rank: int = 0, world_size: int = 1, @@ -238,6 +243,12 @@ def create_splats_with_optimizers( else: raise ValueError("Please specify a correct init_type: sfm or random") + if use_hom_coords: + w, _ = xyz_to_polar(points) + points *= w.unsqueeze(1) + w = torch.log(w) + w = w[world_rank::world_size] + # Initialize the GS size to be the average dist of the 3 nearest neighbors dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] dist_avg = torch.sqrt(dist2_avg) @@ -260,6 +271,9 @@ def create_splats_with_optimizers( ("opacities", torch.nn.Parameter(opacities), opacities_lr), ] + if use_hom_coords: + params.append(("w", torch.nn.Parameter(w), 0.0002 * scene_scale)) + if feature_dim is None: # color is SH coefficients. colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3] @@ -366,6 +380,7 @@ def __init__( visible_adam=cfg.visible_adam, batch_size=cfg.batch_size, feature_dim=feature_dim, + use_hom_coords=cfg.use_hom_coords, device=self.device, world_rank=world_rank, world_size=world_size, @@ -495,6 +510,11 @@ def rasterize_splats( scales = torch.exp(self.splats["scales"]) # [N, 3] opacities = torch.sigmoid(self.splats["opacities"]) # [N,] + if cfg.use_hom_coords: + w_inv = 1.0 / torch.exp(self.splats["w"]).unsqueeze(1) + means = means * w_inv + scales = scales * w_inv + image_ids = kwargs.pop("image_ids", None) if self.cfg.app_opt: colors = self.app_module( @@ -560,6 +580,12 @@ def train(self): self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps) ), ] + if cfg.use_hom_coords: + schedulers.append( + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["w"], gamma=0.1 ** (0.005 / max_steps) + ) + ) if cfg.pose_opt: # pose optimization has a learning rate schedule schedulers.append( @@ -801,6 +827,11 @@ def train(self): means = self.splats["means"] scales = self.splats["scales"] + if cfg.use_hom_coords: + w_inv = 1.0 / torch.exp(self.splats["w"]).unsqueeze(1) + means = means * w_inv + scales = torch.log(torch.exp(scales) * w_inv) + quats = self.splats["quats"] opacities = self.splats["opacities"] export_splats( @@ -1211,7 +1242,23 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config): ), ), } + cfg = tyro.extras.overridable_config_cli(configs) + if cfg.use_hom_coords: + cfg.max_steps = 50_000 + if isinstance(cfg.strategy, DefaultStrategy): + cfg.strategy.refine_stop_iter = 30_000 + cfg.strategy.refine_every = 200 + cfg.strategy.reset_every = 6_000 + cfg.strategy.refine_start_iter = 1_500 + cfg.strategy.prune_too_big = False + elif isinstance(cfg.strategy, MCMCStrategy): + cfg.strategy.refine_start_iter: 1_500 + cfg.strategy.refine_stop_iter: 40_000 + cfg.strategy.refine_every: int = 200 + else: + assert_never(cfg.strategy) + cfg.adjust_steps(cfg.steps_scaler) # try import extra dependencies diff --git a/examples/simple_viewer.py b/examples/simple_viewer.py index 63597033a..8b0facc76 100644 --- a/examples/simple_viewer.py +++ b/examples/simple_viewer.py @@ -103,12 +103,18 @@ def main(local_rank: int, world_rank, world_size: int, args): means, quats, scales, opacities, sh0, shN = [], [], [], [], [], [] for ckpt_path in args.ckpt: ckpt = torch.load(ckpt_path, map_location=device)["splats"] - means.append(ckpt["means"]) + if "w" in ckpt: + w_inv = 1.0 / torch.exp(ckpt["w"]).unsqueeze(1) + means.append(ckpt["means"] * w_inv) + scales.append(torch.exp(ckpt["scales"]) * w_inv) + else: + means.append(ckpt["means"]) + scales.append(torch.exp(ckpt["scales"])) quats.append(F.normalize(ckpt["quats"], p=2, dim=-1)) - scales.append(torch.exp(ckpt["scales"])) opacities.append(torch.sigmoid(ckpt["opacities"])) sh0.append(ckpt["sh0"]) shN.append(ckpt["shN"]) + means = torch.cat(means, dim=0) quats = torch.cat(quats, dim=0) scales = torch.cat(scales, dim=0) diff --git a/gsplat/strategy/default.py b/gsplat/strategy/default.py index 9075278a9..a1e2b81ab 100644 --- a/gsplat/strategy/default.py +++ b/gsplat/strategy/default.py @@ -83,6 +83,7 @@ class DefaultStrategy(Strategy): prune_scale3d: float = 0.1 prune_scale2d: float = 0.15 refine_scale2d_stop_iter: int = 0 + prune_too_big: bool = True refine_start_iter: int = 500 refine_stop_iter: int = 15_000 reset_every: int = 3000 @@ -272,10 +273,12 @@ def _grow_gs( device = grads.device is_grad_high = grads > self.grow_grad2d - is_small = ( - torch.exp(params["scales"]).max(dim=-1).values - <= self.grow_scale3d * state["scene_scale"] - ) + scales = torch.exp(params["scales"]) + if "w" in params: + w_inv = 1.0 / torch.exp(params["w"]).unsqueeze(1) + scales = scales * w_inv + + is_small = scales.max(dim=-1).values <= self.grow_scale3d * state["scene_scale"] is_dupli = is_grad_high & is_small n_dupli = is_dupli.sum().item() @@ -317,10 +320,13 @@ def _prune_gs( step: int, ) -> int: is_prune = torch.sigmoid(params["opacities"].flatten()) < self.prune_opa - if step > self.reset_every: + if step > self.reset_every and self.prune_too_big: + scales = torch.exp(params["scales"]) + if "w" in params: + w_inv = 1.0 / torch.exp(params["w"]).unsqueeze(1) + scales = scales * w_inv is_too_big = ( - torch.exp(params["scales"]).max(dim=-1).values - > self.prune_scale3d * state["scene_scale"] + scales.max(dim=-1).values > self.prune_scale3d * state["scene_scale"] ) # The official code also implements sreen-size pruning but # it's actually not being used due to a bug: diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 83c90a25e..26544cf7d 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -7,7 +7,7 @@ from gsplat import quat_scale_to_covar_preci from gsplat.relocation import compute_relocation -from gsplat.utils import normalized_quat_to_rotmat +from gsplat.utils import normalized_quat_to_rotmat, xyz_to_polar @torch.no_grad() @@ -142,6 +142,11 @@ def split( rest = torch.where(~mask)[0] scales = torch.exp(params["scales"][sel]) + means = params["means"][sel] + if "w" in params: + w_inv = 1.0 / torch.exp(params["w"][sel]).unsqueeze(1) + scales = scales * w_inv + means = means * w_inv quats = F.normalize(params["quats"][sel], dim=-1) rotmats = normalized_quat_to_rotmat(quats) # [N, 3, 3] samples = torch.einsum( @@ -151,12 +156,21 @@ def split( torch.randn(2, len(scales), 3, device=device), ) # [2, N, 3] + means = (means + samples).reshape(-1, 3) + if "w" in params: + w, _ = xyz_to_polar(means) # [2N] + means = means * w.unsqueeze(1) # [2N, 3] + def param_fn(name: str, p: Tensor) -> Tensor: repeats = [2] + [1] * (p.dim() - 1) if name == "means": - p_split = (p[sel] + samples).reshape(-1, 3) # [2N, 3] + p_split = means # [2N, 3] + elif name == "w": + p_split = torch.log(w) elif name == "scales": - p_split = torch.log(scales / 1.6).repeat(2, 1) # [2N, 3] + p_split = torch.log(torch.exp(params["scales"][sel]) / 1.6).repeat( + 2, 1 + ) # [2N, 3] elif name == "opacities" and revised_opacity: new_opacities = 1.0 - torch.sqrt(1.0 - torch.sigmoid(p[sel])) p_split = torch.logit(new_opacities).repeat(repeats) # [2N] @@ -269,9 +283,14 @@ def relocate( probs = opacities[alive_indices].flatten() # ensure its shape is [N,] sampled_idxs = _multinomial_sample(probs, n, replacement=True) sampled_idxs = alive_indices[sampled_idxs] + + scales = torch.exp(params["scales"])[sampled_idxs] + if "w" in params: + w = torch.exp(params["w"][sampled_idxs]).unsqueeze(1) + scales = scales / w new_opacities, new_scales = compute_relocation( opacities=opacities[sampled_idxs], - scales=torch.exp(params["scales"])[sampled_idxs], + scales=scales, ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1, binoms=binoms, ) @@ -281,7 +300,10 @@ def param_fn(name: str, p: Tensor) -> Tensor: if name == "opacities": p[sampled_idxs] = torch.logit(new_opacities) elif name == "scales": - p[sampled_idxs] = torch.log(new_scales) + if "w" in params: + p[sampled_idxs] = torch.log(new_scales * w) + else: + p[sampled_idxs] = torch.log(new_scales) p[dead_indices] = p[sampled_idxs] return torch.nn.Parameter(p, requires_grad=p.requires_grad) @@ -311,9 +333,14 @@ def sample_add( eps = torch.finfo(torch.float32).eps probs = opacities.flatten() sampled_idxs = _multinomial_sample(probs, n, replacement=True) + + scales = torch.exp(params["scales"])[sampled_idxs] + if "w" in params: + w = torch.exp(params["w"][sampled_idxs]).unsqueeze(1) + scales = scales / w new_opacities, new_scales = compute_relocation( opacities=opacities[sampled_idxs], - scales=torch.exp(params["scales"])[sampled_idxs], + scales=scales, ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1, binoms=binoms, ) @@ -323,7 +350,10 @@ def param_fn(name: str, p: Tensor) -> Tensor: if name == "opacities": p[sampled_idxs] = torch.logit(new_opacities) elif name == "scales": - p[sampled_idxs] = torch.log(new_scales) + if "w" in params: + p[sampled_idxs] = torch.log(new_scales * w) + else: + p[sampled_idxs] = torch.log(new_scales) p_new = torch.cat([p, p[sampled_idxs]]) return torch.nn.Parameter(p_new, requires_grad=p.requires_grad) @@ -348,7 +378,13 @@ def inject_noise_to_position( scaler: float, ): opacities = torch.sigmoid(params["opacities"].flatten()) + means = params["means"] scales = torch.exp(params["scales"]) + if "w" in params: + w_inv = 1.0 / torch.exp(params["w"]).unsqueeze(1) + means = means * w_inv + scales = scales * w_inv + covars, _ = quat_scale_to_covar_preci( params["quats"], scales, @@ -366,4 +402,11 @@ def op_sigmoid(x, k=100, x0=0.995): * scaler ) noise = torch.einsum("bij,bj->bi", covars, noise) - params["means"].add_(noise) + means = means + noise + + if "w" in params: + w, _ = xyz_to_polar(means) + means = means * w.unsqueeze(1) + params["w"].data = torch.log(w.clamp_min(1e-8)) + + params["means"].data = means diff --git a/gsplat/utils.py b/gsplat/utils.py index e56692958..50e393246 100644 --- a/gsplat/utils.py +++ b/gsplat/utils.py @@ -133,6 +133,12 @@ def normalized_quat_to_rotmat(quat: Tensor) -> Tensor: return mat.reshape(quat.shape[:-1] + (3, 3)) +def xyz_to_polar(means): + x, y, z = means[:, 0], means[:, 1], means[:, 2] + r = torch.sqrt(x**2 + y**2 + z**2) + return 1 / r, r + + def log_transform(x): return torch.sign(x) * torch.log1p(torch.abs(x))