Skip to content

Commit 8fd003c

Browse files
ApdowJNwilsonCernWq
authored andcommitted
Adapt cusfm format for training
1 parent f530c88 commit 8fd003c

File tree

5 files changed

+78
-1
lines changed

5 files changed

+78
-1
lines changed

configs/apps/cusfm_3dgut_mcmc.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# @package _global_
2+
3+
# order in which configs override each other (/* - denotes a relative search path)
4+
defaults:
5+
- /base_mcmc
6+
- /dataset: colmap
7+
- /initialization: accumulated_point_cloud
8+
- /render: 3dgut
9+
- _self_
10+
11+
# overwrite of default parameters
12+
val_frequency: 999999 # never validate
13+
14+
initialization:
15+
accumulated_point_cloud_path: ??? # Path to accumulated point cloud PLY file to be provided by user at runtime
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
method: accumulated_point_cloud
2+
observation_scale_factor: 0.01
3+
use_observation_points: true
4+
accumulated_point_cloud_path: ??? # Path to accumulated point cloud PLY file
5+

threedgrut/datasets/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,8 @@ def read_colmap_extrinsics_text(path):
479479
with open(path, "r") as fid:
480480
# Skip comment lines and get valid lines
481481
lines = (line.strip() for line in fid)
482-
lines = (line for line in lines if line and not line.startswith("#"))
482+
# lines = (line for line in lines if line and not line.startswith("#"))
483+
lines = (line for line in lines if line == "" or not line.startswith("#"))
483484
# Process lines in pairs (image info + points info)
484485
try:
485486
while True:

threedgrut/model/model.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,55 @@ def init_from_colmap(self, root_path: str, observer_pts):
252252
file_pts, observer_pts, file_rgb, use_observer_pts=self.conf.initialization.use_observation_points
253253
)
254254

255+
def init_from_accumulated_point_cloud(self, pc_path: str, observer_pts):
256+
"""
257+
Initialize gaussians from an accumulated point cloud PLY file.
258+
Similar to init_from_colmap but loads from a given PLY file instead of sparse/0/points3D.txt
259+
260+
Args:
261+
pc_path: Path to the PLY point cloud file
262+
observer_pts: Observer points tensor for scale initialization
263+
"""
264+
logger.info(f"Loading accumulated point cloud from {pc_path}...")
265+
266+
# Read PLY file
267+
plydata = PlyData.read(pc_path)
268+
vertices = plydata['vertex']
269+
270+
# Extract XYZ coordinates
271+
xyz = np.stack([
272+
vertices['x'],
273+
vertices['y'],
274+
vertices['z']
275+
], axis=1).astype(np.float32)
276+
277+
# Extract RGB colors (check if they exist)
278+
if 'red' in vertices and 'green' in vertices and 'blue' in vertices:
279+
rgb = np.stack([
280+
vertices['red'],
281+
vertices['green'],
282+
vertices['blue']
283+
], axis=1).astype(np.uint8)
284+
else:
285+
# If no colors, initialize with random colors
286+
logger.warning("No RGB data found in point cloud, using random colors")
287+
rgb = np.random.randint(0, 256, size=(len(vertices), 3), dtype=np.uint8)
288+
289+
# Convert to torch tensors
290+
file_pts = torch.tensor(xyz, dtype=torch.float32, device=self.device)
291+
file_rgb = torch.tensor(rgb, dtype=torch.uint8, device=self.device)
292+
293+
logger.info(f"Loaded {len(file_pts)} points from accumulated point cloud")
294+
295+
# Initialize using the same method as COLMAP
296+
assert file_rgb.dtype == torch.uint8, "Expecting RGB values to be in [0, 255] range"
297+
self.default_initialize_from_points(
298+
file_pts,
299+
observer_pts,
300+
file_rgb,
301+
use_observer_pts=self.conf.initialization.use_observation_points
302+
)
303+
255304
def init_from_pretrained_point_cloud(self, pc_path: str, set_optimizable_parameters: bool = True):
256305
data = PlyData.read(pc_path)
257306
num_gaussians = len(data["vertex"])

threedgrut/trainer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,13 @@ def setup_training(self, conf: DictConfig, model: MixtureOfGaussians, train_data
256256
train_dataset.get_observer_points(), dtype=torch.float32, device=self.device
257257
)
258258
model.init_from_colmap(conf.path, observer_points)
259+
case "accumulated_point_cloud":
260+
observer_points = torch.tensor(
261+
train_dataset.get_observer_points(), dtype=torch.float32, device=self.device
262+
)
263+
ply_path = conf.initialization.accumulated_point_cloud_path
264+
logger.info(f"Initializing from accumulated point cloud: {ply_path}")
265+
model.init_from_accumulated_point_cloud(ply_path, observer_points)
259266
case "point_cloud":
260267
try:
261268
ply_path = os.path.join(conf.path, "point_cloud.ply")

0 commit comments

Comments
 (0)