|
10 | 10 | import torch
|
11 | 11 | import argparse
|
12 | 12 |
|
| 13 | +from copy import deepcopy |
13 | 14 | from xfeat.xfeat import XFeat
|
14 | 15 | from lightglue import LightGlue, SuperPoint
|
| 16 | +from efficientloftr.loftr import LoFTR, full_default_cfg, opt_default_cfg, reparameter |
15 | 17 |
|
16 |
| -from utils import warp_corners_and_draw_matches, sp_lg |
| 18 | +from utils import warp_corners_and_draw_matches, sp_lg, eloftr |
17 | 19 |
|
18 | 20 |
|
19 | 21 | if __name__ == '__main__':
|
20 | 22 | parser = argparse.ArgumentParser()
|
21 | 23 | parser.add_argument('--ref', type=str, help='Path to the reference image', default='assets/groot/groot.jpg')
|
22 | 24 | parser.add_argument('--tgt', type=str, help='Path to the target video', default='assets/groot/groot.mp4')
|
23 |
| - parser.add_argument('--method', type=str, help='Method to use for image matching (xfeat+mnn, sp+lg)', default='xfeat+mnn') |
| 25 | + parser.add_argument('--method', type=str, help='Method to use for image matching (xfeat+mnn, sp+lg, loftr)', default='xfeat+mnn') |
24 | 26 | parser.add_argument('--save_path', type=str, help='Path to save the output video', default='output.mp4')
|
25 | 27 |
|
26 | 28 | args = parser.parse_args()
|
|
36 | 38 | elif method == 'sp+lg':
|
37 | 39 | extractor = SuperPoint(max_num_keypoints=2048).eval().to(device)
|
38 | 40 | matcher = LightGlue(features="superpoint").eval().to(device)
|
| 41 | + print(f'Load superpoint and lightglue to {device}') |
| 42 | + elif method == 'loftr': |
| 43 | + _default_cfg = deepcopy(full_default_cfg) |
| 44 | + loftr = LoFTR(config=_default_cfg) |
| 45 | + loftr.load_state_dict(torch.load("weights/eloftr_outdoor.ckpt")['state_dict']) |
| 46 | + loftr = reparameter(loftr) |
| 47 | + loftr = loftr.eval().to(device) |
39 | 48 | else:
|
40 | 49 | raise ValueError(f'Unknown method: {method}')
|
41 | 50 |
|
|
72 | 81 | mkpts_0, mkpts_1, time_det, time_mat = xfeat.match_xfeat(ref, frame, top_k = 4096)
|
73 | 82 | elif method == 'sp+lg':
|
74 | 83 | mkpts_0, mkpts_1, time_det, time_mat = sp_lg(extractor, matcher, ref, frame)
|
75 |
| - |
76 |
| - |
| 84 | + elif method == 'loftr': |
| 85 | + mkpts_0, mkpts_1, time_det, time_mat = eloftr(loftr, ref, frame) |
77 | 86 | time_total = time_det + time_mat
|
78 | 87 |
|
79 | 88 | canvas = warp_corners_and_draw_matches(mkpts_1, mkpts_0, frame, ref, time_total)
|
|
0 commit comments