@@ -252,6 +252,55 @@ def init_from_colmap(self, root_path: str, observer_pts):
252252 file_pts , observer_pts , file_rgb , use_observer_pts = self .conf .initialization .use_observation_points
253253 )
254254
255+ def init_from_accumulated_point_cloud (self , pc_path : str , observer_pts ):
256+ """
257+ Initialize gaussians from an accumulated point cloud PLY file.
258+ Similar to init_from_colmap but loads from a given PLY file instead of sparse/0/points3D.txt
259+
260+ Args:
261+ pc_path: Path to the PLY point cloud file
262+ observer_pts: Observer points tensor for scale initialization
263+ """
264+ logger .info (f"Loading accumulated point cloud from { pc_path } ..." )
265+
266+ # Read PLY file
267+ plydata = PlyData .read (pc_path )
268+ vertices = plydata ['vertex' ]
269+
270+ # Extract XYZ coordinates
271+ xyz = np .stack ([
272+ vertices ['x' ],
273+ vertices ['y' ],
274+ vertices ['z' ]
275+ ], axis = 1 ).astype (np .float32 )
276+
277+ # Extract RGB colors (check if they exist)
278+ if 'red' in vertices and 'green' in vertices and 'blue' in vertices :
279+ rgb = np .stack ([
280+ vertices ['red' ],
281+ vertices ['green' ],
282+ vertices ['blue' ]
283+ ], axis = 1 ).astype (np .uint8 )
284+ else :
285+ # If no colors, initialize with random colors
286+ logger .warning ("No RGB data found in point cloud, using random colors" )
287+ rgb = np .random .randint (0 , 256 , size = (len (vertices ), 3 ), dtype = np .uint8 )
288+
289+ # Convert to torch tensors
290+ file_pts = torch .tensor (xyz , dtype = torch .float32 , device = self .device )
291+ file_rgb = torch .tensor (rgb , dtype = torch .uint8 , device = self .device )
292+
293+ logger .info (f"Loaded { len (file_pts )} points from accumulated point cloud" )
294+
295+ # Initialize using the same method as COLMAP
296+ assert file_rgb .dtype == torch .uint8 , "Expecting RGB values to be in [0, 255] range"
297+ self .default_initialize_from_points (
298+ file_pts ,
299+ observer_pts ,
300+ file_rgb ,
301+ use_observer_pts = self .conf .initialization .use_observation_points
302+ )
303+
255304 def init_from_pretrained_point_cloud (self , pc_path : str , set_optimizable_parameters : bool = True ):
256305 data = PlyData .read (pc_path )
257306 num_gaussians = len (data ["vertex" ])
0 commit comments