Skip to content

Commit d95b929

Browse files
committed
delayed start instead of warmup
1 parent 68289ae commit d95b929

File tree

2 files changed

+18
-17
lines changed

2 files changed

+18
-17
lines changed

examples/blur_opt.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ class BlurOptModule(nn.Module):
1111

1212
def __init__(self, n: int, embed_dim: int = 4):
1313
super().__init__()
14-
self.num_warmup_steps = 2000
15-
1614
self.embeds = torch.nn.Embedding(n, embed_dim)
1715
self.means_encoder = get_encoder(3, 3)
1816
self.depths_encoder = get_encoder(3, 1)
@@ -73,19 +71,15 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor):
7371
blur_mask = torch.sigmoid(mlp_out)
7472
return blur_mask
7573

76-
def mask_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2):
74+
def mask_loss(self, blur_mask: Tensor, eps: float = 1e-2):
7775
"""Loss function for regularizing the blur mask by controlling its mean.
7876
7977
The loss function diverges to +infinity at 0 and 1. This prevents the mask
80-
from collapsing all 0s or 1s. It is also biased towards 0 to encourage
81-
sparsity. During warmup, the bias is even higher to start with a sparse mask."""
78+
from collapsing all 0s or 1s. It is biased towards 0 to encourage sparsity.
79+
"""
8280
x = blur_mask.mean()
83-
if step <= self.num_warmup_steps:
84-
a = 3
85-
b = 0.1
86-
else:
87-
a = 1
88-
b = 0.1
81+
a = 2.0
82+
b = 0.1
8983
maskloss = a * (1 / (1 - x + eps) - 1) + b * (1 / (x + eps) - 1)
9084
return maskloss
9185

examples/simple_trainer.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,9 @@ class Config:
153153
# Learning rate for blur optimization
154154
blur_opt_lr: float = 1e-3
155155
# Regularization for blur mask
156-
blur_mask_reg: float = 0.002
156+
blur_mask_reg: float = 0.001
157+
# Blur start iteration
158+
blur_start_iter: int = 2_000
157159

158160
# Enable bilateral grid. (experimental)
159161
use_bilateral_grid: bool = False
@@ -651,7 +653,7 @@ def train(self):
651653
if cfg.random_bkgd:
652654
bkgd = torch.rand(1, 3, device=device)
653655
colors = colors + bkgd * (1.0 - alphas)
654-
if cfg.blur_opt:
656+
if cfg.blur_opt and step >= cfg.blur_start_iter:
655657
blur_mask = self.blur_module.predict_mask(image_ids, depths)
656658
renders_blur, _, _ = self.rasterize_splats(
657659
camtoworlds=camtoworlds,
@@ -704,8 +706,8 @@ def train(self):
704706
if cfg.use_bilateral_grid:
705707
tvloss = 10 * total_variation_loss(self.bil_grids.grids)
706708
loss += tvloss
707-
if cfg.blur_opt:
708-
loss += cfg.blur_mask_reg * self.blur_module.mask_loss(blur_mask, step)
709+
if cfg.blur_opt and step >= cfg.blur_start_iter:
710+
loss += cfg.blur_mask_reg * self.blur_module.mask_loss(blur_mask)
709711

710712
# regularizations
711713
if cfg.opacity_reg > 0.0:
@@ -865,6 +867,8 @@ def train(self):
865867
self.eval(step, stage="train")
866868
self.eval(step, stage="val")
867869
self.render_traj(step)
870+
if step % 1000 == 0:
871+
self.eval(step, stage="vis")
868872

869873
# run compression
870874
if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]:
@@ -890,14 +894,17 @@ def eval(self, step: int, stage: str = "val"):
890894
world_rank = self.world_rank
891895
world_size = self.world_size
892896

893-
dataset = self.trainset if stage == "train" else self.valset
897+
dataset = self.valset if stage == "val" else self.trainset
894898
dataloader = torch.utils.data.DataLoader(
895899
dataset, batch_size=1, shuffle=False, num_workers=1
896900
)
897901

898902
ellipse_time = 0
899903
metrics = defaultdict(list)
900904
for i, data in enumerate(dataloader):
905+
if stage == "vis":
906+
if i % 5 != 0:
907+
continue
901908
camtoworlds = data["camtoworld"].to(device)
902909
Ks = data["K"].to(device)
903910
pixels = data["image"].to(device) / 255.0
@@ -929,7 +936,7 @@ def eval(self, step: int, stage: str = "val"):
929936

930937
colors = torch.clamp(colors, 0.0, 1.0)
931938
canvas_list = [pixels, colors]
932-
if self.cfg.blur_opt and stage == "train":
939+
if self.cfg.blur_opt and stage != "val":
933940
blur_mask = self.blur_module.predict_mask(image_ids, depths)
934941
canvas_list.append(blur_mask.repeat(1, 1, 1, 3))
935942
renders_blur, _, _ = self.rasterize_splats(

0 commit comments

Comments
 (0)