-
Notifications
You must be signed in to change notification settings - Fork 625
Open
Description
The SceneManager of COLMAP's Parser is no longer present in the new version of pycolmap. I tried rewriting the Parser using pycolmap and Reconstruction. It worked fine in tests using simple_trainer.py, and the depth_loss also trained normally.
The new parser is as follows:
import pycolmap
import json
import os
from typing import Any, Dict, List, Optional
import cv2
import imageio.v2 as imageio
import numpy as np
import torch
from tqdm import tqdm
from typing_extensions import assert_never
from concurrent.futures import ThreadPoolExecutor
from .normalize import (
align_principal_axes,
similarity_from_cameras,
transform_cameras,
transform_points,
)
def _get_rel_paths(path_dir: str) -> List[str]:
"""Recursively get relative paths of files in a directory."""
paths = []
for dp, dn, fn in os.walk(path_dir):
for f in fn:
paths.append(os.path.relpath(os.path.join(dp, f), path_dir))
return paths
def _resize_image_folder(
image_dir: str, resized_dir: str, factor: int, num_threads: int = None
) -> str:
"""
Resize all images in a folder using OpenCV (highest quality) + multithreading.
Args:
image_dir (str): Original image folder
resized_dir (str): Output folder
factor (int): Downscale factor
num_threads (int or None): Number of threads, default 8
"""
if num_threads is None:
num_threads = 8
print(
f"Downscaling images by {factor}x from {image_dir} to {resized_dir} using {num_threads} threads."
)
os.makedirs(resized_dir, exist_ok=True)
image_files = _get_rel_paths(image_dir)
def process_one(image_file):
in_path = os.path.join(image_dir, image_file)
out_path = os.path.join(resized_dir, os.path.splitext(image_file)[0] + ".png")
if os.path.isfile(out_path):
return # skip existing
image = cv2.imread(in_path, cv2.IMREAD_UNCHANGED)
if image is None:
return
if image.shape[-1] > 3:
image = image[..., :3]
h, w = image.shape[:2]
new_w, new_h = int(round(w / factor)), int(round(h / factor))
resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
cv2.imwrite(out_path, resized)
with ThreadPoolExecutor(max_workers=num_threads) as ex:
list(tqdm(ex.map(process_one, image_files), total=len(image_files)))
return resized_dir
class Parser:
"""COLMAP Parser."""
def __init__(
self,
data_dir: str,
factor: int = 1,
normalize: bool = False,
test_every: int = 8,
):
self.data_dir = data_dir
self.factor = factor
self.normalize = normalize
self.test_every = test_every
colmap_dir = os.path.join(data_dir, "sparse/0")
if not os.path.exists(colmap_dir):
colmap_dir = os.path.join(data_dir, "sparse")
assert os.path.exists(
colmap_dir,
), f"COLMAP directory {colmap_dir} does not exist."
reconstruction = pycolmap.Reconstruction(colmap_dir)
imdata = list(reconstruction.images.values())
w2c_mats = []
camera_ids = []
Ks_dict = dict()
params_dict = dict()
imsize_dict = dict() # width, height
mask_dict = dict()
bottom = np.array([0, 0, 0, 1]).reshape(1, 4)
for img in imdata:
rot = img.cam_from_world().rotation.matrix()
trans = img.cam_from_world().translation.reshape(3, 1)
w2c_mat = np.concatenate(
[np.concatenate([rot, trans], axis=1), bottom], axis=0
)
w2c_mats.append(w2c_mat)
# support different camera intrinsics
camera_id = img.camera_id
camera_ids.append(camera_id)
cam = reconstruction.cameras[camera_id]
fx, fy, cx, cy = cam.params
K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
K[:2, :] /= factor
Ks_dict[camera_id] = K
type_ = cam.model.name
# ID 0: SIMPLE_PINHOLE (f, cx, cy)
if type_ == 0 or type_ == "SIMPLE_PINHOLE":
params = np.empty(0, dtype=np.float32)
camtype = "perspective"
# ID 1: PINHOLE (fx, fy, cx, cy)
elif type_ == 1 or type_ == "PINHOLE":
params = np.empty(0, dtype=np.float32)
camtype = "perspective"
# ID 2: SIMPLE_RADIAL (f, cx, cy, k1)
elif type_ == 2 or type_ == "SIMPLE_RADIAL":
params = np.array([cam.params[-1], 0.0, 0.0, 0.0], dtype=np.float32)
camtype = "perspective"
# ID 3: RADIAL (fx, cx, cy, k1, k2)
elif type_ == 3 or type_ == "RADIAL":
params = np.array([cam.params[3], 0.0, 0.0, 0.0], dtype=np.float32)
camtype = "perspective"
# ID 4: OPENCV (fx, fy, cx, cy, k1, k2, p1, p2)
elif type_ == 4 or type_ == "OPENCV":
params = np.array(cam.params[4:], dtype=np.float32)
camtype = "perspective"
# ID 5: OPENCV_FISHEYE (fx, fy, cx, cy, k1, k2, k3, k4)
elif type_ == 5 or type_ == "OPENCV_FISHEYE":
params = np.array(cam.params[4:], dtype=np.float32)
camtype = "fisheye"
params_dict[camera_id] = params
imsize_dict[camera_id] = (cam.width // factor, cam.height // factor)
mask_dict[camera_id] = None
assert camtype == "perspective" or camtype == "fisheye", (
f"Only perspective and fisheye cameras are supported, got {type_}"
)
print(
f"[Parser] {len(imdata)} images, taken by {len(set(camera_ids))} cameras."
)
if len(imdata) == 0:
raise ValueError("No images found in COLMAP.")
if not (
type_ == 0 or type_ == 1 or type_ == "SIMPLE_PINHOLE" or type_ == "PINHOLE"
):
print("Warning: COLMAP Camera is not PINHOLE. Images have distortion.")
w2c_mats = np.stack(w2c_mats, axis=0)
# Convert extrinsics to camera-to-world.
camtoworlds = np.linalg.inv(w2c_mats)
# Image names from COLMAP. No need for permuting the poses according to
# image names anymore.
image_names = [k.name for k in imdata]
# Previous Nerf results were generated with images sorted by filename,
# ensure metrics are reported on the same test set.
inds = np.argsort(image_names)
image_names = [image_names[i] for i in inds]
camtoworlds = camtoworlds[inds]
camera_ids = [camera_ids[i] for i in inds]
# Load extended metadata. Used by Bilarf dataset.
self.extconf = {
"spiral_radius_scale": 1.0,
"no_factor_suffix": False,
}
extconf_file = os.path.join(data_dir, "ext_metadata.json")
if os.path.exists(extconf_file):
with open(extconf_file) as f:
self.extconf.update(json.load(f))
# Load bounds if possible (only used in forward facing scenes).
self.bounds = np.array([0.01, 1.0])
posefile = os.path.join(data_dir, "poses_bounds.npy")
if os.path.exists(posefile):
self.bounds = np.load(posefile)[:, -2:]
# Load images.
if factor > 1 and not self.extconf["no_factor_suffix"]:
image_dir_suffix = f"_{factor}"
else:
image_dir_suffix = ""
colmap_image_dir = os.path.join(data_dir, "images")
image_dir = os.path.join(data_dir, "images" + image_dir_suffix)
for d in [image_dir, colmap_image_dir]:
if not os.path.exists(d):
raise ValueError(f"Image folder {d} does not exist.")
# Downsampled images may have different names vs images used for COLMAP,
# so we need to map between the two sorted lists of files.
colmap_files = sorted(_get_rel_paths(colmap_image_dir))
image_files = sorted(_get_rel_paths(image_dir))
if factor > 1 and os.path.splitext(image_files[0])[1].lower() == ".jpg":
image_dir = _resize_image_folder(
colmap_image_dir, image_dir + "_png", factor=factor
)
image_files = sorted(_get_rel_paths(image_dir))
colmap_to_image = dict(zip(colmap_files, image_files))
image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names]
# 3D points and {image_name -> [point_idx]}
points = np.vstack(
[pt.xyz for pt in reconstruction.points3D.values()], dtype=np.float32
)
points_rgb = np.vstack(
[pt.color for pt in reconstruction.points3D.values()], dtype=np.float32
).astype(np.uint8)
points_err = np.array(
[pt.error for pt in reconstruction.points3D.values()], dtype=np.float32
)
point_indices = dict()
# image_id -> image_name
# 遍历重建中所有的图像对象
image_id_to_name = {
img_id: image.name for img_id, image in reconstruction.images.items()
}
# point_id -> point_idx
point_id_idx = {
point_id: idx
for idx, (point_id, _) in enumerate(reconstruction.points3D.items())
}
# image_name -> point_idx
for point_id, point3D in reconstruction.points3D.items():
for track_element in point3D.track.elements:
image_id = track_element.image_id
image_name = image_id_to_name[image_id]
point_idx = point_id_idx[point_id]
point_indices.setdefault(image_name, []).append(point_idx)
# 3D points and {image_name -> [point_idx]}
point_indices = {
k: np.array(v).astype(np.int32) for k, v in point_indices.items()
}
# Normalize the world space
if normalize:
T1 = similarity_from_cameras(camtoworlds)
camtoworlds = transform_cameras(T1, camtoworlds)
points = transform_points(T1, points)
T2 = align_principal_axes(points)
camtoworlds = transform_cameras(T2, camtoworlds)
points = transform_points(T2, points)
transform = T2 @ T1
# Fix for up side down. We assume more points towards
# the bottom of the scene which is true when ground floor is
# present in the images.
if np.median(points[:, 2]) > np.mean(points[:, 2]):
# rotate 180 degrees around x axis such that z is flipped
T3 = np.array(
[
[1.0, 0.0, 0.0, 0.0],
[0.0, -1.0, 0.0, 0.0],
[0.0, 0.0, -1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
)
camtoworlds = transform_cameras(T3, camtoworlds)
points = transform_points(T3, points)
transform = T3 @ transform
else:
transform = np.eye(4)
self.image_names = image_names # List[str], (num_images,)
self.image_paths = image_paths # List[str], (num_images,)
self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4)
self.camera_ids = camera_ids # List[int], (num_images,)
self.Ks_dict = Ks_dict # Dict of camera_id -> K
self.params_dict = params_dict # Dict of camera_id -> params
self.imsize_dict = imsize_dict # Dict of camera_id -> (width, height)
self.mask_dict = mask_dict # Dict of camera_id -> mask
self.points = points # np.ndarray, (num_points, 3)
self.points_err = points_err # np.ndarray, (num_points,)
self.points_rgb = points_rgb # np.ndarray, (num_points, 3)
self.point_indices = point_indices # Dict[str, np.ndarray], image_name -> [M,]
self.transform = transform # np.ndarray, (4, 4)
# load one image to check the size. In the case of tanksandtemples dataset, the
# intrinsics stored in COLMAP corresponds to 2x upsampled images.
actual_image = imageio.imread(self.image_paths[0])[..., :3]
actual_height, actual_width = actual_image.shape[:2]
colmap_width, colmap_height = self.imsize_dict[self.camera_ids[0]]
s_height, s_width = actual_height / colmap_height, actual_width / colmap_width
for camera_id, K in self.Ks_dict.items():
K[0, :] *= s_width
K[1, :] *= s_height
self.Ks_dict[camera_id] = K
width, height = self.imsize_dict[camera_id]
self.imsize_dict[camera_id] = (int(width * s_width), int(height * s_height))
# undistortion
self.mapx_dict = dict()
self.mapy_dict = dict()
self.roi_undist_dict = dict()
for camera_id in self.params_dict.keys():
params = self.params_dict[camera_id]
if len(params) == 0:
continue # no distortion
assert camera_id in self.Ks_dict, f"Missing K for camera {camera_id}"
assert camera_id in self.params_dict, (
f"Missing params for camera {camera_id}"
)
K = self.Ks_dict[camera_id]
width, height = self.imsize_dict[camera_id]
if camtype == "perspective":
K_undist, roi_undist = cv2.getOptimalNewCameraMatrix(
K, params, (width, height), 0
)
mapx, mapy = cv2.initUndistortRectifyMap(
K, params, None, K_undist, (width, height), cv2.CV_32FC1
)
mask = None
elif camtype == "fisheye":
fx = K[0, 0]
fy = K[1, 1]
cx = K[0, 2]
cy = K[1, 2]
grid_x, grid_y = np.meshgrid(
np.arange(width, dtype=np.float32),
np.arange(height, dtype=np.float32),
indexing="xy",
)
x1 = (grid_x - cx) / fx
y1 = (grid_y - cy) / fy
theta = np.sqrt(x1**2 + y1**2)
r = (
1.0
+ params[0] * theta**2
+ params[1] * theta**4
+ params[2] * theta**6
+ params[3] * theta**8
)
mapx = (fx * x1 * r + width // 2).astype(np.float32)
mapy = (fy * y1 * r + height // 2).astype(np.float32)
# Use mask to define ROI
mask = np.logical_and(
np.logical_and(mapx > 0, mapy > 0),
np.logical_and(mapx < width - 1, mapy < height - 1),
)
y_indices, x_indices = np.nonzero(mask)
y_min, y_max = y_indices.min(), y_indices.max() + 1
x_min, x_max = x_indices.min(), x_indices.max() + 1
mask = mask[y_min:y_max, x_min:x_max]
K_undist = K.copy()
K_undist[0, 2] -= x_min
K_undist[1, 2] -= y_min
roi_undist = [x_min, y_min, x_max - x_min, y_max - y_min]
else:
assert_never(camtype)
self.mapx_dict[camera_id] = mapx
self.mapy_dict[camera_id] = mapy
self.Ks_dict[camera_id] = K_undist
self.roi_undist_dict[camera_id] = roi_undist
self.imsize_dict[camera_id] = (roi_undist[2], roi_undist[3])
self.mask_dict[camera_id] = mask
# size of the scene measured by cameras
camera_locations = camtoworlds[:, :3, 3]
scene_center = np.mean(camera_locations, axis=0)
dists = np.linalg.norm(camera_locations - scene_center, axis=1)
self.scene_scale = np.max(dists)
class Dataset:
"""A simple dataset class."""
def __init__(
self,
parser: Parser,
split: str = "train",
patch_size: Optional[int] = None,
load_depths: bool = False,
):
self.parser = parser
self.split = split
self.patch_size = patch_size
self.load_depths = load_depths
indices = np.arange(len(self.parser.image_names))
if split == "train":
self.indices = indices[indices % self.parser.test_every != 0]
else:
self.indices = indices[indices % self.parser.test_every == 0]
def __len__(self):
return len(self.indices)
def __getitem__(self, item: int) -> Dict[str, Any]:
index = self.indices[item]
image = imageio.imread(self.parser.image_paths[index])[..., :3]
camera_id = self.parser.camera_ids[index]
K = self.parser.Ks_dict[camera_id].copy() # undistorted K
params = self.parser.params_dict[camera_id]
camtoworlds = self.parser.camtoworlds[index]
mask = self.parser.mask_dict[camera_id]
if len(params) > 0:
# Images are distorted. Undistort them.
mapx, mapy = (
self.parser.mapx_dict[camera_id],
self.parser.mapy_dict[camera_id],
)
image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR)
x, y, w, h = self.parser.roi_undist_dict[camera_id]
image = image[y : y + h, x : x + w]
if self.patch_size is not None:
# Random crop.
h, w = image.shape[:2]
x = np.random.randint(0, max(w - self.patch_size, 1))
y = np.random.randint(0, max(h - self.patch_size, 1))
image = image[y : y + self.patch_size, x : x + self.patch_size]
K[0, 2] -= x
K[1, 2] -= y
data = {
"K": torch.from_numpy(K).float(),
"camtoworld": torch.from_numpy(camtoworlds).float(),
"image": torch.from_numpy(image).float(),
"image_id": item, # the index of the image in the dataset
}
if mask is not None:
data["mask"] = torch.from_numpy(mask).bool()
if self.load_depths:
# projected points to image plane to get depths
worldtocams = np.linalg.inv(camtoworlds)
image_name = self.parser.image_names[index]
point_indices = self.parser.point_indices[image_name]
points_world = self.parser.points[point_indices]
points_cam = (worldtocams[:3, :3] @ points_world.T + worldtocams[:3, 3:4]).T
points_proj = (K @ points_cam.T).T
points = points_proj[:, :2] / points_proj[:, 2:3] # (M, 2)
depths = points_cam[:, 2] # (M,)
# filter out points outside the image
selector = (
(points[:, 0] >= 0)
& (points[:, 0] < image.shape[1])
& (points[:, 1] >= 0)
& (points[:, 1] < image.shape[0])
& (depths > 0)
)
points = points[selector]
depths = depths[selector]
data["points"] = torch.from_numpy(points).float()
data["depths"] = torch.from_numpy(depths).float()
return data
if __name__ == "__main__":
import argparse
import imageio.v2 as imageio
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", type=str, default="data/360_v2/garden")
parser.add_argument("--factor", type=int, default=4)
args = parser.parse_args()
# Parse COLMAP data.
parser = Parser(
data_dir=args.data_dir, factor=args.factor, normalize=True, test_every=8
)
dataset = Dataset(parser, split="train", load_depths=True)
print(f"Dataset: {len(dataset)} images.")
writer = imageio.get_writer("results/points.mp4", fps=30)
for data in tqdm(dataset, desc="Plotting points"):
image = data["image"].numpy().astype(np.uint8)
points = data["points"].numpy()
depths = data["depths"].numpy()
for x, y in points:
cv2.circle(image, (int(x), int(y)), 2, (255, 0, 0), -1)
writer.append_data(image)
writer.close()Metadata
Metadata
Assignees
Labels
No labels