55
66import cv2
77import numpy as np
8+ import rasterio
89
910from deep_image_matching .utils .database import (
1011 COLMAPDatabase ,
@@ -151,7 +152,7 @@ def ShowColmapMatches(self, plot_config: dict):
151152 img1_path ,
152153 keypoints0 ,
153154 keypoints1 ,
154- self .two_views_matches [(id0 , id1 )],
155+ self .matches [( id0 , id1 )], #"self. two_views_matches[(id0, id1)]," or "self.matches[(id0, id1)],"
155156 plot_config ,
156157 )
157158
@@ -169,13 +170,90 @@ def GeneratePlot(
169170 thickness = plot_config ["thickness" ]
170171 space_between_images = plot_config ["space_between_images" ]
171172
172- # Load images
173- img0 = cv2 .imread (str (img0_path ))
174- img1 = cv2 .imread (str (img1_path ))
175-
176- # Convert keypoints to integers
177- kpts0_int = np .round (kpts0 ).astype (int )
178- kpts1_int = np .round (kpts1 ).astype (int )
173+ # Load images using rasterio
174+ def load_image_with_rasterio (img_path ):
175+ with rasterio .open (str (img_path )) as src :
176+ img_data = src .read ()
177+ # Convert from (bands, rows, cols) to (rows, cols, bands)
178+ img = np .transpose (img_data , (1 , 2 , 0 ))
179+
180+ # Handle different number of bands
181+ if img .shape [2 ] == 1 :
182+ # Single band - convert to 3-channel grayscale
183+ img = np .repeat (img , 3 , axis = 2 )
184+ elif img .shape [2 ] > 3 :
185+ # More than 3 bands - take first 3 (typically RGB)
186+ img = img [:, :, :3 ]
187+
188+ # Convert to uint8 if needed
189+ if img .dtype != np .uint8 :
190+ # Normalize to 0-255 range if values are in different range
191+ if img .max () <= 1.0 :
192+ img = (img * 255 ).astype (np .uint8 )
193+ else :
194+ img = np .clip (img , 0 , 255 ).astype (np .uint8 )
195+
196+ ## Convert RGB to BGR for OpenCV compatibility (if 3 channels)
197+ #if img.shape[2] == 3:
198+ # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
199+
200+ return img
201+
202+ img0 = load_image_with_rasterio (img0_path )
203+ img1 = load_image_with_rasterio (img1_path )
204+
205+ # Check if images are too large and resize if necessary
206+ max_dimension = plot_config .get ("max_dimension" , 4000 ) # Maximum dimension for visualization
207+ scale_factor0 = 1.0
208+ scale_factor1 = 1.0
209+
210+ # Calculate scale factors for each image
211+ if max (img0 .shape [:2 ]) > max_dimension :
212+ scale_factor0 = max_dimension / max (img0 .shape [:2 ])
213+ new_size0 = (int (img0 .shape [1 ] * scale_factor0 ), int (img0 .shape [0 ] * scale_factor0 ))
214+ img0 = cv2 .resize (img0 , new_size0 , interpolation = cv2 .INTER_AREA )
215+ print (f"Resized img0 by factor { scale_factor0 :.3f} to { img0 .shape [1 ]} x{ img0 .shape [0 ]} " )
216+
217+ if max (img1 .shape [:2 ]) > max_dimension :
218+ scale_factor1 = max_dimension / max (img1 .shape [:2 ])
219+ new_size1 = (int (img1 .shape [1 ] * scale_factor1 ), int (img1 .shape [0 ] * scale_factor1 ))
220+ img1 = cv2 .resize (img1 , new_size1 , interpolation = cv2 .INTER_AREA )
221+ print (f"Resized img1 by factor { scale_factor1 :.3f} to { img1 .shape [1 ]} x{ img1 .shape [0 ]} " )
222+
223+ # Scale keypoints to match resized images
224+ kpts0_scaled = kpts0 * scale_factor0
225+ kpts1_scaled = kpts1 * scale_factor1
226+
227+ # Filter out invalid keypoints (NaN, inf) and convert to integers
228+ valid_mask0 = np .isfinite (kpts0_scaled ).all (axis = 1 )
229+ valid_mask1 = np .isfinite (kpts1_scaled ).all (axis = 1 )
230+
231+ kpts0_valid = kpts0_scaled [valid_mask0 ]
232+ kpts1_valid = kpts1_scaled [valid_mask1 ]
233+
234+ kpts0_int = np .round (kpts0_valid ).astype (int )
235+ kpts1_int = np .round (kpts1_valid ).astype (int )
236+
237+ print (f"Valid keypoints - Img0: { len (kpts0_int )} /{ len (kpts0 )} , Img1: { len (kpts1_int )} /{ len (kpts1 )} " )
238+
239+ # Calculate final visualization size and check memory requirements
240+ final_height = max (img0 .shape [0 ], img1 .shape [0 ])
241+ final_width = img0 .shape [1 ] + img1 .shape [1 ] + space_between_images
242+ estimated_memory_gb = (final_height * final_width * 3 ) / (1024 ** 3 )
243+
244+ print (f"Final visualization size: { final_width } x{ final_height } ({ estimated_memory_gb :.2f} GB)" )
245+
246+ if estimated_memory_gb > 8.0 : # If still too large after resizing
247+ print ("Warning: Visualization still requires significant memory. Consider using smaller max_dimension." )
248+ additional_scale = min (1.0 , 8.0 / estimated_memory_gb )
249+ if additional_scale < 1.0 :
250+ new_size0 = (int (img0 .shape [1 ] * additional_scale ), int (img0 .shape [0 ] * additional_scale ))
251+ new_size1 = (int (img1 .shape [1 ] * additional_scale ), int (img1 .shape [0 ] * additional_scale ))
252+ img0 = cv2 .resize (img0 , new_size0 , interpolation = cv2 .INTER_AREA )
253+ img1 = cv2 .resize (img1 , new_size1 , interpolation = cv2 .INTER_AREA )
254+ kpts0_int = np .round (kpts0_int * additional_scale ).astype (int )
255+ kpts1_int = np .round (kpts1_int * additional_scale ).astype (int )
256+ print (f"Applied additional scaling factor { additional_scale :.3f} " )
179257
180258 # Create a new image to draw matches
181259 img_matches = np .zeros (
@@ -193,19 +271,58 @@ def GeneratePlot(
193271 ] = (255 , 255 , 255 )
194272
195273 if show_keypoints :
196- # Show keypoints
274+ # Show valid keypoints within image bounds
197275 for kpt in kpts0_int :
198- kpt = tuple (kpt )
199- cv2 .circle (img_matches , kpt , radius , (0 , 0 , 255 ), thickness )
276+ if 0 <= kpt [0 ] < img0 .shape [1 ] and 0 <= kpt [1 ] < img0 .shape [0 ]:
277+ kpt_tuple = tuple (kpt )
278+ cv2 .circle (img_matches , kpt_tuple , radius , (0 , 0 , 255 ), thickness )
200279
201280 for kpt in kpts1_int :
202- kpt = tuple (kpt + np .array ([img0 .shape [1 ], 0 ]))
203- cv2 .circle (img_matches , kpt , radius , (0 , 0 , 255 ), thickness )
204-
205- # Draw lines and circles for matches
281+ kpt_shifted = kpt + np .array ([img0 .shape [1 ] + space_between_images , 0 ])
282+ if (0 <= kpt [0 ] < img1 .shape [1 ] and 0 <= kpt [1 ] < img1 .shape [0 ] and
283+ 0 <= kpt_shifted [0 ] < img_matches .shape [1 ] and 0 <= kpt_shifted [1 ] < img_matches .shape [0 ]):
284+ kpt_tuple = tuple (kpt_shifted )
285+ cv2 .circle (img_matches , kpt_tuple , radius , (0 , 0 , 255 ), thickness )
286+
287+ # Filter matches to only include those with valid keypoints and within image bounds
288+ valid_matches = []
289+
290+ # Create mapping from original indices to filtered indices
291+ valid_idx0_map = {}
292+ valid_idx1_map = {}
293+
294+ for i , is_valid in enumerate (valid_mask0 ):
295+ if is_valid :
296+ valid_idx0_map [i ] = len (valid_idx0_map )
297+
298+ for i , is_valid in enumerate (valid_mask1 ):
299+ if is_valid :
300+ valid_idx1_map [i ] = len (valid_idx1_map )
301+
206302 for match in matches :
207- pt1 = tuple (kpts0_int [match [0 ]])
208- pt2 = tuple (np .array (kpts1_int [match [1 ]]) + np .array ([img0 .shape [1 ], 0 ]))
303+ orig_idx0 , orig_idx1 = match [0 ], match [1 ]
304+
305+ # Check if both original indices had valid keypoints
306+ if orig_idx0 in valid_idx0_map and orig_idx1 in valid_idx1_map :
307+ new_idx0 = valid_idx0_map [orig_idx0 ]
308+ new_idx1 = valid_idx1_map [orig_idx1 ]
309+
310+ # Check if indices are within filtered arrays
311+ if new_idx0 < len (kpts0_int ) and new_idx1 < len (kpts1_int ):
312+ kpt0 = kpts0_int [new_idx0 ]
313+ kpt1 = kpts1_int [new_idx1 ]
314+
315+ # Check if keypoints are within image bounds
316+ if (0 <= kpt0 [0 ] < img0 .shape [1 ] and 0 <= kpt0 [1 ] < img0 .shape [0 ] and
317+ 0 <= kpt1 [0 ] < img1 .shape [1 ] and 0 <= kpt1 [1 ] < img1 .shape [0 ]):
318+ valid_matches .append ((new_idx0 , new_idx1 ))
319+
320+ print (f"Valid matches: { len (valid_matches )} /{ len (matches )} " )
321+
322+ # Draw lines and circles for valid matches
323+ for idx0 , idx1 in valid_matches :
324+ pt1 = tuple (kpts0_int [idx0 ])
325+ pt2 = tuple (kpts1_int [idx1 ] + np .array ([img0 .shape [1 ] + space_between_images , 0 ]))
209326
210327 # Draw a line connecting the keypoints
211328 cv2 .line (img_matches , pt1 , pt2 , (0 , 255 , 0 ), thickness )
@@ -274,20 +391,29 @@ def parse_args():
274391 required = False ,
275392 default = 1500 ,
276393 )
394+ parser .add_argument (
395+ "--max_dimension" ,
396+ type = int ,
397+ help = "Maximum dimension (width or height) for individual images before visualization" ,
398+ required = False ,
399+ default = 4000 ,
400+ )
277401 args = parser .parse_args ()
278402
279403 return args
280404
281405
282406def main ():
407+ args = parse_args ()
408+
283409 plot_config = {
284410 "show_keypoints" : True ,
285411 "radius" : 5 ,
286412 "thickness" : 2 ,
287413 "space_between_images" : 0 ,
414+ "max_dimension" : args .max_dimension ,
288415 }
289416
290- args = parse_args ()
291417 database_path = Path (args .database )
292418 out_dir = Path (args .output )
293419 imgs_dir = Path (args .imgsdir )
0 commit comments