Skip to content

Commit 0b95e90

Browse files
authored
Add BAL samples for JTJ diagonal test (#12)
* [refactor] move read_bal_data to a separate module. * local file system test * [test] fetch from uw * [ci] rename and hook up * add negative test
1 parent 3b5dba4 commit 0b95e90

File tree

4 files changed

+411
-108
lines changed

4 files changed

+411
-108
lines changed

.github/workflows/test-graph-jacobian.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,4 @@ jobs:
3636
PYTHONPATH: /tmp/pydeps:${{ github.workspace }}
3737
run: |
3838
python -c "import torch; print('torch', torch.__version__)"
39-
python -m pytest -q tests/autograd/test_graph_jacobian.py
39+
python -m pytest -q tests/autograd

datapipes/bal_io.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import os
2+
3+
import torch
4+
5+
DTYPE = torch.float64
6+
7+
8+
def _rotvec_to_quat_xyzw(rotvec: torch.Tensor) -> torch.Tensor:
9+
theta = torch.linalg.norm(rotvec, dim=-1, keepdim=True)
10+
half_theta = 0.5 * theta
11+
cos_half = torch.cos(half_theta)
12+
sin_half = torch.sin(half_theta)
13+
14+
eps = 1e-12
15+
scale = torch.where(theta > eps, sin_half / theta, 0.5 - (theta * theta) / 48.0)
16+
xyz = rotvec * scale
17+
return torch.cat([xyz, cos_half], dim=-1)
18+
19+
20+
def read_bal_data(file_name: str, use_quat: bool = False) -> dict:
21+
"""
22+
Read a Bundle Adjustment in the Large dataset problem text file.
23+
24+
Format:
25+
<num_cameras> <num_points> <num_observations>
26+
<camera_index> <point_index> <x> <y> (repeated n_observations)
27+
<camera parameters> (n_cameras * 9 lines)
28+
<point parameters> (n_points * 3 lines)
29+
30+
Each camera has 9 parameters: Rodrigues rotvec (3), translation (3), f, k1, k2.
31+
This loader outputs either:
32+
- use_quat=False: [tx, ty, tz, rx, ry, rz, f, k1, k2] (9)
33+
- use_quat=True: [tx, ty, tz, qx, qy, qz, qw, f, k1, k2] (10)
34+
"""
35+
with open(file_name, "r") as file:
36+
n_cameras, n_points, n_observations = map(int, file.readline().split())
37+
38+
camera_indices = torch.empty(n_observations, dtype=torch.int64)
39+
point_indices = torch.empty(n_observations, dtype=torch.int64)
40+
points_2d = torch.empty((n_observations, 2), dtype=DTYPE)
41+
42+
for i in range(n_observations):
43+
camera_index, point_index, x, y = file.readline().split()
44+
camera_indices[i] = int(camera_index)
45+
point_indices[i] = int(point_index)
46+
points_2d[i, 0] = float(x)
47+
points_2d[i, 1] = float(y)
48+
49+
camera_params = torch.empty(n_cameras * 9, dtype=DTYPE)
50+
for i in range(n_cameras * 9):
51+
camera_params[i] = float(file.readline())
52+
camera_params = camera_params.reshape((n_cameras, 9))
53+
54+
points_3d = torch.empty(n_points * 3, dtype=DTYPE)
55+
for i in range(n_points * 3):
56+
points_3d[i] = float(file.readline())
57+
points_3d = points_3d.reshape((n_points, 3))
58+
59+
if use_quat:
60+
q = _rotvec_to_quat_xyzw(camera_params[:, :3])
61+
camera_params = torch.cat([camera_params[:, 3:6], q, camera_params[:, 6:]], dim=1)
62+
else:
63+
camera_params = torch.cat([camera_params[:, 3:6], camera_params[:, :3], camera_params[:, 6:]], dim=1)
64+
65+
return {
66+
"problem_name": os.path.splitext(os.path.basename(file_name))[0],
67+
"camera_params": camera_params.to(DTYPE),
68+
"points_3d": points_3d,
69+
"points_2d": points_2d,
70+
"camera_index_of_observations": camera_indices,
71+
"point_index_of_observations": point_indices,
72+
}
73+

datapipes/bal_loader.py

Lines changed: 40 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,24 @@
99
Link to the dataset: https://grail.cs.washington.edu/projects/bal/
1010
"""
1111

12-
import torch, os, warnings
13-
import numpy as np
12+
import os
13+
import warnings
14+
15+
import torch
1416
from functools import partial
15-
from operator import itemgetter, methodcaller
16-
from bs4 import BeautifulSoup, MarkupResemblesLocatorWarning
17-
from torchvision.transforms import Compose
18-
from scipy.spatial.transform import Rotation
19-
from torchdata.datapipes.iter import HttpReader, IterableWrapper, FileOpener
20-
import pypose as pp
17+
from operator import methodcaller
2118

22-
DTYPE = torch.float64
19+
from .bal_io import DTYPE, read_bal_data
2320

24-
# ignore bs4 warning
25-
warnings.filterwarnings("ignore", category=MarkupResemblesLocatorWarning)
21+
def _torchdata():
22+
try:
23+
from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper
24+
except ImportError as e:
25+
raise ImportError(
26+
"torchdata is required for datapipes.bal_loader streaming utilities. "
27+
"If you only need parsing, import read_bal_data from datapipes.bal_io."
28+
) from e
29+
return HttpReader, IterableWrapper, FileOpener
2630

2731
# only export __all__
2832
__ALL__ = ['build_pipeline', 'read_bal_data', 'DATA_URL', 'ALL_DATASETS']
@@ -46,8 +50,22 @@ def _not_none(s):
4650

4751
# extract problem file urls from the problem url
4852
def _problem_lister(*problem_url, cache_dir):
53+
HttpReader, IterableWrapper, FileOpener = _torchdata()
54+
try:
55+
from bs4 import BeautifulSoup, MarkupResemblesLocatorWarning
56+
except ImportError as e:
57+
raise ImportError(
58+
"bs4 is required for datapipes.bal_loader streaming utilities. "
59+
"If you only need parsing, import read_bal_data from datapipes.bal_io."
60+
) from e
61+
62+
warnings.filterwarnings("ignore", category=MarkupResemblesLocatorWarning)
63+
64+
def _cache_path(url: str) -> str:
65+
return os.path.join(cache_dir, os.path.basename(url))
66+
4967
problem_list_dp = IterableWrapper(problem_url).on_disk_cache(
50-
filepath_fn=Compose([os.path.basename, partial(os.path.join, cache_dir)]),
68+
filepath_fn=_cache_path,
5169
)
5270
problem_list_dp = HttpReader(problem_list_dp).end_caching(same_filepath_fn=True)
5371

@@ -69,113 +87,28 @@ def _problem_lister(*problem_url, cache_dir):
6987

7088
# download and decompress the problem files
7189
def _download_pipe(cache_dir, url_dp, suffix: str):
90+
HttpReader, _, _ = _torchdata()
91+
92+
def _cache_path(url: str) -> str:
93+
return os.path.join(cache_dir, os.path.basename(url))
94+
95+
def _strip_suffix(path: str) -> str:
96+
return path.split(suffix)[0]
97+
7298
# cache compressed files
7399
cache_compressed = url_dp.on_disk_cache(
74-
filepath_fn=Compose([os.path.basename, partial(os.path.join, cache_dir)]) ,
100+
filepath_fn=_cache_path,
75101
)
76102
cache_compressed = HttpReader(cache_compressed).end_caching(same_filepath_fn=True)
77103
# cache decompressed files
78104
cache_decompressed = cache_compressed.on_disk_cache(
79-
filepath_fn=Compose([partial(str.split, sep=suffix), itemgetter(0)]),
105+
filepath_fn=_strip_suffix,
80106
)
81107
cache_decompressed = cache_decompressed.open_files(mode="b").load_from_bz2().end_caching(
82108
same_filepath_fn=True
83109
)
84110
return cache_decompressed
85111

86-
def read_bal_data(file_name: str, use_quat=False) -> dict:
87-
"""
88-
Read a Bundle Adjustment in the Large dataset.
89-
90-
Referenced Scipy's BAL loader: https://scipy-cookbook.readthedocs.io/items/bundle_adjustment.html
91-
92-
According to BAL official documentation, each problem is provided as a text file in the following format:
93-
94-
<num_cameras> <num_points> <num_observations>
95-
<camera_index_1> <point_index_1> <x_1> <y_1>
96-
...
97-
<camera_index_num_observations> <point_index_num_observations> <x_num_observations> <y_num_observations>
98-
<camera_1>
99-
...
100-
<camera_num_cameras>
101-
<point_1>
102-
...
103-
<point_num_points>
104-
105-
Where, there camera and point indices start from 0. Each camera is a set of 9 parameters - R,t,f,k1 and k2. The rotation R is specified as a Rodrigues' vector.
106-
107-
Parameters
108-
----------
109-
file_name : str
110-
The decompressed file of the dataset.
111-
112-
Returns
113-
-------
114-
dict
115-
A dictionary containing the following fields:
116-
- problem_name: str
117-
The name of the problem.
118-
- camera_params: torch.Tensor (n_cameras, 9 or 10)
119-
contains camera parameters for each camera. If use_quat is True, the shape is (n_cameras, 10).
120-
- points_3d: torch.Tensor (n_points, 3)
121-
contains initial estimates of point coordinates in the world frame.
122-
- points_2d: torch.Tensor (n_observations, 2)
123-
contains measured 2-D coordinates of points projected on images in each observations.
124-
- camera_index_of_observations: torch.Tensor (n_observations,)
125-
contains indices of cameras (from 0 to n_cameras - 1) involved in each observation.
126-
- point_index_of_observations: torch.Tensor (n_observations,)
127-
contains indices of points (from 0 to n_points - 1) involved in each observation.
128-
"""
129-
with open(file_name, "r") as file:
130-
n_cameras, n_points, n_observations = map(
131-
int, file.readline().split())
132-
133-
camera_indices = torch.empty(n_observations, dtype=torch.int64)
134-
point_indices = torch.empty(n_observations, dtype=torch.int64)
135-
points_2d = torch.empty((n_observations, 2), dtype=DTYPE)
136-
137-
for i in range(n_observations):
138-
tmp_line = file.readline()
139-
camera_index, point_index, x, y = tmp_line.split()
140-
camera_indices[i] = int(camera_index)
141-
point_indices[i] = int(point_index)
142-
points_2d[i, 0] = float(x)
143-
points_2d[i, 1] = float(y)
144-
145-
camera_params = torch.empty(n_cameras * 9, dtype=DTYPE)
146-
for i in range(n_cameras * 9):
147-
camera_params[i] = float(file.readline())
148-
camera_params = camera_params.reshape((n_cameras, -1))
149-
150-
points_3d = torch.empty(n_points * 3, dtype=DTYPE)
151-
for i in range(n_points * 3):
152-
points_3d[i] = float(file.readline())
153-
points_3d = points_3d.reshape((n_points, -1))
154-
155-
if use_quat:
156-
# convert Rodrigues vector to unit quaternion for camera rotation
157-
# camera_params[0:3] is the Rodrigues vector
158-
# after conversion, camera_params[0:4] is the unit quaternion
159-
# r = Rotation.from_rotvec(camera_params[:, :3])
160-
# q = r.as_quat()
161-
r = pp.so3(camera_params[:, :3])
162-
q = r.Exp()
163-
# [tx, ty, tz, q0, q1, q2, q3, f, k1, k2]
164-
camera_params = torch.cat([camera_params[:, 3:6], q, camera_params[:, 6:]], axis=1)
165-
else:
166-
camera_params = torch.cat([camera_params[:, 3:6], camera_params[:, :3], camera_params[:, 6:]], axis=1)
167-
168-
# convert camera_params to torch.Tensor
169-
camera_params = torch.tensor(camera_params).to(DTYPE)
170-
171-
return {'problem_name': os.path.splitext(os.path.basename(file_name))[0], # str
172-
'camera_params': camera_params, # torch.Tensor (n_cameras, 9 or 10)
173-
'points_3d': points_3d, # torch.Tensor (n_points, 3)
174-
'points_2d': points_2d, # torch.Tensor (n_observations, 2)
175-
'camera_index_of_observations': camera_indices, # torch.Tensor (n_observations,)
176-
'point_index_of_observations': point_indices, # torch.Tensor (n_observations,)
177-
}
178-
179112
def build_pipeline(dataset='ladybug', cache_dir='bal_data', use_quat=False):
180113
"""
181114
Build a pipeline for the Bundle Adjustment in the Large dataset.

0 commit comments

Comments
 (0)