@@ -42,9 +42,8 @@ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
4242class ModelConfig :
4343
4444 n_cameras : int = - 1
45- pose_opt_type : Literal ["sfm" , "mlp" , "7dmlp" ] = "sfm"
45+ pose_opt_type : Literal ["sfm" , "mlp" ] = "sfm"
4646 cam_scale : float = 1.0
47- scale : float = 1e-3 # Used for 7dmlp
4847 mlp_width : int = 64
4948 mlp_depth : int = 2
5049
@@ -58,7 +57,6 @@ class OptimizationConfig:
5857 shceduler_type : Literal ["step" , "cosine" , "none" ] = "none"
5958 eps : float = 1e-15
6059 max_steps : int = 30_000
61- opt_test : bool = False # TODO: remove it
6260
6361class CameraOptModule (nn .Module ):
6462 """Camera pose optimization module."""
@@ -166,80 +164,7 @@ def forward(self, camtoworlds: torch.Tensor, embed_ids: torch.Tensor) -> torch.T
166164 transform [..., :3 , 3 ] = dx * self .cam_scale
167165
168166 return torch .matmul (camtoworlds , transform )
169-
170- class CameraOptModule7dMLP (torch .nn .Module ):
171- """Camera pose optimization module using MLP."""
172-
173- def __init__ (self , n : int , mlp_width : int = 256 , mlp_depth : int = 2 , scale : float = 1e-6 ):
174- super ().__init__ ()
175- # Identity rotation in 6D representation
176- self .register_buffer ("identity" , torch .tensor ([1.0 , 0.0 , 0.0 , 0.0 , 1.0 , 0.0 ]))
177-
178- # Initial embeddings for each camera
179- self .num_cams = n
180-
181- # MLP layers
182- activation = torch .nn .ELU (inplace = True )
183- layers = []
184- layers .append (torch .nn .Linear (7 , mlp_width ))
185- layers .append (activation )
186- for _ in range (mlp_depth - 1 ):
187- layers .append (torch .nn .Linear (mlp_width , mlp_width ))
188- layers .append (activation )
189- # Output layer produces 9D adjustments (3D position + 6D rotation)
190- layers .append (torch .nn .Linear (mlp_width , 6 ))
191- self .mlp = torch .nn .Sequential (* layers )
192-
193- self .scale = scale
194-
195- def zero_init (self ):
196- # torch.nn.init.zeros_(self.embeds.weight)
197- #torch.nn.init.normal_(self.embeds.weight)
198- # Also initialize the last layer of MLP with small weights
199- # torch.nn.init.zeros_(self.mlp[-1].weight)
200- # torch.nn.init.zeros_(self.mlp[-1].bias)
201- pass
202-
203- def random_init (self , std : float ):
204- # torch.nn.init.normal_(self.embeds.weight, std=std)
205- # Initialize the last layer of MLP with small weights
206- torch .nn .init .normal_ (self .mlp [- 1 ].weight , std = std )
207- torch .nn .init .normal_ (self .mlp [- 1 ].bias , std = std )
208-
209- def forward (self , camtoworlds : torch .Tensor , embed_ids : torch .Tensor ) -> torch .Tensor :
210- """Adjust camera pose based on MLP outputs with SGLD noise.
211-
212- Args:
213- camtoworlds: (..., 4, 4)
214- embed_ids: (...,)
215-
216- Returns:
217- updated camtoworlds: (..., 4, 4)
218- """
219- assert camtoworlds .shape [:- 2 ] == embed_ids .shape
220- if camtoworlds .ndim == 2 :
221- camtoworlds = camtoworlds .unsqueeze (0 )
222- if embed_ids .ndim == 0 :
223- embed_ids = embed_ids .unsqueeze (0 )
224- batch_shape = camtoworlds .shape [:- 2 ]
225-
226- # Get embeddings and process through MLP with noise
227- r_init = rotation_matrix_to_axis_angle (camtoworlds [..., :3 , :3 ])
228- t_init = camtoworlds [..., :3 , 3 ]
229-
230- mlp_input = torch .cat ((embed_ids [..., None ], r_init , t_init ), dim = - 1 ) # (..., 7)
231-
232- out = self .mlp (mlp_input ) * self .scale
233-
234- r = out [..., :3 ] + r_init
235- t = out [..., 3 :] + t_init
236- R = axis_angle_to_rotation_matrix (r )
237-
238- camtoworlds_corrected = torch .eye (4 , device = camtoworlds .device ).repeat ((* batch_shape , 1 , 1 ))
239- camtoworlds_corrected [..., :3 , :3 ] = R
240- camtoworlds_corrected [..., :3 , 3 ] = t
241-
242- return camtoworlds_corrected .squeeze ()
167+
243168
244169@dataclass
245170class GSplatCameraOptRenderer (GSplatV1Renderer ):
@@ -281,13 +206,6 @@ def _setup_model(self, device=None):
281206 mlp_depth = self .config .model .mlp_depth ,
282207 cam_scale = self .config .model .cam_scale
283208 )
284- elif self .config .model .pose_opt_type == "7dmlp" :
285- self .model = CameraOptModule7dMLP (
286- n = self .config .model .n_cameras ,
287- mlp_width = self .config .model .mlp_width ,
288- mlp_depth = self .config .model .mlp_depth ,
289- scale = self .config .model .scale
290- )
291209 else :
292210 self .model = CameraOptModule (self .config .model .n_cameras )
293211
0 commit comments