Skip to content

Commit fede42c

Browse files
committed
comments
1 parent 79da6bf commit fede42c

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

gsplat/strategy/ops.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _multinomial_sample(weights: Tensor, n: int, replacement: bool = True) -> Te
4646

4747
@torch.no_grad()
4848
def _update_param_with_optimizer(
49-
param_fn: Callable[[str, Tensor, bool], Tensor],
49+
param_fn: Callable[[str, Tensor], Tensor],
5050
optimizer_fn: Callable[[str, Tensor], Tensor],
5151
params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
5252
optimizers: Dict[str, torch.optim.Optimizer],
@@ -69,10 +69,13 @@ def _update_param_with_optimizer(
6969

7070
for name in names:
7171
param = params[name]
72-
new_param = param_fn(name, param, param.requires_grad)
72+
new_param = param_fn(name, param)
7373
params[name] = new_param
7474
if name not in optimizers:
75-
assert not param.requires_grad
75+
assert not param.requires_grad, (
76+
f"Optimizer for {name} is not found, but the parameter is trainable."
77+
f"Got requires_grad={param.requires_grad}"
78+
)
7679
continue
7780
optimizer = optimizers[name]
7881
for i in range(len(optimizer.param_groups)):
@@ -103,8 +106,8 @@ def duplicate(
103106
device = mask.device
104107
sel = torch.where(mask)[0]
105108

106-
def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
107-
return torch.nn.Parameter(torch.cat([p, p[sel]]), requires_grad=requires_grad)
109+
def param_fn(name: str, p: Tensor) -> Tensor:
110+
return torch.nn.Parameter(torch.cat([p, p[sel]]), requires_grad=p.requires_grad)
108111

109112
def optimizer_fn(key: str, v: Tensor) -> Tensor:
110113
return torch.cat([v, torch.zeros((len(sel), *v.shape[1:]), device=device)])
@@ -148,7 +151,7 @@ def split(
148151
torch.randn(2, len(scales), 3, device=device),
149152
) # [2, N, 3]
150153

151-
def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
154+
def param_fn(name: str, p: Tensor) -> Tensor:
152155
repeats = [2] + [1] * (p.dim() - 1)
153156
if name == "means":
154157
p_split = (p[sel] + samples).reshape(-1, 3) # [2N, 3]
@@ -160,7 +163,7 @@ def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
160163
else:
161164
p_split = p[sel].repeat(repeats)
162165
p_new = torch.cat([p[rest], p_split])
163-
p_new = torch.nn.Parameter(p_new, requires_grad=requires_grad)
166+
p_new = torch.nn.Parameter(p_new, requires_grad=p.requires_grad)
164167
return p_new
165168

166169
def optimizer_fn(key: str, v: Tensor) -> Tensor:
@@ -193,8 +196,8 @@ def remove(
193196
"""
194197
sel = torch.where(~mask)[0]
195198

196-
def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
197-
return torch.nn.Parameter(p[sel], requires_grad=requires_grad)
199+
def param_fn(name: str, p: Tensor) -> Tensor:
200+
return torch.nn.Parameter(p[sel], requires_grad=p.requires_grad)
198201

199202
def optimizer_fn(key: str, v: Tensor) -> Tensor:
200203
return v[sel]
@@ -222,10 +225,10 @@ def reset_opa(
222225
value: The value to reset the opacities
223226
"""
224227

225-
def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
228+
def param_fn(name: str, p: Tensor) -> Tensor:
226229
if name == "opacities":
227230
opacities = torch.clamp(p, max=torch.logit(torch.tensor(value)).item())
228-
return torch.nn.Parameter(opacities, requires_grad=requires_grad)
231+
return torch.nn.Parameter(opacities, requires_grad=p.requires_grad)
229232
else:
230233
raise ValueError(f"Unexpected parameter name: {name}")
231234

@@ -274,13 +277,13 @@ def relocate(
274277
)
275278
new_opacities = torch.clamp(new_opacities, max=1.0 - eps, min=min_opacity)
276279

277-
def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
280+
def param_fn(name: str, p: Tensor) -> Tensor:
278281
if name == "opacities":
279282
p[sampled_idxs] = torch.logit(new_opacities)
280283
elif name == "scales":
281284
p[sampled_idxs] = torch.log(new_scales)
282285
p[dead_indices] = p[sampled_idxs]
283-
return torch.nn.Parameter(p, requires_grad=requires_grad)
286+
return torch.nn.Parameter(p, requires_grad=p.requires_grad)
284287

285288
def optimizer_fn(key: str, v: Tensor) -> Tensor:
286289
v[sampled_idxs] = 0
@@ -316,13 +319,13 @@ def sample_add(
316319
)
317320
new_opacities = torch.clamp(new_opacities, max=1.0 - eps, min=min_opacity)
318321

319-
def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
322+
def param_fn(name: str, p: Tensor) -> Tensor:
320323
if name == "opacities":
321324
p[sampled_idxs] = torch.logit(new_opacities)
322325
elif name == "scales":
323326
p[sampled_idxs] = torch.log(new_scales)
324-
p = torch.cat([p, p[sampled_idxs]])
325-
return torch.nn.Parameter(p, requires_grad=requires_grad)
327+
p_new = torch.cat([p, p[sampled_idxs]])
328+
return torch.nn.Parameter(p_new, requires_grad=p.requires_grad)
326329

327330
def optimizer_fn(key: str, v: Tensor) -> Tensor:
328331
v_new = torch.zeros((len(sampled_idxs), *v.shape[1:]), device=v.device)

0 commit comments

Comments
 (0)