@@ -39,6 +39,20 @@ def cm_prune(x_):
3939 return cm_BlRdGn (norm_x )
4040
4141
42+ def cm_grad2d (xy ):
43+ """2D grad. colormap: yellow (0, 0) -> green (1, 0) -> red (0, 1) -> blue (1, 1)."""
44+ tl = np .array ([1.0 , 0 , 0 ]) # red
45+ tr = np .array ([0 , 0.0 , 1 ]) # blue
46+ ll = np .array ([1.0 , 1.0 , 0 ]) # yellow
47+ lr = np .array ([0 , 1.0 , 0 ]) # green
48+
49+ xy = np .clip (xy , 0 , 1 )
50+ x = xy [..., :1 ]
51+ y = xy [..., - 1 :]
52+ rgb = (1 - x ) * (1 - y ) * ll + x * (1 - y ) * lr + x * y * tr + (1 - x ) * y * tl
53+ return rgb .clip (0 , 1 )
54+
55+
4256def plot_images (imgs , titles = None , cmaps = "gray" , dpi = 100 , pad = 0.5 , adaptive = True ):
4357 """Plot a set of images horizontally.
4458 Args:
@@ -49,9 +63,11 @@ def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True
4963 """
5064 # conversion to (H, W, 3) for torch.Tensor
5165 imgs = [
52- img .permute (1 , 2 , 0 ).cpu ().numpy ()
53- if (isinstance (img , torch .Tensor ) and img .dim () == 3 )
54- else img
66+ (
67+ img .permute (1 , 2 , 0 ).cpu ().numpy ()
68+ if (isinstance (img , torch .Tensor ) and img .dim () == 3 )
69+ else img
70+ )
5571 for img in imgs
5672 ]
5773
@@ -122,7 +138,10 @@ def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axe
122138 kpts1 = kpts1 .cpu ().numpy ()
123139 assert len (kpts0 ) == len (kpts1 )
124140 if color is None :
125- color = matplotlib .cm .hsv (np .random .rand (len (kpts0 ))).tolist ()
141+ kpts_norm = (kpts0 - kpts0 .min (axis = 0 , keepdims = True )) / np .ptp (
142+ kpts0 , axis = 0 , keepdims = True
143+ )
144+ color = cm_grad2d (kpts_norm ) # gradient color
126145 elif len (color ) > 0 and not isinstance (color [0 ], (tuple , list )):
127146 color = [color ] * len (kpts0 )
128147
0 commit comments