Skip to content

Commit 43c41bf

Browse files
committed
lightgluestick train pipeline & small bugs & README
1 parent a70bb7a commit 43c41bf

8 files changed

Lines changed: 141 additions & 53 deletions

README.md

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,30 @@ Since we use points and lines to solve for the homography, we use a different ro
104104

105105
</details>
106106

107+
<details>
108+
<summary>[Evaluating LightGlueStick]</summary>
109+
110+
To evaluate LightGlueStick on HPatches, run:
111+
```bash
112+
python -m gluefactory.eval.hpatches --conf superpoint+lsd+lightgluestick-official --overwrite
113+
```
114+
You should expect the following results
115+
```
116+
{'H_error_dlt@1px': 0.3725,
117+
'H_error_dlt@3px': 0.6803,
118+
'H_error_dlt@5px': 0.7806,
119+
'H_error_ransac@1px': 0.3907,
120+
'H_error_ransac@3px': 0.6973,
121+
'H_error_ransac@5px': 0.7947,
122+
'H_error_ransac_mAA': 0.6275666666666667,
123+
'mH_error_dlt': nan,
124+
'mH_error_ransac': 0.606,
125+
'mnum_keypoints': 2287.25,
126+
'mnum_matches': 1117.0,
127+
'mprec@1px': 0.281,
128+
'mprec@3px': 0.936}
129+
```
130+
</details>
107131

108132
#### MegaDepth-1500
109133

@@ -153,6 +177,19 @@ python -m gluefactory.eval.megadepth1500 --conf gluefactory/configs/superpoint+l
153177

154178
</details>
155179

180+
<details>
181+
<summary>[Evaluating LightGlueStick]</summary>
182+
183+
To evaluate the pre-trained SuperPoint+LightGlueStick model on MegaDepth-1500, run:
184+
```bash
185+
python -m gluefactory.eval.megadepth1500 --conf superpoint+lsd+lightgluestick-official
186+
# or the adaptive variant
187+
python -m gluefactory.eval.megadepth1500 --conf superpoint+lsd+lightgluestick-official \
188+
model.matcher.depth_confidence=0.95
189+
```
190+
191+
</details>
192+
156193
<details>
157194

158195
Here are the results as Area Under the Curve (AUC) of the pose error at 5/10/20 degrees:
@@ -214,6 +251,21 @@ AP_lines: 69.22
214251

215252
</details>
216253

254+
<details>
255+
<summary>[Evaluating LightGlueStick]</summary>
256+
257+
To evaluate LightGlueStick on ETH3D, run:
258+
```bash
259+
python -m gluefactory.eval.eth3d --conf superpoint+lsd+lightgluestick-official
260+
```
261+
You should expect the following results
262+
```
263+
AP: 78.13
264+
AP_lines: 74.62
265+
```
266+
267+
</details>
268+
217269
#### Image Matching Challenge 2021
218270
Coming soon!
219271

@@ -308,16 +360,46 @@ We then fine-tune the model on the MegaDepth dataset:
308360
```bash
309361
python -m gluefactory.train gluestick_MD --conf gluefactory/configs/superpoint+lsd+gluestick-megadepth.yaml --distributed
310362
```
363+
311364
Note that we used the training splits `train_scenes.txt` and `valid_scenes.txt` to train the original model, which contains some overlap with the IMC challenge. The new default splits are now `train_scenes_clean.txt` and `valid_scenes_clean.txt`, without this overlap.
312365

313366
</details>
314367

