Skip to content

Commit 2d99381

Browse files
authored
Fix formatting and better match visualizations (#164)
1 parent f905a74 commit 2d99381

2 files changed

Lines changed: 28 additions & 7 deletions

File tree

lightglue/aliked.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -682,9 +682,11 @@ def __init__(self, **conf):
682682
radius=conf.nms_radius,
683683
top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints,
684684
scores_th=conf.detection_threshold,
685-
n_limit=conf.max_num_keypoints
686-
if conf.max_num_keypoints > 0
687-
else self.n_limit_max,
685+
n_limit=(
686+
conf.max_num_keypoints
687+
if conf.max_num_keypoints > 0
688+
else self.n_limit_max
689+
),
688690
)
689691

690692
state_dict = torch.hub.load_state_dict_from_url(

lightglue/viz2d.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,20 @@ def cm_prune(x_):
3939
return cm_BlRdGn(norm_x)
4040

4141

42+
def cm_grad2d(xy):
43+
"""2D grad. colormap: yellow (0, 0) -> green (1, 0) -> red (0, 1) -> blue (1, 1)."""
44+
tl = np.array([1.0, 0, 0]) # red
45+
tr = np.array([0, 0.0, 1]) # blue
46+
ll = np.array([1.0, 1.0, 0]) # yellow
47+
lr = np.array([0, 1.0, 0]) # green
48+
49+
xy = np.clip(xy, 0, 1)
50+
x = xy[..., :1]
51+
y = xy[..., -1:]
52+
rgb = (1 - x) * (1 - y) * ll + x * (1 - y) * lr + x * y * tr + (1 - x) * y * tl
53+
return rgb.clip(0, 1)
54+
55+
4256
def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True):
4357
"""Plot a set of images horizontally.
4458
Args:
@@ -49,9 +63,11 @@ def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True
4963
"""
5064
# conversion to (H, W, 3) for torch.Tensor
5165
imgs = [
52-
img.permute(1, 2, 0).cpu().numpy()
53-
if (isinstance(img, torch.Tensor) and img.dim() == 3)
54-
else img
66+
(
67+
img.permute(1, 2, 0).cpu().numpy()
68+
if (isinstance(img, torch.Tensor) and img.dim() == 3)
69+
else img
70+
)
5571
for img in imgs
5672
]
5773

@@ -122,7 +138,10 @@ def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axe
122138
kpts1 = kpts1.cpu().numpy()
123139
assert len(kpts0) == len(kpts1)
124140
if color is None:
125-
color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
141+
kpts_norm = (kpts0 - kpts0.min(axis=0, keepdims=True)) / np.ptp(
142+
kpts0, axis=0, keepdims=True
143+
)
144+
color = cm_grad2d(kpts_norm) # gradient color
126145
elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
127146
color = [color] * len(kpts0)
128147

0 commit comments

Comments
 (0)