@@ -46,7 +46,7 @@ def _multinomial_sample(weights: Tensor, n: int, replacement: bool = True) -> Te
4646
4747@torch .no_grad ()
4848def _update_param_with_optimizer (
49- param_fn : Callable [[str , Tensor , bool ], Tensor ],
49+ param_fn : Callable [[str , Tensor ], Tensor ],
5050 optimizer_fn : Callable [[str , Tensor ], Tensor ],
5151 params : Union [Dict [str , torch .nn .Parameter ], torch .nn .ParameterDict ],
5252 optimizers : Dict [str , torch .optim .Optimizer ],
@@ -69,10 +69,13 @@ def _update_param_with_optimizer(
6969
7070 for name in names :
7171 param = params [name ]
72- new_param = param_fn (name , param , param . requires_grad )
72+ new_param = param_fn (name , param )
7373 params [name ] = new_param
7474 if name not in optimizers :
75- assert not param .requires_grad
75+ assert not param .requires_grad , (
76+ f"Optimizer for { name } is not found, but the parameter is trainable."
77+ f"Got requires_grad={ param .requires_grad } "
78+ )
7679 continue
7780 optimizer = optimizers [name ]
7881 for i in range (len (optimizer .param_groups )):
@@ -103,8 +106,8 @@ def duplicate(
103106 device = mask .device
104107 sel = torch .where (mask )[0 ]
105108
106- def param_fn (name : str , p : Tensor , requires_grad : bool ) -> Tensor :
107- return torch .nn .Parameter (torch .cat ([p , p [sel ]]), requires_grad = requires_grad )
109+ def param_fn (name : str , p : Tensor ) -> Tensor :
110+ return torch .nn .Parameter (torch .cat ([p , p [sel ]]), requires_grad = p . requires_grad )
108111
109112 def optimizer_fn (key : str , v : Tensor ) -> Tensor :
110113 return torch .cat ([v , torch .zeros ((len (sel ), * v .shape [1 :]), device = device )])
@@ -148,7 +151,7 @@ def split(
148151 torch .randn (2 , len (scales ), 3 , device = device ),
149152 ) # [2, N, 3]
150153
151- def param_fn (name : str , p : Tensor , requires_grad : bool ) -> Tensor :
154+ def param_fn (name : str , p : Tensor ) -> Tensor :
152155 repeats = [2 ] + [1 ] * (p .dim () - 1 )
153156 if name == "means" :
154157 p_split = (p [sel ] + samples ).reshape (- 1 , 3 ) # [2N, 3]
@@ -160,7 +163,7 @@ def param_fn(name: str, p: Tensor, requires_grad: bool) -> Tensor:
160163 else :
161164 p_split = p [sel ].repeat (repeats )
162165 p_new = torch .cat ([p [rest ], p_split ])
163- p_new = torch .nn .Parameter (p_new , requires_grad = requires_grad )
166+ p_new = torch .nn .Parameter (p_new , requires_grad = p . requires_grad )
164167 return p_new
165168
166169 def optimizer_fn (key : str , v : Tensor ) -> Tensor :
@@ -193,8 +196,8 @@ def remove(
193196 """
194197 sel = torch .where (~ mask )[0 ]
195198
196- def param_fn (name : str , p : Tensor , requires_grad : bool ) -> Tensor :
197- return torch .nn .Parameter (p [sel ], requires_grad = requires_grad )
199+ def param_fn (name : str , p : Tensor ) -> Tensor :
200+ return torch .nn .Parameter (p [sel ], requires_grad = p . requires_grad )
198201
199202 def optimizer_fn (key : str , v : Tensor ) -> Tensor :
200203 return v [sel ]
@@ -222,10 +225,10 @@ def reset_opa(
222225 value: The value to reset the opacities
223226 """
224227
225- def param_fn (name : str , p : Tensor , requires_grad : bool ) -> Tensor :
228+ def param_fn (name : str , p : Tensor ) -> Tensor :
226229 if name == "opacities" :
227230 opacities = torch .clamp (p , max = torch .logit (torch .tensor (value )).item ())
228- return torch .nn .Parameter (opacities , requires_grad = requires_grad )
231+ return torch .nn .Parameter (opacities , requires_grad = p . requires_grad )
229232 else :
230233 raise ValueError (f"Unexpected parameter name: { name } " )
231234
@@ -274,13 +277,13 @@ def relocate(
274277 )
275278 new_opacities = torch .clamp (new_opacities , max = 1.0 - eps , min = min_opacity )
276279
277- def param_fn (name : str , p : Tensor , requires_grad : bool ) -> Tensor :
280+ def param_fn (name : str , p : Tensor ) -> Tensor :
278281 if name == "opacities" :
279282 p [sampled_idxs ] = torch .logit (new_opacities )
280283 elif name == "scales" :
281284 p [sampled_idxs ] = torch .log (new_scales )
282285 p [dead_indices ] = p [sampled_idxs ]
283- return torch .nn .Parameter (p , requires_grad = requires_grad )
286+ return torch .nn .Parameter (p , requires_grad = p . requires_grad )
284287
285288 def optimizer_fn (key : str , v : Tensor ) -> Tensor :
286289 v [sampled_idxs ] = 0
@@ -316,13 +319,13 @@ def sample_add(
316319 )
317320 new_opacities = torch .clamp (new_opacities , max = 1.0 - eps , min = min_opacity )
318321
319- def param_fn (name : str , p : Tensor , requires_grad : bool ) -> Tensor :
322+ def param_fn (name : str , p : Tensor ) -> Tensor :
320323 if name == "opacities" :
321324 p [sampled_idxs ] = torch .logit (new_opacities )
322325 elif name == "scales" :
323326 p [sampled_idxs ] = torch .log (new_scales )
324- p = torch .cat ([p , p [sampled_idxs ]])
325- return torch .nn .Parameter (p , requires_grad = requires_grad )
327+ p_new = torch .cat ([p , p [sampled_idxs ]])
328+ return torch .nn .Parameter (p_new , requires_grad = p . requires_grad )
326329
327330 def optimizer_fn (key : str , v : Tensor ) -> Tensor :
328331 v_new = torch .zeros ((len (sampled_idxs ), * v .shape [1 :]), device = v .device )
0 commit comments