368+
<details>
369+
<summary>[Training LightGlueStick]</summary>
370+
371+
We first pre-train LightGlueStick on the homography dataset:
372+
```bash
373+
python -m gluefactory.train lightgluestick_H --conf gluefactory/configs/superpoint+lsd+lightgluestick_homography.yaml --distributed
374+
```
375+
Feel free to use any other experiment name. Configurations are managed by [OmegaConf](https://omegaconf.readthedocs.io/) so any entry can be overridden from the command line.
376+
377+
We then fine-tune the model on the MegaDepth dataset:
378+
```bash
379+
python -m gluefactory.train lightgluestick_MD --conf gluefactory/configs/superpoint+lsd+lightgluestick_megadepth.yaml --distributed
380+
```
381+
382+
To speed up training on MegaDepth, we suggest to cache the local features before training
383+
384+
```bash
385+
# extract features
386+
python -m gluefactory.scripts.export_megadepth --method sp_lsd_wireframe --num_workers 8
387+
# run training with cached features (change the data.load_features.path depending on the export parameters). We cache 1500 keypoints and 512 lines.
388+
python -m gluefactory.train lightgluestick_MD \
389+
--conf gluefactory/configs/superpoint+lsd+lightgluestick_megadepth.yaml \
390+
train.load_experiment=lightgluestick_H \
391+
data.load_features.do=True
392+
```
393+
394+
</details>
395+
315396
### Available models
316397
Glue Factory supports training and evaluating the following deep matchers:
317398
| Model | Training? | Evaluation? |
318399
| --------- | --------- | ----------- |
319400
| [LightGlue](https://github.com/cvg/LightGlue) |||
320401
| [GlueStick](https://github.com/cvg/GlueStick) |||
402+
| [LightGlueStick](https://github.com/aubingazhib/LightGlueStick) |||
321403
| [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork) |||
322404
| [LoFTR](https://github.com/zju3dv/LoFTR) |||
323405

gluefactory/configs/superpoint+lsd+lightgluestick.yaml renamed to gluefactory/configs/superpoint+lsd+lightgluestick-official.yaml

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
model:
2-
name: gluefactory.models.two_view_pipeline
2+
name: two_view_pipeline
33
extractor:
44
name: gluefactory.models.lines.wireframe
55
point_extractor:
@@ -18,30 +18,20 @@ model:
1818
merge_line_endpoints: True
1919
nms_radius: 3
2020
matcher:
21-
name: gluefactory.models.matchers.lightgluestick
21+
name: gluefactory.models.matchers.lightgluestick_pretrained
2222
depth_confidence: -1
2323
width_confidence: -1
2424
filter_threshold: 0.1
25-
line_threshold: 3
26-
tau: 3
27-
method: "mean"
28-
weights: superpoint # This will download weights from internet
2925

3026
# ground_truth: # for ETH3D, comment otherwise
3127
# name: gluefactory.models.matchers.depth_matcher
3228
# use_lines: True
33-
3429
benchmarks:
3530
hpatches:
3631
eval:
3732
use_lines: True
3833
estimator: homography_est
3934
ransac_th: -1 # [1., 1.5, 2., 2.5, 3.]
40-
scannet:
41-
eval:
42-
use_lines: True
43-
estimator: homography_est
44-
ransac_th: -1
4535
megadepth1500:
4636
data:
4737
preprocessing:
@@ -50,19 +40,6 @@ benchmarks:
5040
eval:
5141
estimator: poselib
5242
ransac_th: -1
53-
megadepth1500_match_eval:
54-
data:
55-
preprocessing:
56-
side: long
57-
resize: 1600
58-
model:
59-
ground_truth:
60-
name: gluefactory.models.matchers.depth_matcher
61-
use_lines: True
62-
eval:
63-
eval_lines: True
64-
estimator: poselib
65-
ransac_th: -1
6643
eth3d:
6744
model:
6845
ground_truth:

gluefactory/configs/superpoint+lsd+lightgluestick_homography.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ model:
4141
th_negative: 5
4242
matcher:
4343
name: gluefactory.models.matchers.lightgluestick
44-
weights: superpoint
4544
input_dim: 256
4645
descriptor_dim: 256
4746
flash: false

gluefactory/eval/eth3d.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,10 @@ def get_predictions(self, experiment_dir, model=None, overwrite=False):
9797
return pred_file
9898

9999
def run_eval(self, loader, pred_file):
100-
eval_conf = self.conf.eval
101100
r = eval_dataset(loader, pred_file)
102101
if self.conf.eval.eval_lines:
103-
r.update(eval_dataset(loader, pred_file, conf=eval_conf, suffix="_lines"))
102+
r.update(eval_dataset(loader, pred_file, suffix="_lines"))
104103
s = {}
105-
106104
return s, {}, r
107105

108106

@@ -199,4 +197,4 @@ def plot_pr_curve(
199197
results,
200198
dst_file="eth3d_pr_curve_lines.pdf",
201199
suffix="_lines",
202-
)
200+
)

gluefactory/models/cache_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _forward(self, data):
121121
pred = batch_to_device(pred, device)
122122
for k, v in pred.items():
123123
for pattern in self.conf.scale:
124-
if k.startswith(pattern):
124+
if k.startswith(pattern) and not k.startswith("lines_junc"):
125125
view_idx = k.replace(pattern, "")
126126
scales = (
127127
data["scales"]

gluefactory/models/matchers/lightgluestick.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ def rotate_half(x: torch.Tensor) -> torch.Tensor:
4444
def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
4545
return (t * freqs[0]) + (rotate_half(t) * freqs[1])
4646

47-
def create_mask(lines_junc_idx, num_nodes):
47+
def create_mask(lines_junc_idx):
4848
# Get batch size and number of connections
49-
bs = lines_junc_idx.shape[0]
49+
bs, num_nodes = lines_junc_idx.shape
5050
# Create an empty mask
5151
mask = torch.eye(num_nodes, dtype=torch.float32).unsqueeze(0).repeat(bs, 1, 1)
5252

@@ -196,6 +196,7 @@ def forward(
196196
self,
197197
x: torch.Tensor,
198198
encoding: torch.Tensor,
199+
mask_ffn: torch.Tensor,
199200
mask: Optional[torch.Tensor] = None,
200201

201202
) -> torch.Tensor:
@@ -207,7 +208,7 @@ def forward(
207208
context = self.inner_attn(q, k, v, mask=mask)
208209
message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
209210

210-
return x + self.ffn(torch.cat([x, message], -1))
211+
return x + self.ffn(torch.cat([x, message], -1)) * mask_ffn.unsqueeze(-1)
211212

212213
class CrossBlock(nn.Module):
213214
def __init__(
@@ -280,6 +281,8 @@ def forward(
280281
desc1,
281282
encoding0,
282283
encoding1,
284+
mask_ffn0,
285+
mask_ffn1,
283286
mask0: Optional[torch.Tensor] = None,
284287
mask1: Optional[torch.Tensor] = None,
285288
):
@@ -290,10 +293,9 @@ def forward(
290293
n_endpoints1 = mask1.shape[-1]
291294

292295
desc0[:, : n_endpoints0, :] = self.line_layer(desc0[:, : n_endpoints0, :], \
293-
encoding0[:, :, :, : n_endpoints0, :], mask0)
296+
encoding0[:, :, :, : n_endpoints0, :], mask_ffn0, mask0)
294297
desc1[:, : n_endpoints1, :] = self.line_layer(desc1[:, : n_endpoints1, :], \
295-
encoding1[:, :, :, : n_endpoints1, :], mask1)
296-
298+
encoding1[:, :, :, : n_endpoints1, :], mask_ffn1, mask1)
297299
return self.cross_attn(desc0, desc1)
298300

299301

@@ -427,7 +429,7 @@ class LightGlueStick(BaseModel):
427429
"mp": False, # enable mixed precision
428430
"depth_confidence": -1, # early stopping, disable with -1
429431
"width_confidence": -1, # point pruning, disable with -1
430-
"filter_threshold": 0.1, # match threshold
432+
"filter_threshold": 0.0, # match threshold
431433
"checkpointed": False,
432434
"weights": None, # either a path or the name of pretrained weights (disk, ...)
433435
"keypoint_encoder": [32, 64, 128, 256],
@@ -483,10 +485,10 @@ def _init(self, conf) -> None:
483485
)
484486

485487
self.loss_fn = NLLLoss(conf.loss)
486-
self.i = 0
487488

488489
state_dict = None
489490
if conf.weights is not None:
491+
# weights can be either a path or an existing file from official LG
490492
if Path(conf.weights).exists():
491493
state_dict = torch.load(conf.weights, map_location="cpu")
492494
elif (Path(DATA_PATH) / conf.weights).exists():
@@ -629,6 +631,8 @@ def _forward(self, data: dict) -> dict:
629631
do_early_stop = self.conf.depth_confidence > 0 and not self.training
630632
do_point_pruning = self.conf.width_confidence > 0 and not self.training
631633

634+
all_desc0, all_desc1 = [], []
635+
632636
if do_point_pruning:
633637
ind0 = torch.arange(0, m, device=device)[None]
634638
ind1 = torch.arange(0, n, device=device)[None]
@@ -637,18 +641,30 @@ def _forward(self, data: dict) -> dict:
637641
prune1 = torch.ones_like(ind1)
638642
token0, token1 = None, None
639643

640-
n_endpoints0 = lines_junc_idx0.max() + 1
641-
n_endpoints1 = lines_junc_idx1.max() + 1
642-
643644
# pre-compute masks for LG-LMP
644-
mask0 = create_mask(lines_junc_idx0, n_endpoints0).unsqueeze(1).bool().to(lines_junc_idx0.device)
645-
mask1 = create_mask(lines_junc_idx1, n_endpoints1).unsqueeze(1).bool().to(lines_junc_idx1.device)
645+
mask0 = create_mask(lines_junc_idx0).unsqueeze(1).bool().to(lines_junc_idx0.device)
646+
mask1 = create_mask(lines_junc_idx1).unsqueeze(1).bool().to(lines_junc_idx1.device)
647+
648+
max_indices0 = lines_junc_idx0.max(1).values
649+
max_indices1 = lines_junc_idx1.max(1).values
650+
651+
mask_ffn0 = torch.arange(mask0.shape[-1], device=mask0.device).unsqueeze(0) <= max_indices0.unsqueeze(1)
652+
mask_ffn1 = torch.arange(mask1.shape[-1], device=mask1.device).unsqueeze(0) <= max_indices1.unsqueeze(1)
646653

647654
for i in range(self.conf.n_layers):
648-
torch.cuda.synchronize() # Synchronize before starting the timer
655+
if self.conf.checkpointed and self.training:
656+
desc0, desc1 = checkpoint(
657+
self.transformers[i], desc0, desc1, encoding0, encoding1, \
658+
mask_ffn0, mask_ffn1, mask0, mask1, use_reentrant=True
659+
)
660+
else:
661+
desc0, desc1 = self.transformers[i](desc0, desc1, encoding0, encoding1, \
662+
mask_ffn0, mask_ffn1, mask0, mask1)
649663

650-
desc0, desc1 = self.transformers[i](desc0, desc1, encoding0, encoding1, \
651-
mask0, mask1)
664+
if self.training or i == self.conf.n_layers - 1:
665+
all_desc0.append(desc0)
666+
all_desc1.append(desc1)
667+
continue # no early stopping or adaptive width at last layer
652668

653669
# only for eval
654670
if do_early_stop:
@@ -659,17 +675,13 @@ def _forward(self, data: dict) -> dict:
659675
if do_point_pruning:
660676
assert b == 1
661677
scores0 = self.log_assignment[i].get_matchability(desc0)
662-
663-
scores0[0, : n_endpoints0] = 1.0
664678
prunemask0 = self.get_pruning_mask(token0, scores0, i)
665679
keep0 = torch.where(prunemask0)[1]
666680
ind0 = ind0.index_select(1, keep0)
667681
desc0 = desc0.index_select(1, keep0)
668682
encoding0 = encoding0.index_select(-2, keep0)
669683
prune0[:, ind0] += 1
670684
scores1 = self.log_assignment[i].get_matchability(desc1)
671-
672-
scores1[0, : n_endpoints1] = 1.0
673685
prunemask1 = self.get_pruning_mask(token1, scores1, i)
674686
keep1 = torch.where(prunemask1)[1]
675687
ind1 = ind1.index_select(1, keep1)
@@ -703,12 +715,12 @@ def _forward(self, data: dict) -> dict:
703715
"log_assignment": scores,
704716
"prune0": prune0,
705717
"prune1": prune1,
706-
"early_exit_layer_idx": i + 1
718+
"ref_descriptors0": torch.stack(all_desc0, 1),
719+
"ref_descriptors1": torch.stack(all_desc1, 1)
707720
}
708721

709722
if n_lines0 > 0 and n_lines1 > 0:
710723
m0_lines, m1_lines, mscores0_lines, mscores1_lines = filter_matches(line_scores, self.conf.filter_threshold)
711-
712724
pred["line_log_assignment"] = line_scores
713725
pred["line_matches0"] = m0_lines
714726
pred["line_matches1"] = m1_lines
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from lightgluestick import LightGlueStick as LightGlueStick_
2+
from omegaconf import OmegaConf
3+
4+
from ..base_model import BaseModel
5+
6+
7+
class LightGlueStick(BaseModel):
8+
default_conf = {"features": "superpoint", **LightGlueStick_.default_conf}
9+
10+
def _init(self, conf):
11+
dconf = OmegaConf.to_container(conf)
12+
self.net = LightGlueStick_(dconf)
13+
self.set_initialized()
14+
15+
def _forward(self, data):
16+
return self.net(data)
17+
18+
def loss(pred, data):
19+
raise NotImplementedError

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ urls = {Repository = "https://github.com/cvg/glue-factory"}
3838

3939
[project.optional-dependencies]
4040
extra = [
41+
"lightgluestick @ git+https://github.com/aubingazhib/LightGlueStick.git",
4142
"pycolmap",
4243
"poselib",
4344
"pytlsd @ git+https://github.com/iago-suarez/pytlsd.git@4180ab8990ae68cc9c8797c63aa1dc47b2c714da",

0 commit comments

Comments
 (0)