Skip to content

Commit c7bd953

Browse files
KevinXu02KevinXu02kerrj
authored
Finalize MCMC strategy and some tiny fix (#3548)
--------- Co-authored-by: KevinXu02 <[email protected]> Co-authored-by: Justin Kerr <[email protected]>
1 parent 0ced5ce commit c7bd953

File tree

3 files changed

+141
-29
lines changed

3 files changed

+141
-29
lines changed

nerfstudio/configs/method_configs.py

+62-1
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,68 @@
696696
),
697697
},
698698
"bilateral_grid": {
699-
"optimizer": AdamOptimizerConfig(lr=5e-3, eps=1e-15),
699+
"optimizer": AdamOptimizerConfig(lr=2e-3, eps=1e-15),
700+
"scheduler": ExponentialDecaySchedulerConfig(
701+
lr_final=1e-4, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0
702+
),
703+
},
704+
},
705+
viewer=ViewerConfig(num_rays_per_chunk=1 << 15),
706+
vis="viewer",
707+
)
708+
709+
method_configs["splatfacto-mcmc"] = TrainerConfig(
710+
method_name="splatfacto",
711+
steps_per_eval_image=100,
712+
steps_per_eval_batch=0,
713+
steps_per_save=2000,
714+
steps_per_eval_all_images=1000,
715+
max_num_iterations=30000,
716+
mixed_precision=False,
717+
pipeline=VanillaPipelineConfig(
718+
datamanager=FullImageDatamanagerConfig(
719+
dataparser=NerfstudioDataParserConfig(load_3D_points=True),
720+
cache_images_type="uint8",
721+
),
722+
model=SplatfactoModelConfig(
723+
strategy="mcmc",
724+
cull_alpha_thresh=0.005,
725+
stop_split_at=25000,
726+
),
727+
),
728+
optimizers={
729+
"means": {
730+
"optimizer": AdamOptimizerConfig(lr=1.6e-4, eps=1e-15),
731+
"scheduler": ExponentialDecaySchedulerConfig(
732+
lr_final=1.6e-6,
733+
max_steps=30000,
734+
),
735+
},
736+
"features_dc": {
737+
"optimizer": AdamOptimizerConfig(lr=0.0025, eps=1e-15),
738+
"scheduler": None,
739+
},
740+
"features_rest": {
741+
"optimizer": AdamOptimizerConfig(lr=0.0025 / 20, eps=1e-15),
742+
"scheduler": None,
743+
},
744+
"opacities": {
745+
"optimizer": AdamOptimizerConfig(lr=0.05, eps=1e-15),
746+
"scheduler": None,
747+
},
748+
"scales": {
749+
"optimizer": AdamOptimizerConfig(lr=0.005, eps=1e-15),
750+
"scheduler": None,
751+
},
752+
"quats": {"optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), "scheduler": None},
753+
"camera_opt": {
754+
"optimizer": AdamOptimizerConfig(lr=1e-4, eps=1e-15),
755+
"scheduler": ExponentialDecaySchedulerConfig(
756+
lr_final=5e-7, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0
757+
),
758+
},
759+
"bilateral_grid": {
760+
"optimizer": AdamOptimizerConfig(lr=2e-3, eps=1e-15),
700761
"scheduler": ExponentialDecaySchedulerConfig(
701762
lr_final=1e-4, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0
702763
),

nerfstudio/models/splatfacto.py

+78-28
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from typing import Dict, List, Literal, Optional, Tuple, Type, Union
2424

2525
import torch
26-
from gsplat.strategy import DefaultStrategy
26+
from gsplat.strategy import DefaultStrategy, MCMCStrategy
2727

2828
try:
2929
from gsplat.rendering import rasterization
@@ -156,6 +156,16 @@ class SplatfactoModelConfig(ModelConfig):
156156
"""Shape of the bilateral grid (X, Y, W)"""
157157
color_corrected_metrics: bool = False
158158
"""If True, apply color correction to the rendered images before computing the metrics."""
159+
strategy: Literal["default", "mcmc"] = "default"
160+
"""The default strategy will be used if strategy is not specified. Other strategies, e.g. mcmc, can be used."""
161+
max_gs_num: int = 1_000_000
162+
"""Maximum number of GSs. Default to 1_000_000."""
163+
noise_lr: float = 5e5
164+
"""MCMC samping noise learning rate. Default to 5e5."""
165+
mcmc_opacity_reg: float = 0.01
166+
"""Regularization term for opacity in MCMC strategy. Only enabled when using MCMC strategy"""
167+
mcmc_scale_reg: float = 0.01
168+
"""Regularization term for scale in MCMC strategy. Only enabled when using MCMC strategy"""
159169

160170

161171
class SplatfactoModel(Model):
@@ -249,24 +259,40 @@ def populate_modules(self):
249259
)
250260

251261
# Strategy for GS densification
252-
self.strategy = DefaultStrategy(
253-
prune_opa=self.config.cull_alpha_thresh,
254-
grow_grad2d=self.config.densify_grad_thresh,
255-
grow_scale3d=self.config.densify_size_thresh,
256-
grow_scale2d=self.config.split_screen_size,
257-
prune_scale3d=self.config.cull_scale_thresh,
258-
prune_scale2d=self.config.cull_screen_size,
259-
refine_scale2d_stop_iter=self.config.stop_screen_size_at,
260-
refine_start_iter=self.config.warmup_length,
261-
refine_stop_iter=self.config.stop_split_at,
262-
reset_every=self.config.reset_alpha_every * self.config.refine_every,
263-
refine_every=self.config.refine_every,
264-
pause_refine_after_reset=self.num_train_data + self.config.refine_every,
265-
absgrad=self.config.use_absgrad,
266-
revised_opacity=False,
267-
verbose=True,
268-
)
269-
self.strategy_state = self.strategy.initialize_state(scene_scale=1.0)
262+
if self.config.strategy == "default":
263+
# Strategy for GS densification
264+
self.strategy = DefaultStrategy(
265+
prune_opa=self.config.cull_alpha_thresh,
266+
grow_grad2d=self.config.densify_grad_thresh,
267+
grow_scale3d=self.config.densify_size_thresh,
268+
grow_scale2d=self.config.split_screen_size,
269+
prune_scale3d=self.config.cull_scale_thresh,
270+
prune_scale2d=self.config.cull_screen_size,
271+
refine_scale2d_stop_iter=self.config.stop_screen_size_at,
272+
refine_start_iter=self.config.warmup_length,
273+
refine_stop_iter=self.config.stop_split_at,
274+
reset_every=self.config.reset_alpha_every * self.config.refine_every,
275+
refine_every=self.config.refine_every,
276+
pause_refine_after_reset=self.num_train_data + self.config.refine_every,
277+
absgrad=self.config.use_absgrad,
278+
revised_opacity=False,
279+
verbose=True,
280+
)
281+
self.strategy_state = self.strategy.initialize_state(scene_scale=1.0)
282+
elif self.config.strategy == "mcmc":
283+
self.strategy = MCMCStrategy(
284+
cap_max=self.config.max_gs_num,
285+
noise_lr=self.config.noise_lr,
286+
refine_start_iter=self.config.warmup_length,
287+
refine_stop_iter=self.config.stop_split_at,
288+
refine_every=self.config.refine_every,
289+
min_opacity=self.config.cull_alpha_thresh,
290+
verbose=False,
291+
)
292+
self.strategy_state = self.strategy.initialize_state()
293+
else:
294+
raise ValueError(f"""Splatfacto does not support strategy {self.config.strategy}
295+
Currently, the supported strategies include default and mcmc.""")
270296

271297
@property
272298
def colors(self):
@@ -338,14 +364,26 @@ def set_background(self, background_color: torch.Tensor):
338364

339365
def step_post_backward(self, step):
340366
assert step == self.step
341-
self.strategy.step_post_backward(
342-
params=self.gauss_params,
343-
optimizers=self.optimizers,
344-
state=self.strategy_state,
345-
step=self.step,
346-
info=self.info,
347-
packed=False,
348-
)
367+
if isinstance(self.strategy, DefaultStrategy):
368+
self.strategy.step_post_backward(
369+
params=self.gauss_params,
370+
optimizers=self.optimizers,
371+
state=self.strategy_state,
372+
step=self.step,
373+
info=self.info,
374+
packed=False,
375+
)
376+
elif isinstance(self.strategy, MCMCStrategy):
377+
self.strategy.step_post_backward(
378+
params=self.gauss_params,
379+
optimizers=self.optimizers,
380+
state=self.strategy_state,
381+
step=step,
382+
info=self.info,
383+
lr=self.schedulers["means"].get_last_lr()[0], # the learning rate for the "means" attribute of the GS
384+
)
385+
else:
386+
raise ValueError(f"Unknown strategy {self.strategy}")
349387

350388
def get_training_callbacks(
351389
self, training_callback_attributes: TrainingCallbackAttributes
@@ -369,6 +407,7 @@ def get_training_callbacks(
369407
def step_cb(self, optimizers: Optimizers, step):
370408
self.step = step
371409
self.optimizers = optimizers.optimizers
410+
self.schedulers = optimizers.schedulers
372411

373412
def get_gaussian_param_groups(self) -> Dict[str, List[Parameter]]:
374413
# Here we explicitly use the means, scales as parameters so that the user can override this function and
@@ -529,7 +568,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
529568
render_mode=render_mode,
530569
sh_degree=sh_degree_to_use,
531570
sparse_grad=False,
532-
absgrad=self.strategy.absgrad,
571+
absgrad=self.strategy.absgrad if isinstance(self.strategy, DefaultStrategy) else False,
533572
rasterize_mode=self.config.rasterize_mode,
534573
# set some threshold to disregrad small gaussians for faster rendering.
535574
# radius_clip=3.0,
@@ -651,6 +690,17 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
651690
"scale_reg": scale_reg,
652691
}
653692

693+
# Losses for mcmc
694+
if self.config.strategy == "mcmc":
695+
if self.config.mcmc_opacity_reg > 0.0:
696+
mcmc_opacity_reg = (
697+
self.config.mcmc_opacity_reg * torch.abs(torch.sigmoid(self.gauss_params["opacities"])).mean()
698+
)
699+
loss_dict["mcmc_opacity_reg"] = mcmc_opacity_reg
700+
if self.config.mcmc_scale_reg > 0.0:
701+
mcmc_scale_reg = self.config.mcmc_scale_reg * torch.abs(torch.exp(self.gauss_params["scales"])).mean()
702+
loss_dict["mcmc_scale_reg"] = mcmc_scale_reg
703+
654704
if self.training:
655705
# Add loss from camera optimizer
656706
self.camera_optimizer.get_loss_dict(loss_dict)

tests/test_train.py

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
"neus-facto",
2929
"splatfacto",
3030
"splatfacto-big",
31+
"splatfacto-mcmc",
3132
]
3233

3334

0 commit comments

Comments
 (0)