Skip to content

Commit 6493a79

Browse files
authored
Update dependencies and improve documentation (#34)
* Update dependencies and improve documentation - Updated flake8 to version 5.0.4 in .pre-commit-config.yaml. - Added gdown to requirements.txt for model downloading. - Changed dataset download link in get_started.md and conftest.py to use Google Drive. - Refactored model download logic in bmnet.py to use gdown. - Introduced gdown_download_model function in miscs.py for consistent model downloading. * Update lint workflow to use Ubuntu 22.04 and remove unnecessary Ruby installation steps
1 parent 76b621e commit 6493a79

File tree

8 files changed

+74
-57
lines changed

8 files changed

+74
-57
lines changed

.github/workflows/lint.yml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ concurrency:
88

99
jobs:
1010
lint:
11-
runs-on: ubuntu-18.04
11+
runs-on: ubuntu-22.04
1212
steps:
1313
- uses: actions/checkout@v2
1414
- name: Set up Python 3.7
@@ -17,9 +17,6 @@ jobs:
1717
python-version: 3.7
1818
- name: Install pre-commit hook
1919
run: |
20-
sudo apt-add-repository ppa:brightbox/ruby-ng -y
21-
sudo apt-get update
22-
sudo apt-get install -y ruby2.7
2320
pip install pre-commit
2421
pre-commit install
2522
- name: Linting

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ repos:
44
rev: v14.0.1
55
hooks:
66
- id: clang-format
7-
- repo: https://gitlab.com/pycqa/flake8.git
8-
rev: 3.8.3
7+
- repo: https://github.com/PyCQA/flake8
8+
rev: 5.0.4
99
hooks:
1010
- id: flake8
1111
- repo: https://github.com/asottile/seed-isort-config.git

docs/en/get_started.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
XRLocalization provides a flexible tool that can easily perform visual localization offline and online.
33
Given a query image, XRLocalization estimates a 6DoF pose from a pre-reconstructed map. A tiny dataset
44
is provided for convenience. Download the dataset from
5-
[here](https://openxrlab-share-mainland.oss-cn-hangzhou.aliyuncs.com/xrlocalization/meta/xrloc-test-meta.tar.gz). The
5+
[here](https://drive.google.com/file/d/1vKfCDWtZ1ui5t5sYjlF_EAqj1mVM1zk-/view?usp=sharing). The
66
folder extracted from the dataset is shown below.
77
```commandline
88
├── map

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ scipy>=1.6.1
99
setuptools>=46.1.3
1010
torch>=1.1.0
1111
tqdm
12-
xrprimer==0.5.2
12+
xrprimer==0.5.2
13+
gdown

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
def fixture():
1111
if os.path.exists(dataset_dir):
1212
shutil.rmtree(dataset_dir)
13-
url = 'https://openxrlab-share-mainland.oss-cn-hangzhou.aliyuncs.com/xrlocalization/meta/xrloc-test-meta.tar.gz'
14-
command = ['wget', '--no-check-certificate', url]
13+
url = 'https://docs.google.com/uc?id=1vKfCDWtZ1ui5t5sYjlF_EAqj1mVM1zk-'
14+
command = ['gdown', url]
1515
subprocess.run(command, check=True)
1616
command = ['tar', '-xf', 'xrloc-test-meta.tar.gz']
1717
subprocess.run(command, check=True)

xrloc/localizer.py

Lines changed: 53 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,23 @@ def __init__(self, map_path, config=default_config):
3737
head_logging('XRLocalization')
3838
self.config = config
3939
database_path = os.path.join(map_path, 'database.bin')
40-
40+
4141
if os.path.exists(database_path):
4242
self.database = ImageDatabase(database_path)
4343
self.database.create()
4444
else:
45-
pairs = [name for name in os.listdir(map_path)
46-
if name.startswith('pairs-query')]
45+
pairs = [
46+
name for name in os.listdir(map_path)
47+
if name.startswith('pairs-query')
48+
]
4749
if len(pairs) == 0:
4850
raise ValueError(
4951
'Not found database under map: {}'.format(map_path))
5052
else:
5153
self.pairs = PairsDatabase(os.path.join(map_path, pairs[0]))
5254

5355
self.reconstruction = Reconstruction(map_path)
54-
56+
5557
if hasattr(self, 'database'):
5658
self.gextractor = Extractor(self.config['global_feature'])
5759
self.lextractor = Extractor(self.config['local_feature'])
@@ -61,8 +63,9 @@ def __init__(self, map_path, config=default_config):
6163
config_logging(self.config)
6264

6365
if self.config['mode'] == '2D2D' and self.config['matcher'] == 'gam':
64-
raise ValueError('Loc mode {} is not compatible with matcher {}'.format(
65-
self.config['mode'], self.config['matcher']))
66+
raise ValueError(
67+
'Loc mode {} is not compatible with matcher {}'.format(
68+
self.config['mode'], self.config['matcher']))
6669

6770
head_logging('Init Success')
6871

@@ -91,15 +94,19 @@ def geo_localize(self, data):
9194
if hasattr(self, 'database'):
9295
image_feature = self.gextractor.extract(data)
9396
image_ids = self.database.retrieve(image_feature,
94-
self.config['retrieval_num'])
97+
self.config['retrieval_num'])
9598
return image_ids
9699
elif hasattr(self, 'pairs') and isinstance(data, str):
97-
image_names = self.pairs.image_retrieve(data, self.config['retrieval_num'])
98-
image_ids = [self.reconstruction.name_to_id(name) for name in image_names]
100+
image_names = self.pairs.image_retrieve(
101+
data, self.config['retrieval_num'])
102+
image_ids = [
103+
self.reconstruction.name_to_id(name) for name in image_names
104+
]
99105
return np.array([id for id in image_ids if id != -1])
100106

101107
def feature_match_2d3d(self, query_points, query_point_descriptors,
102-
train_points, train_point_descriptors, width, height):
108+
train_points, train_point_descriptors, width,
109+
height):
103110
"""Feature matching phase."""
104111
query_feat = {
105112
'shape': np.array([height, width]),
@@ -114,58 +121,57 @@ def feature_match_2d3d(self, query_points, query_point_descriptors,
114121
return pred['matches'], pred['scores']
115122

116123
def establish_correspondences_2d2d(self, query_feat, image_ids):
117-
'''Establish 2D-3D correspondences depend on 2D2D matching
118-
'''
124+
"""Establish 2D-3D correspondences depend on 2D2D matching."""
119125
logging.info('Scene size: {0}'.format(len(image_ids)))
120-
match_indices = np.ones(len(query_feat['points']), dtype=int)*-1
126+
match_indices = np.ones(len(query_feat['points']), dtype=int) * -1
121127
match_priors = np.zeros(len(query_feat['points']))
122128
for image_id in image_ids:
123129
ref_image = self.reconstruction.image_at(image_id)
124130
ref_feat = {
125-
'points': ref_image.xys,
126-
'descs': self.reconstruction.point3d_features(ref_image.point3D_ids),
127-
'scores': np.ones(len(ref_image.xys)),
128-
'shape': np.array([600, 600]) # TODO
131+
'points':
132+
ref_image.xys,
133+
'descs':
134+
self.reconstruction.point3d_features(ref_image.point3D_ids),
135+
'scores':
136+
np.ones(len(ref_image.xys)),
137+
'shape':
138+
np.array([600, 600]) # TODO
129139
}
130140
pred = self.matcher.match(query_feat, ref_feat)
131141
matches, scores = pred['matches'], pred['scores']
132-
142+
133143
reserve_matches = matches[:, scores > match_priors[matches[0]]]
134144
reserve_scores = scores[scores > match_priors[matches[0]]]
135145
if len(reserve_scores) > 0:
136-
match_indices[reserve_matches[0]] = ref_image.point3D_ids[reserve_matches[1]]
146+
match_indices[reserve_matches[0]] = ref_image.point3D_ids[
147+
reserve_matches[1]]
137148
match_priors[reserve_matches[0]] = reserve_scores
138149

139150
if len(match_priors[match_indices != -1]) > 400:
140151
break
141152

142153
point3d_ids = match_indices[match_indices != -1]
143-
points3d = self.reconstruction.point3d_coordinates(
144-
point3d_ids)
154+
points3d = self.reconstruction.point3d_coordinates(point3d_ids)
145155
points2d = query_feat['points'][match_indices != -1]
146156
priors = match_priors[match_indices != -1]
147157
logging.info('Match number: {0}'.format(len(priors)))
148158
return points2d, points3d, priors
149159

150-
151160
def establish_correspondences_2d3d(self, feat, image_ids):
152-
'''Establish 2D-3D correspondences depend on 2D3D matching
153-
'''
161+
"""Establish 2D-3D correspondences depend on 2D3D matching."""
154162
point3d_ids = self.reconstruction.visible_points(image_ids)
155-
point3ds = self.reconstruction.point3d_coordinates(
156-
point3d_ids)
157-
point3d_descs = self.reconstruction.point3d_features(
158-
point3d_ids)
159-
logging.info('3d points size: {0}'.format(
160-
point3d_descs.shape[1]))
163+
point3ds = self.reconstruction.point3d_coordinates(point3d_ids)
164+
point3d_descs = self.reconstruction.point3d_features(point3d_ids)
165+
logging.info('3d points size: {0}'.format(point3d_descs.shape[1]))
161166

162167
# Matching
163-
matches, priors = self.feature_match_2d3d(feat['points'],
164-
feat['descs'],
165-
point3ds,
166-
point3d_descs,
167-
feat['shape'][1], # width
168-
feat['shape'][0])
168+
matches, priors = self.feature_match_2d3d(
169+
feat['points'],
170+
feat['descs'],
171+
point3ds,
172+
point3d_descs,
173+
feat['shape'][1], # width
174+
feat['shape'][0])
169175
logging.info('Match number: {0}'.format(matches.shape[1]))
170176
points2d = feat['points'][matches[0]]
171177
points3d = point3ds[matches[1]]
@@ -202,19 +208,22 @@ def refine_localize(self, image, camera, ref_image_ids):
202208
logging.info('Coarse location number: {0}'.format(len(scenes)))
203209

204210
best_ret = {
205-
'ninlier': 0, 'qvec': np.array([1, 0, 0, 0]),
206-
'tvec': np.array([0, 0, 0]), 'mask': None
211+
'ninlier': 0,
212+
'qvec': np.array([1, 0, 0, 0]),
213+
'tvec': np.array([0, 0, 0]),
214+
'mask': None
207215
}
208216
for i, image_ids in enumerate(scenes[:self.config['max_scene_num']]):
209217
# Establish 2D-3D correspondences
210218
if self.config['mode'] == '2D3D':
211-
points2d, points3d, priors = self.establish_correspondences_2d3d(
212-
feat, image_ids)
219+
points2d, points3d, priors = \
220+
self.establish_correspondences_2d3d(feat, image_ids)
213221
elif self.config['mode'] == '2D2D':
214-
points2d, points3d, priors = self.establish_correspondences_2d2d(
215-
feat, image_ids)
222+
points2d, points3d, priors = \
223+
self.establish_correspondences_2d2d(feat, image_ids)
216224

217-
if len(priors) < 3: continue
225+
if len(priors) < 3:
226+
continue
218227

219228
# Pose estimation
220229
ret = self.prior_guided_pose_estimation(points2d, points3d, priors,
@@ -226,4 +235,4 @@ def refine_localize(self, image, camera, ref_image_ids):
226235
best_ret = ret
227236
if best_ret['ninlier'] > self.config['max_inlier']:
228237
break
229-
return best_ret
238+
return best_ret

xrloc/matchers/bmnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch.nn as nn
55

6-
from xrloc.utils.miscs import download_model, get_parent_dir
6+
from xrloc.utils.miscs import gdown_download_model, get_parent_dir
77
from scipy.optimize import linear_sum_assignment
88

99

@@ -135,11 +135,11 @@ def __init__(self, config=default_config):
135135
# 'hpooling'] else None
136136
self.conv = nn.Conv1d(self.config['channels2'][-1], 1, kernel_size=1)
137137

138-
url = 'https://openxrlab-share-mainland.oss-cn-hangzhou.aliyuncs.com/xrlocalization/weights/bmnet.pth'
138+
url = 'https://docs.google.com/uc?id=1RqkkGc5WomkP7aDgLbCXLqlc8jcwdlcN'
139139
model_dir = get_parent_dir(__file__) + '/../models/'
140140

141141
model_name = 'bmnet.pth'
142-
download_model(url, model_dir, model_name)
142+
gdown_download_model(url, model_dir, model_name)
143143
model_path = os.path.join(model_dir, model_name)
144144

145145
self.load_state_dict(

xrloc/utils/miscs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,16 @@ def download_model(url, model_dir, model_name):
107107
subprocess.run(command, check=True)
108108

109109

110+
def gdown_download_model(url, model_dir, model_name):
111+
model_path = os.path.join(model_dir, model_name)
112+
if os.path.exists(model_path):
113+
return
114+
os.makedirs(model_dir, exist_ok=True)
115+
print('Downloading the {} model from {}.'.format(model_name, url))
116+
command = ['gdown', url, '-O', model_path]
117+
subprocess.run(command, check=True)
118+
119+
110120
def head_logging(info: str, width=50):
111121
logging.info('=' * width)
112122
left = '*' * int((width - len(info) - 2) / 2) + ' '

0 commit comments

Comments
 (0)