Skip to content

Commit 84cfa10

Browse files
committed
Adapt cusfm format for training
1 parent e55c114 commit 84cfa10

File tree

5 files changed

+78
-1
lines changed

5 files changed

+78
-1
lines changed

configs/apps/cusfm_3dgut_mcmc.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# @package _global_
2+
3+
# order in which configs override each other (/* - denotes a relative search path)
4+
defaults:
5+
- /base_mcmc
6+
- /dataset: colmap
7+
- /initialization: accumulated_point_cloud
8+
- /render: 3dgut
9+
- _self_
10+
11+
# overwrite of default parameters
12+
val_frequency: 999999 # never validate
13+
14+
initialization:
15+
accumulated_point_cloud_path: ??? # Path to accumulated point cloud PLY file to be provided by user at runtime
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
method: accumulated_point_cloud
2+
observation_scale_factor: 0.01
3+
use_observation_points: true
4+
accumulated_point_cloud_path: ??? # Path to accumulated point cloud PLY file
5+

threedgrut/datasets/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,8 @@ def read_colmap_extrinsics_text(path):
514514
with open(path, "r") as fid:
515515
# Skip comment lines and get valid lines
516516
lines = (line.strip() for line in fid)
517-
lines = (line for line in lines if line and not line.startswith("#"))
517+
# lines = (line for line in lines if line and not line.startswith("#"))
518+
lines = (line for line in lines if line == "" or not line.startswith("#"))
518519
# Process lines in pairs (image info + points info)
519520
try:
520521
while True:

threedgrut/model/model.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,55 @@ def init_from_colmap(self, root_path: str, observer_pts):
247247
self.default_initialize_from_points(file_pts, observer_pts, file_rgb,
248248
use_observer_pts=self.conf.initialization.use_observation_points)
249249

250+
def init_from_accumulated_point_cloud(self, pc_path: str, observer_pts):
251+
"""
252+
Initialize gaussians from an accumulated point cloud PLY file.
253+
Similar to init_from_colmap but loads from a given PLY file instead of sparse/0/points3D.txt
254+
255+
Args:
256+
pc_path: Path to the PLY point cloud file
257+
observer_pts: Observer points tensor for scale initialization
258+
"""
259+
logger.info(f"Loading accumulated point cloud from {pc_path}...")
260+
261+
# Read PLY file
262+
plydata = PlyData.read(pc_path)
263+
vertices = plydata['vertex']
264+
265+
# Extract XYZ coordinates
266+
xyz = np.stack([
267+
vertices['x'],
268+
vertices['y'],
269+
vertices['z']
270+
], axis=1).astype(np.float32)
271+
272+
# Extract RGB colors (check if they exist)
273+
if 'red' in vertices and 'green' in vertices and 'blue' in vertices:
274+
rgb = np.stack([
275+
vertices['red'],
276+
vertices['green'],
277+
vertices['blue']
278+
], axis=1).astype(np.uint8)
279+
else:
280+
# If no colors, initialize with random colors
281+
logger.warning("No RGB data found in point cloud, using random colors")
282+
rgb = np.random.randint(0, 256, size=(len(vertices), 3), dtype=np.uint8)
283+
284+
# Convert to torch tensors
285+
file_pts = torch.tensor(xyz, dtype=torch.float32, device=self.device)
286+
file_rgb = torch.tensor(rgb, dtype=torch.uint8, device=self.device)
287+
288+
logger.info(f"Loaded {len(file_pts)} points from accumulated point cloud")
289+
290+
# Initialize using the same method as COLMAP
291+
assert file_rgb.dtype == torch.uint8, "Expecting RGB values to be in [0, 255] range"
292+
self.default_initialize_from_points(
293+
file_pts,
294+
observer_pts,
295+
file_rgb,
296+
use_observer_pts=self.conf.initialization.use_observation_points
297+
)
298+
250299
def init_from_pretrained_point_cloud(self, pc_path: str, set_optimizable_parameters: bool = True):
251300
data = PlyData.read(pc_path)
252301
num_gaussians = len(data["vertex"])

threedgrut/trainer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,13 @@ def setup_training(self, conf: DictConfig, model: MixtureOfGaussians, train_data
252252
train_dataset.get_observer_points(), dtype=torch.float32, device=self.device
253253
)
254254
model.init_from_colmap(conf.path, observer_points)
255+
case "accumulated_point_cloud":
256+
observer_points = torch.tensor(
257+
train_dataset.get_observer_points(), dtype=torch.float32, device=self.device
258+
)
259+
ply_path = conf.initialization.accumulated_point_cloud_path
260+
logger.info(f"Initializing from accumulated point cloud: {ply_path}")
261+
model.init_from_accumulated_point_cloud(ply_path, observer_points)
255262
case "point_cloud":
256263
try:
257264
ply_path = os.path.join(conf.path, "point_cloud.ply")

0 commit comments

Comments
 (0)