Skip to content

Commit e7d2eb8

Browse files
committed
update elfotr
1 parent bd2b975 commit e7d2eb8

File tree

2 files changed

+50
-5
lines changed

2 files changed

+50
-5
lines changed

utils.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -105,4 +105,40 @@ def sp_lg(sp, lg, ref, tgt):
105105
ref = np.transpose(ref,(1,2,0))
106106
ref = cv2.cvtColor(np.uint8(ref*255), cv2.COLOR_RGB2BGR)
107107

108-
return mkpts_0, mkpts_1, time_det, time_mat
108+
return mkpts_0, mkpts_1, time_det, time_mat
109+
110+
def eloftr(model, ref, tgt):
111+
"""
112+
Function to perform LoFTR pipeline.
113+
114+
Args:
115+
model: LoFTR model
116+
ref: reference image
117+
tgt: target image
118+
119+
Returns:
120+
kpts0: keypoints of reference image
121+
kpts1: keypoints of target image
122+
time_det: time taken for detection
123+
time_mat: time taken for matching
124+
"""
125+
126+
ref = cv2.resize(ref, (ref.shape[1]//32*32, ref.shape[0]//32*32))
127+
tgt = cv2.resize(tgt, (tgt.shape[1]//32*32, tgt.shape[0]//32*32))
128+
129+
ref = torch.from_numpy(ref)[None][None].cuda()/255.
130+
tgt = torch.from_numpy(tgt)[None][None].cuda()/255.
131+
132+
batch = {'image0': ref, 'image1': tgt}
133+
134+
tik = time.time()
135+
with torch.no_grad():
136+
model(batch)
137+
tok = time.time()
138+
time_total = tok - tik
139+
140+
mkpts0 = batch['mkpts0_f'].cpu().numpy()
141+
mkpts1 = batch['mkpts1_f'].cpu().numpy()
142+
mconf = batch['mconf'].cpu().numpy()
143+
144+
return mkpts0, mkpts1, 0, time_total

video_matching.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,19 @@
1010
import torch
1111
import argparse
1212

13+
from copy import deepcopy
1314
from xfeat.xfeat import XFeat
1415
from lightglue import LightGlue, SuperPoint
16+
from efficientloftr.loftr import LoFTR, full_default_cfg, opt_default_cfg, reparameter
1517

16-
from utils import warp_corners_and_draw_matches, sp_lg
18+
from utils import warp_corners_and_draw_matches, sp_lg, eloftr
1719

1820

1921
if __name__ == '__main__':
2022
parser = argparse.ArgumentParser()
2123
parser.add_argument('--ref', type=str, help='Path to the reference image', default='assets/groot/groot.jpg')
2224
parser.add_argument('--tgt', type=str, help='Path to the target video', default='assets/groot/groot.mp4')
23-
parser.add_argument('--method', type=str, help='Method to use for image matching (xfeat+mnn, sp+lg)', default='xfeat+mnn')
25+
parser.add_argument('--method', type=str, help='Method to use for image matching (xfeat+mnn, sp+lg, loftr)', default='xfeat+mnn')
2426
parser.add_argument('--save_path', type=str, help='Path to save the output video', default='output.mp4')
2527

2628
args = parser.parse_args()
@@ -36,6 +38,13 @@
3638
elif method == 'sp+lg':
3739
extractor = SuperPoint(max_num_keypoints=2048).eval().to(device)
3840
matcher = LightGlue(features="superpoint").eval().to(device)
41+
print(f'Load superpoint and lightglue to {device}')
42+
elif method == 'loftr':
43+
_default_cfg = deepcopy(full_default_cfg)
44+
loftr = LoFTR(config=_default_cfg)
45+
loftr.load_state_dict(torch.load("weights/eloftr_outdoor.ckpt")['state_dict'])
46+
loftr = reparameter(loftr)
47+
loftr = loftr.eval().to(device)
3948
else:
4049
raise ValueError(f'Unknown method: {method}')
4150

@@ -72,8 +81,8 @@
7281
mkpts_0, mkpts_1, time_det, time_mat = xfeat.match_xfeat(ref, frame, top_k = 4096)
7382
elif method == 'sp+lg':
7483
mkpts_0, mkpts_1, time_det, time_mat = sp_lg(extractor, matcher, ref, frame)
75-
76-
84+
elif method == 'loftr':
85+
mkpts_0, mkpts_1, time_det, time_mat = eloftr(loftr, ref, frame)
7786
time_total = time_det + time_mat
7887

7988
canvas = warp_corners_and_draw_matches(mkpts_1, mkpts_0, frame, ref, time_total)

0 commit comments

Comments
 (0)