@@ -153,7 +153,9 @@ class Config:
153153 # Learning rate for blur optimization
154154 blur_opt_lr : float = 1e-3
155155 # Regularization for blur mask
156- blur_mask_reg : float = 0.002
156+ blur_mask_reg : float = 0.001
157+ # Blur start iteration
158+ blur_start_iter : int = 2_000
157159
158160 # Enable bilateral grid. (experimental)
159161 use_bilateral_grid : bool = False
@@ -651,7 +653,7 @@ def train(self):
651653 if cfg .random_bkgd :
652654 bkgd = torch .rand (1 , 3 , device = device )
653655 colors = colors + bkgd * (1.0 - alphas )
654- if cfg .blur_opt :
656+ if cfg .blur_opt and step >= cfg . blur_start_iter :
655657 blur_mask = self .blur_module .predict_mask (image_ids , depths )
656658 renders_blur , _ , _ = self .rasterize_splats (
657659 camtoworlds = camtoworlds ,
@@ -704,8 +706,8 @@ def train(self):
704706 if cfg .use_bilateral_grid :
705707 tvloss = 10 * total_variation_loss (self .bil_grids .grids )
706708 loss += tvloss
707- if cfg .blur_opt :
708- loss += cfg .blur_mask_reg * self .blur_module .mask_loss (blur_mask , step )
709+ if cfg .blur_opt and step >= cfg . blur_start_iter :
710+ loss += cfg .blur_mask_reg * self .blur_module .mask_loss (blur_mask )
709711
710712 # regularizations
711713 if cfg .opacity_reg > 0.0 :
@@ -865,6 +867,8 @@ def train(self):
865867 self .eval (step , stage = "train" )
866868 self .eval (step , stage = "val" )
867869 self .render_traj (step )
870+ if step % 1000 == 0 :
871+ self .eval (step , stage = "vis" )
868872
869873 # run compression
870874 if cfg .compression is not None and step in [i - 1 for i in cfg .eval_steps ]:
@@ -890,14 +894,17 @@ def eval(self, step: int, stage: str = "val"):
890894 world_rank = self .world_rank
891895 world_size = self .world_size
892896
893- dataset = self .trainset if stage == "train " else self .valset
897+ dataset = self .valset if stage == "val " else self .trainset
894898 dataloader = torch .utils .data .DataLoader (
895899 dataset , batch_size = 1 , shuffle = False , num_workers = 1
896900 )
897901
898902 ellipse_time = 0
899903 metrics = defaultdict (list )
900904 for i , data in enumerate (dataloader ):
905+ if stage == "vis" :
906+ if i % 5 != 0 :
907+ continue
901908 camtoworlds = data ["camtoworld" ].to (device )
902909 Ks = data ["K" ].to (device )
903910 pixels = data ["image" ].to (device ) / 255.0
@@ -929,7 +936,7 @@ def eval(self, step: int, stage: str = "val"):
929936
930937 colors = torch .clamp (colors , 0.0 , 1.0 )
931938 canvas_list = [pixels , colors ]
932- if self .cfg .blur_opt and stage == "train " :
939+ if self .cfg .blur_opt and stage != "val " :
933940 blur_mask = self .blur_module .predict_mask (image_ids , depths )
934941 canvas_list .append (blur_mask .repeat (1 , 1 , 1 , 3 ))
935942 renders_blur , _ , _ = self .rasterize_splats (
0 commit comments