Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-graph-jacobian.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ jobs:
PYTHONPATH: /tmp/pydeps:${{ github.workspace }}
run: |
python -c "import torch; print('torch', torch.__version__)"
python -m pytest -q tests/autograd/test_graph_jacobian.py
python -m pytest -q tests/autograd
73 changes: 73 additions & 0 deletions datapipes/bal_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os

import torch

DTYPE = torch.float64


def _rotvec_to_quat_xyzw(rotvec: torch.Tensor) -> torch.Tensor:
theta = torch.linalg.norm(rotvec, dim=-1, keepdim=True)
half_theta = 0.5 * theta
cos_half = torch.cos(half_theta)
sin_half = torch.sin(half_theta)

eps = 1e-12
scale = torch.where(theta > eps, sin_half / theta, 0.5 - (theta * theta) / 48.0)
xyz = rotvec * scale
return torch.cat([xyz, cos_half], dim=-1)


def read_bal_data(file_name: str, use_quat: bool = False) -> dict:
"""
Read a Bundle Adjustment in the Large dataset problem text file.

Format:
<num_cameras> <num_points> <num_observations>
<camera_index> <point_index> <x> <y> (repeated n_observations)
<camera parameters> (n_cameras * 9 lines)
<point parameters> (n_points * 3 lines)

Each camera has 9 parameters: Rodrigues rotvec (3), translation (3), f, k1, k2.
This loader outputs either:
- use_quat=False: [tx, ty, tz, rx, ry, rz, f, k1, k2] (9)
- use_quat=True: [tx, ty, tz, qx, qy, qz, qw, f, k1, k2] (10)
"""
with open(file_name, "r") as file:
n_cameras, n_points, n_observations = map(int, file.readline().split())

camera_indices = torch.empty(n_observations, dtype=torch.int64)
point_indices = torch.empty(n_observations, dtype=torch.int64)
points_2d = torch.empty((n_observations, 2), dtype=DTYPE)

for i in range(n_observations):
camera_index, point_index, x, y = file.readline().split()
camera_indices[i] = int(camera_index)
point_indices[i] = int(point_index)
points_2d[i, 0] = float(x)
points_2d[i, 1] = float(y)

camera_params = torch.empty(n_cameras * 9, dtype=DTYPE)
for i in range(n_cameras * 9):
camera_params[i] = float(file.readline())
camera_params = camera_params.reshape((n_cameras, 9))

points_3d = torch.empty(n_points * 3, dtype=DTYPE)
for i in range(n_points * 3):
points_3d[i] = float(file.readline())
points_3d = points_3d.reshape((n_points, 3))

if use_quat:
q = _rotvec_to_quat_xyzw(camera_params[:, :3])
camera_params = torch.cat([camera_params[:, 3:6], q, camera_params[:, 6:]], dim=1)
else:
camera_params = torch.cat([camera_params[:, 3:6], camera_params[:, :3], camera_params[:, 6:]], dim=1)

return {
"problem_name": os.path.splitext(os.path.basename(file_name))[0],
"camera_params": camera_params.to(DTYPE),
"points_3d": points_3d,
"points_2d": points_2d,
"camera_index_of_observations": camera_indices,
"point_index_of_observations": point_indices,
}

147 changes: 40 additions & 107 deletions datapipes/bal_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,24 @@
Link to the dataset: https://grail.cs.washington.edu/projects/bal/
"""

import torch, os, warnings
import numpy as np
import os
import warnings

import torch
from functools import partial
from operator import itemgetter, methodcaller
from bs4 import BeautifulSoup, MarkupResemblesLocatorWarning
from torchvision.transforms import Compose
from scipy.spatial.transform import Rotation
from torchdata.datapipes.iter import HttpReader, IterableWrapper, FileOpener
import pypose as pp
from operator import methodcaller

DTYPE = torch.float64
from .bal_io import DTYPE, read_bal_data

# ignore bs4 warning
warnings.filterwarnings("ignore", category=MarkupResemblesLocatorWarning)
def _torchdata():
try:
from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper
except ImportError as e:
raise ImportError(
"torchdata is required for datapipes.bal_loader streaming utilities. "
"If you only need parsing, import read_bal_data from datapipes.bal_io."
) from e
return HttpReader, IterableWrapper, FileOpener

# only export __all__
__ALL__ = ['build_pipeline', 'read_bal_data', 'DATA_URL', 'ALL_DATASETS']
Expand All @@ -46,8 +50,22 @@ def _not_none(s):

# extract problem file urls from the problem url
def _problem_lister(*problem_url, cache_dir):
HttpReader, IterableWrapper, FileOpener = _torchdata()
try:
from bs4 import BeautifulSoup, MarkupResemblesLocatorWarning
except ImportError as e:
raise ImportError(
"bs4 is required for datapipes.bal_loader streaming utilities. "
"If you only need parsing, import read_bal_data from datapipes.bal_io."
) from e

warnings.filterwarnings("ignore", category=MarkupResemblesLocatorWarning)

def _cache_path(url: str) -> str:
return os.path.join(cache_dir, os.path.basename(url))

problem_list_dp = IterableWrapper(problem_url).on_disk_cache(
filepath_fn=Compose([os.path.basename, partial(os.path.join, cache_dir)]),
filepath_fn=_cache_path,
)
problem_list_dp = HttpReader(problem_list_dp).end_caching(same_filepath_fn=True)

Expand All @@ -69,113 +87,28 @@ def _problem_lister(*problem_url, cache_dir):

# download and decompress the problem files
def _download_pipe(cache_dir, url_dp, suffix: str):
HttpReader, _, _ = _torchdata()

def _cache_path(url: str) -> str:
return os.path.join(cache_dir, os.path.basename(url))

def _strip_suffix(path: str) -> str:
return path.split(suffix)[0]

# cache compressed files
cache_compressed = url_dp.on_disk_cache(
filepath_fn=Compose([os.path.basename, partial(os.path.join, cache_dir)]) ,
filepath_fn=_cache_path,
)
cache_compressed = HttpReader(cache_compressed).end_caching(same_filepath_fn=True)
# cache decompressed files
cache_decompressed = cache_compressed.on_disk_cache(
filepath_fn=Compose([partial(str.split, sep=suffix), itemgetter(0)]),
filepath_fn=_strip_suffix,
)
cache_decompressed = cache_decompressed.open_files(mode="b").load_from_bz2().end_caching(
same_filepath_fn=True
)
return cache_decompressed

def read_bal_data(file_name: str, use_quat=False) -> dict:
"""
Read a Bundle Adjustment in the Large dataset.

