Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions examples/datasets/colmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,19 +182,32 @@ def __init__(
image_dir_suffix = ""
colmap_image_dir = os.path.join(data_dir, "images")
image_dir = os.path.join(data_dir, "images" + image_dir_suffix)
for d in [image_dir, colmap_image_dir]:
if not os.path.exists(d):
raise ValueError(f"Image folder {d} does not exist.")

# Check if original images folder exists
if not os.path.exists(colmap_image_dir):
raise ValueError(f"Image folder {colmap_image_dir} does not exist.")

# Downsampled images may have different names vs images used for COLMAP,
# so we need to map between the two sorted lists of files.
colmap_files = sorted(_get_rel_paths(colmap_image_dir))

# Handle image resizing if needed
if factor > 1 and not os.path.exists(image_dir):
# Check if we need to resize from JPG to PNG
if os.path.splitext(colmap_files[0])[1].lower() == ".jpg":
image_dir = _resize_image_folder(
colmap_image_dir, image_dir + "_png", factor=factor
)
else:
image_dir = _resize_image_folder(
colmap_image_dir, image_dir, factor=factor
)

# Check if the final image directory exists
if not os.path.exists(image_dir):
raise ValueError(f"Image folder {image_dir} does not exist.")

image_files = sorted(_get_rel_paths(image_dir))
if factor > 1 and os.path.splitext(image_files[0])[1].lower() == ".jpg":
image_dir = _resize_image_folder(
colmap_image_dir, image_dir + "_png", factor=factor
)
image_files = sorted(_get_rel_paths(image_dir))
colmap_to_image = dict(zip(colmap_files, image_files))
image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names]

Expand Down
43 changes: 33 additions & 10 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from gsplat.strategy import DefaultStrategy, MCMCStrategy
from gsplat_viewer import GsplatViewer, GsplatRenderTabState
from nerfview import CameraState, RenderTabState, apply_float_colormap

from datetime import datetime

@dataclass
class Config:
Expand All @@ -58,7 +58,7 @@ class Config:
# Directory to save results
result_dir: str = "results/garden"
# Every N images there is a test image
test_every: int = 8
test_every: int = 22
# Random crop size for training (experimental)
patch_size: Optional[int] = None
# A global scaler that applies to the scene size related parameters
Expand All @@ -79,13 +79,13 @@ class Config:
# Number of training steps
max_steps: int = 30_000
# Steps to evaluate the model
eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000, 50_000, 100_000, 150_000, 200_000])
# Steps to save the model
save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000, 50_000, 100_000, 150_000, 200_000])
# Whether to save ply file (storage size can be large)
save_ply: bool = False
# Steps to save the model as ply
ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000, 50_000, 100_000, 150_000, 200_000])
# Whether to disable video generation during training and evaluation
disable_video: bool = False

Expand Down Expand Up @@ -186,6 +186,7 @@ class Config:

# Whether use fused-bilateral grid
use_fused_bilagrid: bool = False
debug_vram: bool = False # Enable CUDA memory profiling

def adjust_steps(self, factor: float):
self.eval_steps = [int(i * factor) for i in self.eval_steps]
Expand Down Expand Up @@ -316,6 +317,11 @@ def __init__(
self.device = f"cuda:{local_rank}"

# Where to dump results.
if os.path.exists(cfg.result_dir) and os.listdir(cfg.result_dir):
raise ValueError(
f"Result directory '{cfg.result_dir}' already exists and is not empty. "
"Please specify a different directory or remove/rename the existing one."
)
os.makedirs(cfg.result_dir, exist_ok=True)

# Setup output directories.
Expand Down Expand Up @@ -736,12 +742,14 @@ def train(self):
# )

if world_rank == 0 and cfg.tb_every > 0 and step % cfg.tb_every == 0:
mem = torch.cuda.max_memory_allocated() / 1024**3
allocated = torch.cuda.max_memory_allocated() / 1024**2
reserved = torch.cuda.max_memory_reserved() / 1024**2
self.writer.add_scalar("train/loss", loss.item(), step)
self.writer.add_scalar("train/l1loss", l1loss.item(), step)
self.writer.add_scalar("train/ssimloss", ssimloss.item(), step)
self.writer.add_scalar("train/num_GS", len(self.splats["means"]), step)
self.writer.add_scalar("train/mem", mem, step)
self.writer.add_scalar("train/mem_allocated", allocated, step)
self.writer.add_scalar("train/mem_reserved", reserved, step)
if cfg.depth_loss:
self.writer.add_scalar("train/depthloss", depthloss.item(), step)
if cfg.use_bilateral_grid:
Expand Down Expand Up @@ -884,7 +892,7 @@ def train(self):
# eval the full set
if step in [i - 1 for i in cfg.eval_steps]:
self.eval(step)
self.render_traj(step)
# self.render_traj(step)

# run compression
if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]:
Expand All @@ -902,7 +910,7 @@ def train(self):
)
# Update the scene.
self.viewer.update(step, num_train_rays_per_step)

@torch.no_grad()
def eval(self, step: int, stage: str = "val"):
"""Entry for evaluation."""
Expand Down Expand Up @@ -1177,11 +1185,26 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config):
runner.splats[k].data = torch.cat([ckpt["splats"][k] for ckpt in ckpts])
step = ckpts[0]["step"]
runner.eval(step=step)
runner.render_traj(step=step)
# runner.render_traj(step=step)
if cfg.compression is not None:
runner.run_compression(step=step)
else:
if cfg.debug_vram:
torch.cuda.memory._record_memory_history(
max_entries=100_000
)

runner.train()

if cfg.debug_vram:
snapshot_path = os.path.join(
cfg.result_dir,
f"vram_snapshot_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pickle"
)
torch.cuda.memory._dump_snapshot(snapshot_path)
torch.cuda.memory._record_memory_history(enabled=None) # Stop recording
print(f"VRAM snapshot saved to: {snapshot_path}")


if not cfg.disable_viewer:
runner.viewer.complete()
Expand Down