Skip to content

Commit ac0cef5

Browse files
committed
docstring
1 parent cb826c6 commit ac0cef5

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

examples/blur_opt.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from dataclasses import dataclass
21
import torch
32
import torch.nn as nn
43
from torch import Tensor
@@ -7,14 +6,13 @@
76
from gsplat.utils import log_transform
87

98

10-
@dataclass
119
class BlurOptModule(nn.Module):
1210
"""Blur optimization module."""
1311

14-
num_warmup_steps: int = 2000
15-
1612
def __init__(self, n: int, embed_dim: int = 4):
1713
super().__init__()
14+
self.num_warmup_steps = 2000
15+
1816
self.embeds = torch.nn.Embedding(n, embed_dim)
1917
self.means_encoder = get_encoder(3, 3)
2018
self.depths_encoder = get_encoder(3, 1)
@@ -76,7 +74,12 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor):
7674
return blur_mask
7775

7876
def mask_mean_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2):
79-
"""Mask mean loss."""
77+
"""Loss function for regularizing the blur mask by controlling its mean.
78+
79+
The loss function is designed to diverge to +infinity at 0 and 1. This
80+
prevents the mask from collapsing to predicting all 0s or 1s. It is also
81+
bias towards 0 to encourage sparsity. During warmup, we set this bias very
82+
high to start with a sparse and not collapsed blur mask."""
8083
x = blur_mask.mean()
8184
if step <= self.num_warmup_steps:
8285
a = 20

0 commit comments

Comments
 (0)