Referenced Scipy's BAL loader: https://scipy-cookbook.readthedocs.io/items/bundle_adjustment.html

According to BAL official documentation, each problem is provided as a text file in the following format:

<num_cameras> <num_points> <num_observations>
<camera_index_1> <point_index_1> <x_1> <y_1>
...
<camera_index_num_observations> <point_index_num_observations> <x_num_observations> <y_num_observations>
<camera_1>
...
<camera_num_cameras>
<point_1>
...
<point_num_points>

Where, there camera and point indices start from 0. Each camera is a set of 9 parameters - R,t,f,k1 and k2. The rotation R is specified as a Rodrigues' vector.

Parameters
----------
file_name : str
The decompressed file of the dataset.

Returns
-------
dict
A dictionary containing the following fields:
- problem_name: str
The name of the problem.
- camera_params: torch.Tensor (n_cameras, 9 or 10)
contains camera parameters for each camera. If use_quat is True, the shape is (n_cameras, 10).
- points_3d: torch.Tensor (n_points, 3)
contains initial estimates of point coordinates in the world frame.
- points_2d: torch.Tensor (n_observations, 2)
contains measured 2-D coordinates of points projected on images in each observations.
- camera_index_of_observations: torch.Tensor (n_observations,)
contains indices of cameras (from 0 to n_cameras - 1) involved in each observation.
- point_index_of_observations: torch.Tensor (n_observations,)
contains indices of points (from 0 to n_points - 1) involved in each observation.
"""
with open(file_name, "r") as file:
n_cameras, n_points, n_observations = map(
int, file.readline().split())

camera_indices = torch.empty(n_observations, dtype=torch.int64)
point_indices = torch.empty(n_observations, dtype=torch.int64)
points_2d = torch.empty((n_observations, 2), dtype=DTYPE)

for i in range(n_observations):
tmp_line = file.readline()
camera_index, point_index, x, y = tmp_line.split()
camera_indices[i] = int(camera_index)
point_indices[i] = int(point_index)
points_2d[i, 0] = float(x)
points_2d[i, 1] = float(y)

camera_params = torch.empty(n_cameras * 9, dtype=DTYPE)
for i in range(n_cameras * 9):
camera_params[i] = float(file.readline())
camera_params = camera_params.reshape((n_cameras, -1))

points_3d = torch.empty(n_points * 3, dtype=DTYPE)
for i in range(n_points * 3):
points_3d[i] = float(file.readline())
points_3d = points_3d.reshape((n_points, -1))

if use_quat:
# convert Rodrigues vector to unit quaternion for camera rotation
# camera_params[0:3] is the Rodrigues vector
# after conversion, camera_params[0:4] is the unit quaternion
# r = Rotation.from_rotvec(camera_params[:, :3])
# q = r.as_quat()
r = pp.so3(camera_params[:, :3])
q = r.Exp()
# [tx, ty, tz, q0, q1, q2, q3, f, k1, k2]
camera_params = torch.cat([camera_params[:, 3:6], q, camera_params[:, 6:]], axis=1)
else:
camera_params = torch.cat([camera_params[:, 3:6], camera_params[:, :3], camera_params[:, 6:]], axis=1)

# convert camera_params to torch.Tensor
camera_params = torch.tensor(camera_params).to(DTYPE)

return {'problem_name': os.path.splitext(os.path.basename(file_name))[0], # str
'camera_params': camera_params, # torch.Tensor (n_cameras, 9 or 10)
'points_3d': points_3d, # torch.Tensor (n_points, 3)
'points_2d': points_2d, # torch.Tensor (n_observations, 2)
'camera_index_of_observations': camera_indices, # torch.Tensor (n_observations,)
'point_index_of_observations': point_indices, # torch.Tensor (n_observations,)
}

def build_pipeline(dataset='ladybug', cache_dir='bal_data', use_quat=False):
"""
Build a pipeline for the Bundle Adjustment in the Large dataset.
Expand Down
Loading