23
23
from typing import Dict , List , Literal , Optional , Tuple , Type , Union
24
24
25
25
import torch
26
- from gsplat .strategy import DefaultStrategy
26
+ from gsplat .strategy import DefaultStrategy , MCMCStrategy
27
27
28
28
try :
29
29
from gsplat .rendering import rasterization
@@ -156,6 +156,16 @@ class SplatfactoModelConfig(ModelConfig):
156
156
"""Shape of the bilateral grid (X, Y, W)"""
157
157
color_corrected_metrics : bool = False
158
158
"""If True, apply color correction to the rendered images before computing the metrics."""
159
+ strategy : Literal ["default" , "mcmc" ] = "default"
160
+ """The default strategy will be used if strategy is not specified. Other strategies, e.g. mcmc, can be used."""
161
+ max_gs_num : int = 1_000_000
162
+ """Maximum number of GSs. Default to 1_000_000."""
163
+ noise_lr : float = 5e5
164
+ """MCMC samping noise learning rate. Default to 5e5."""
165
+ mcmc_opacity_reg : float = 0.01
166
+ """Regularization term for opacity in MCMC strategy. Only enabled when using MCMC strategy"""
167
+ mcmc_scale_reg : float = 0.01
168
+ """Regularization term for scale in MCMC strategy. Only enabled when using MCMC strategy"""
159
169
160
170
161
171
class SplatfactoModel (Model ):
@@ -249,24 +259,40 @@ def populate_modules(self):
249
259
)
250
260
251
261
# Strategy for GS densification
252
- self .strategy = DefaultStrategy (
253
- prune_opa = self .config .cull_alpha_thresh ,
254
- grow_grad2d = self .config .densify_grad_thresh ,
255
- grow_scale3d = self .config .densify_size_thresh ,
256
- grow_scale2d = self .config .split_screen_size ,
257
- prune_scale3d = self .config .cull_scale_thresh ,
258
- prune_scale2d = self .config .cull_screen_size ,
259
- refine_scale2d_stop_iter = self .config .stop_screen_size_at ,
260
- refine_start_iter = self .config .warmup_length ,
261
- refine_stop_iter = self .config .stop_split_at ,
262
- reset_every = self .config .reset_alpha_every * self .config .refine_every ,
263
- refine_every = self .config .refine_every ,
264
- pause_refine_after_reset = self .num_train_data + self .config .refine_every ,
265
- absgrad = self .config .use_absgrad ,
266
- revised_opacity = False ,
267
- verbose = True ,
268
- )
269
- self .strategy_state = self .strategy .initialize_state (scene_scale = 1.0 )
262
+ if self .config .strategy == "default" :
263
+ # Strategy for GS densification
264
+ self .strategy = DefaultStrategy (
265
+ prune_opa = self .config .cull_alpha_thresh ,
266
+ grow_grad2d = self .config .densify_grad_thresh ,
267
+ grow_scale3d = self .config .densify_size_thresh ,
268
+ grow_scale2d = self .config .split_screen_size ,
269
+ prune_scale3d = self .config .cull_scale_thresh ,
270
+ prune_scale2d = self .config .cull_screen_size ,
271
+ refine_scale2d_stop_iter = self .config .stop_screen_size_at ,
272
+ refine_start_iter = self .config .warmup_length ,
273
+ refine_stop_iter = self .config .stop_split_at ,
274
+ reset_every = self .config .reset_alpha_every * self .config .refine_every ,
275
+ refine_every = self .config .refine_every ,
276
+ pause_refine_after_reset = self .num_train_data + self .config .refine_every ,
277
+ absgrad = self .config .use_absgrad ,
278
+ revised_opacity = False ,
279
+ verbose = True ,
280
+ )
281
+ self .strategy_state = self .strategy .initialize_state (scene_scale = 1.0 )
282
+ elif self .config .strategy == "mcmc" :
283
+ self .strategy = MCMCStrategy (
284
+ cap_max = self .config .max_gs_num ,
285
+ noise_lr = self .config .noise_lr ,
286
+ refine_start_iter = self .config .warmup_length ,
287
+ refine_stop_iter = self .config .stop_split_at ,
288
+ refine_every = self .config .refine_every ,
289
+ min_opacity = self .config .cull_alpha_thresh ,
290
+ verbose = False ,
291
+ )
292
+ self .strategy_state = self .strategy .initialize_state ()
293
+ else :
294
+ raise ValueError (f"""Splatfacto does not support strategy { self .config .strategy }
295
+ Currently, the supported strategies include default and mcmc.""" )
270
296
271
297
@property
272
298
def colors (self ):
@@ -338,14 +364,26 @@ def set_background(self, background_color: torch.Tensor):
338
364
339
365
def step_post_backward (self , step ):
340
366
assert step == self .step
341
- self .strategy .step_post_backward (
342
- params = self .gauss_params ,
343
- optimizers = self .optimizers ,
344
- state = self .strategy_state ,
345
- step = self .step ,
346
- info = self .info ,
347
- packed = False ,
348
- )
367
+ if isinstance (self .strategy , DefaultStrategy ):
368
+ self .strategy .step_post_backward (
369
+ params = self .gauss_params ,
370
+ optimizers = self .optimizers ,
371
+ state = self .strategy_state ,
372
+ step = self .step ,
373
+ info = self .info ,
374
+ packed = False ,
375
+ )
376
+ elif isinstance (self .strategy , MCMCStrategy ):
377
+ self .strategy .step_post_backward (
378
+ params = self .gauss_params ,
379
+ optimizers = self .optimizers ,
380
+ state = self .strategy_state ,
381
+ step = step ,
382
+ info = self .info ,
383
+ lr = self .schedulers ["means" ].get_last_lr ()[0 ], # the learning rate for the "means" attribute of the GS
384
+ )
385
+ else :
386
+ raise ValueError (f"Unknown strategy { self .strategy } " )
349
387
350
388
def get_training_callbacks (
351
389
self , training_callback_attributes : TrainingCallbackAttributes
@@ -369,6 +407,7 @@ def get_training_callbacks(
369
407
def step_cb (self , optimizers : Optimizers , step ):
370
408
self .step = step
371
409
self .optimizers = optimizers .optimizers
410
+ self .schedulers = optimizers .schedulers
372
411
373
412
def get_gaussian_param_groups (self ) -> Dict [str , List [Parameter ]]:
374
413
# Here we explicitly use the means, scales as parameters so that the user can override this function and
@@ -529,7 +568,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
529
568
render_mode = render_mode ,
530
569
sh_degree = sh_degree_to_use ,
531
570
sparse_grad = False ,
532
- absgrad = self .strategy .absgrad ,
571
+ absgrad = self .strategy .absgrad if isinstance ( self . strategy , DefaultStrategy ) else False ,
533
572
rasterize_mode = self .config .rasterize_mode ,
534
573
# set some threshold to disregrad small gaussians for faster rendering.
535
574
# radius_clip=3.0,
@@ -651,6 +690,17 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
651
690
"scale_reg" : scale_reg ,
652
691
}
653
692
693
+ # Losses for mcmc
694
+ if self .config .strategy == "mcmc" :
695
+ if self .config .mcmc_opacity_reg > 0.0 :
696
+ mcmc_opacity_reg = (
697
+ self .config .mcmc_opacity_reg * torch .abs (torch .sigmoid (self .gauss_params ["opacities" ])).mean ()
698
+ )
699
+ loss_dict ["mcmc_opacity_reg" ] = mcmc_opacity_reg
700
+ if self .config .mcmc_scale_reg > 0.0 :
701
+ mcmc_scale_reg = self .config .mcmc_scale_reg * torch .abs (torch .exp (self .gauss_params ["scales" ])).mean ()
702
+ loss_dict ["mcmc_scale_reg" ] = mcmc_scale_reg
703
+
654
704
if self .training :
655
705
# Add loss from camera optimizer
656
706
self .camera_optimizer .get_loss_dict (loss_dict )
0 commit comments