99Link to the dataset: https://grail.cs.washington.edu/projects/bal/
1010"""
1111
12- import torch , os , warnings
13- import numpy as np
12+ import os
13+ import warnings
14+
15+ import torch
1416from functools import partial
15- from operator import itemgetter , methodcaller
16- from bs4 import BeautifulSoup , MarkupResemblesLocatorWarning
17- from torchvision .transforms import Compose
18- from scipy .spatial .transform import Rotation
19- from torchdata .datapipes .iter import HttpReader , IterableWrapper , FileOpener
20- import pypose as pp
17+ from operator import methodcaller
2118
22- DTYPE = torch . float64
19+ from . bal_io import DTYPE , read_bal_data
2320
24- # ignore bs4 warning
25- warnings .filterwarnings ("ignore" , category = MarkupResemblesLocatorWarning )
21+ def _torchdata ():
22+ try :
23+ from torchdata .datapipes .iter import FileOpener , HttpReader , IterableWrapper
24+ except ImportError as e :
25+ raise ImportError (
26+ "torchdata is required for datapipes.bal_loader streaming utilities. "
27+ "If you only need parsing, import read_bal_data from datapipes.bal_io."
28+ ) from e
29+ return HttpReader , IterableWrapper , FileOpener
2630
2731# only export __all__
2832__ALL__ = ['build_pipeline' , 'read_bal_data' , 'DATA_URL' , 'ALL_DATASETS' ]
@@ -46,8 +50,22 @@ def _not_none(s):
4650
4751# extract problem file urls from the problem url
4852def _problem_lister (* problem_url , cache_dir ):
53+ HttpReader , IterableWrapper , FileOpener = _torchdata ()
54+ try :
55+ from bs4 import BeautifulSoup , MarkupResemblesLocatorWarning
56+ except ImportError as e :
57+ raise ImportError (
58+ "bs4 is required for datapipes.bal_loader streaming utilities. "
59+ "If you only need parsing, import read_bal_data from datapipes.bal_io."
60+ ) from e
61+
62+ warnings .filterwarnings ("ignore" , category = MarkupResemblesLocatorWarning )
63+
64+ def _cache_path (url : str ) -> str :
65+ return os .path .join (cache_dir , os .path .basename (url ))
66+
4967 problem_list_dp = IterableWrapper (problem_url ).on_disk_cache (
50- filepath_fn = Compose ([ os . path . basename , partial ( os . path . join , cache_dir )]) ,
68+ filepath_fn = _cache_path ,
5169 )
5270 problem_list_dp = HttpReader (problem_list_dp ).end_caching (same_filepath_fn = True )
5371
@@ -69,113 +87,28 @@ def _problem_lister(*problem_url, cache_dir):
6987
7088# download and decompress the problem files
7189def _download_pipe (cache_dir , url_dp , suffix : str ):
90+ HttpReader , _ , _ = _torchdata ()
91+
92+ def _cache_path (url : str ) -> str :
93+ return os .path .join (cache_dir , os .path .basename (url ))
94+
95+ def _strip_suffix (path : str ) -> str :
96+ return path .split (suffix )[0 ]
97+
7298 # cache compressed files
7399 cache_compressed = url_dp .on_disk_cache (
74- filepath_fn = Compose ([ os . path . basename , partial ( os . path . join , cache_dir )]) ,
100+ filepath_fn = _cache_path ,
75101 )
76102 cache_compressed = HttpReader (cache_compressed ).end_caching (same_filepath_fn = True )
77103 # cache decompressed files
78104 cache_decompressed = cache_compressed .on_disk_cache (
79- filepath_fn = Compose ([ partial ( str . split , sep = suffix ), itemgetter ( 0 )]) ,
105+ filepath_fn = _strip_suffix ,
80106 )
81107 cache_decompressed = cache_decompressed .open_files (mode = "b" ).load_from_bz2 ().end_caching (
82108 same_filepath_fn = True
83109 )
84110 return cache_decompressed
85111
86- def read_bal_data (file_name : str , use_quat = False ) -> dict :
87- """
88- Read a Bundle Adjustment in the Large dataset.
89-
90- Referenced Scipy's BAL loader: https://scipy-cookbook.readthedocs.io/items/bundle_adjustment.html
91-
92- According to BAL official documentation, each problem is provided as a text file in the following format:
93-
94- <num_cameras> <num_points> <num_observations>
95- <camera_index_1> <point_index_1> <x_1> <y_1>
96- ...
97- <camera_index_num_observations> <point_index_num_observations> <x_num_observations> <y_num_observations>
98- <camera_1>
99- ...
100- <camera_num_cameras>
101- <point_1>
102- ...
103- <point_num_points>
104-
105- Where, there camera and point indices start from 0. Each camera is a set of 9 parameters - R,t,f,k1 and k2. The rotation R is specified as a Rodrigues' vector.
106-
107- Parameters
108- ----------
109- file_name : str
110- The decompressed file of the dataset.
111-
112- Returns
113- -------
114- dict
115- A dictionary containing the following fields:
116- - problem_name: str
117- The name of the problem.
118- - camera_params: torch.Tensor (n_cameras, 9 or 10)
119- contains camera parameters for each camera. If use_quat is True, the shape is (n_cameras, 10).
120- - points_3d: torch.Tensor (n_points, 3)
121- contains initial estimates of point coordinates in the world frame.
122- - points_2d: torch.Tensor (n_observations, 2)
123- contains measured 2-D coordinates of points projected on images in each observations.
124- - camera_index_of_observations: torch.Tensor (n_observations,)
125- contains indices of cameras (from 0 to n_cameras - 1) involved in each observation.
126- - point_index_of_observations: torch.Tensor (n_observations,)
127- contains indices of points (from 0 to n_points - 1) involved in each observation.
128- """
129- with open (file_name , "r" ) as file :
130- n_cameras , n_points , n_observations = map (
131- int , file .readline ().split ())
132-
133- camera_indices = torch .empty (n_observations , dtype = torch .int64 )
134- point_indices = torch .empty (n_observations , dtype = torch .int64 )
135- points_2d = torch .empty ((n_observations , 2 ), dtype = DTYPE )
136-
137- for i in range (n_observations ):
138- tmp_line = file .readline ()
139- camera_index , point_index , x , y = tmp_line .split ()
140- camera_indices [i ] = int (camera_index )
141- point_indices [i ] = int (point_index )
142- points_2d [i , 0 ] = float (x )
143- points_2d [i , 1 ] = float (y )
144-
145- camera_params = torch .empty (n_cameras * 9 , dtype = DTYPE )
146- for i in range (n_cameras * 9 ):
147- camera_params [i ] = float (file .readline ())
148- camera_params = camera_params .reshape ((n_cameras , - 1 ))
149-
150- points_3d = torch .empty (n_points * 3 , dtype = DTYPE )
151- for i in range (n_points * 3 ):
152- points_3d [i ] = float (file .readline ())
153- points_3d = points_3d .reshape ((n_points , - 1 ))
154-
155- if use_quat :
156- # convert Rodrigues vector to unit quaternion for camera rotation
157- # camera_params[0:3] is the Rodrigues vector
158- # after conversion, camera_params[0:4] is the unit quaternion
159- # r = Rotation.from_rotvec(camera_params[:, :3])
160- # q = r.as_quat()
161- r = pp .so3 (camera_params [:, :3 ])
162- q = r .Exp ()
163- # [tx, ty, tz, q0, q1, q2, q3, f, k1, k2]
164- camera_params = torch .cat ([camera_params [:, 3 :6 ], q , camera_params [:, 6 :]], axis = 1 )
165- else :
166- camera_params = torch .cat ([camera_params [:, 3 :6 ], camera_params [:, :3 ], camera_params [:, 6 :]], axis = 1 )
167-
168- # convert camera_params to torch.Tensor
169- camera_params = torch .tensor (camera_params ).to (DTYPE )
170-
171- return {'problem_name' : os .path .splitext (os .path .basename (file_name ))[0 ], # str
172- 'camera_params' : camera_params , # torch.Tensor (n_cameras, 9 or 10)
173- 'points_3d' : points_3d , # torch.Tensor (n_points, 3)
174- 'points_2d' : points_2d , # torch.Tensor (n_observations, 2)
175- 'camera_index_of_observations' : camera_indices , # torch.Tensor (n_observations,)
176- 'point_index_of_observations' : point_indices , # torch.Tensor (n_observations,)
177- }
178-
179112def build_pipeline (dataset = 'ladybug' , cache_dir = 'bal_data' , use_quat = False ):
180113 """
181114 Build a pipeline for the Bundle Adjustment in the Large dataset.
0 commit comments