Skip to content

Commit 124df23

Browse files
committed
fixed bug when using upright with wrong options
1 parent 9620909 commit 124df23

File tree

1 file changed

+90
-85
lines changed

1 file changed

+90
-85
lines changed

src/deep_image_matching/image_matching.py

Lines changed: 90 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def __init__(
280280
self.extraction = config.extractor["name"]
281281
self.matching = config.matcher["name"]
282282
self.pair_file = config.general["pair_file"]
283+
self.rotated_images = []
283284

284285
# self.existing_colmap_model = config.general["db_path"]
285286
# if config.general["retrieval"] == "covisibility":
@@ -349,13 +350,17 @@ def run(self):
349350
None
350351
"""
351352
# Generate pairs to be matched
352-
pair_path = self.generate_pairs()
353+
self.generate_pairs()
353354
timer.update("generate_pairs")
354355

355356
# Try to rotate images so they will be all "upright", useful for deep-learning approaches that usually are not rotation invariant
356357
if self.config.general["upright"] in ["custom", "2clusters", "exif"]:
357358
self.rotate_upright_images(self.config.general["upright"])
358359
timer.update("rotate_upright_images")
360+
elif self.config.general["upright"] is True:
361+
logger.warning(
362+
"The 'upright' option should be one of ['custom', '2clusters', 'exif']. Proceeding without rotating images."
363+
)
359364

360365
# Extract features
361366
feature_path = self.extract_features()
@@ -365,7 +370,7 @@ def run(self):
365370
match_path = self.match_pairs(feature_path)
366371

367372
# If features have been extracted on "upright" images, this function bring features back to their original image orientation
368-
if self.config.general["upright"]:
373+
if self.config.general["upright"] in ["custom", "2clusters", "exif"]:
369374
self.rotate_back_features(feature_path)
370375
timer.update("rotate_back_features")
371376

@@ -374,12 +379,12 @@ def run(self):
374379

375380
return feature_path, match_path
376381

377-
def generate_pairs(self, **kwargs) -> Path:
382+
def generate_pairs(self, **kwargs) -> None:
378383
"""
379-
Generates pairs of images for matching.
384+
Generates pairs of images for matching and stores them in the 'self.pairs' attribute.
380385
381386
Returns:
382-
Path: The path to the pair file containing the generated pairs of images.
387+
None
383388
"""
384389
if self.pair_file is not None and self.strategy == "custom_pairs":
385390
if not self.pair_file.exists():
@@ -403,7 +408,86 @@ def generate_pairs(self, **kwargs) -> Path:
403408
)
404409
self.pairs = pairs_generator.run()
405410

406-
return self.pair_file
411+
return None
412+
413+
def extract_features(self) -> Path:
414+
"""
415+
Extracts features from the images using the specified local feature extraction method.
416+
417+
Returns:
418+
Path: The path to the directory containing the extracted features.
419+
420+
Raises:
421+
ValueError: If the local feature extraction method is invalid or not supported.
422+
423+
"""
424+
logger.info(f"Extracting features with {self.extraction}...")
425+
logger.info(f"{self.extraction} configuration: ")
426+
pprint(self.config.extractor)
427+
428+
# Extract features
429+
for img in tqdm(self.image_list):
430+
feature_path = self._extractor.extract(img)
431+
432+
torch.cuda.empty_cache()
433+
logger.info("Features extracted!")
434+
435+
return feature_path
436+
437+
def match_pairs(self, feature_path: Path, try_full_image: bool = False) -> Path:
438+
"""
439+
Matches features using a specified matching method.
440+
441+
Args:
442+
feature_path (Path): The path to the directory containing the extracted features.
443+
try_full_image (bool, optional): Whether to try matching the full image. Defaults to False.
444+
445+
Returns:
446+
Path: The path to the directory containing the matches.
447+
448+
Raises:
449+
ValueError: If the feature path does not exist.
450+
"""
451+
452+
logger.info(f"Matching features with {self.matching}...")
453+
logger.info(f"{self.matching} configuration: ")
454+
pprint(self.config.matcher)
455+
456+
# Check that feature_path exists
457+
feature_path = Path(feature_path)
458+
if not feature_path.exists():
459+
raise ValueError(f"Feature path {feature_path} does not exist")
460+
461+
# Define matches path
462+
matches_path = feature_path.parent / "matches.h5"
463+
464+
# Match pairs
465+
logger.info("Matching features...")
466+
logger.info("")
467+
for i, pair in enumerate(tqdm(self.pairs)):
468+
name0 = pair[0].name if isinstance(pair[0], Path) else pair[0]
469+
name1 = pair[1].name if isinstance(pair[1], Path) else pair[1]
470+
im0 = self.image_dir / name0
471+
im1 = self.image_dir / name1
472+
473+
logger.debug(f"Matching image pair: {name0} - {name1}")
474+
475+
# Run matching
476+
self._matcher.match(
477+
feature_path=feature_path,
478+
matches_path=matches_path,
479+
img0=im0,
480+
img1=im1,
481+
try_full_image=try_full_image,
482+
)
483+
timer.update("Match pair")
484+
485+
# NOTE: Geometric verif. has been moved to the end of the matching process
486+
487+
torch.cuda.empty_cache()
488+
timer.print("matching")
489+
490+
return matches_path
407491

408492
def rotate_upright_images(
409493
self, strategy, resize_size=500, n_cores=4, multi_processing=False
@@ -612,85 +696,6 @@ def rotate_upright_images(
612696
logger.info(f"Images rotated and saved in {path_to_upright_dir}")
613697
gc.collect()
614698

615-
def extract_features(self) -> Path:
616-
"""
617-
Extracts features from the images using the specified local feature extraction method.
618-
619-
Returns:
620-
Path: The path to the directory containing the extracted features.
621-
622-
Raises:
623-
ValueError: If the local feature extraction method is invalid or not supported.
624-
625-
"""
626-
logger.info(f"Extracting features with {self.extraction}...")
627-
logger.info(f"{self.extraction} configuration: ")
628-
pprint(self.config.extractor)
629-
630-
# Extract features
631-
for img in tqdm(self.image_list):
632-
feature_path = self._extractor.extract(img)
633-
634-
torch.cuda.empty_cache()
635-
logger.info("Features extracted!")
636-
637-
return feature_path
638-
639-
def match_pairs(self, feature_path: Path, try_full_image: bool = False) -> Path:
640-
"""
641-
Matches features using a specified matching method.
642-
643-
Args:
644-
feature_path (Path): The path to the directory containing the extracted features.
645-
try_full_image (bool, optional): Whether to try matching the full image. Defaults to False.
646-
647-
Returns:
648-
Path: The path to the directory containing the matches.
649-
650-
Raises:
651-
ValueError: If the feature path does not exist.
652-
"""
653-
654-
logger.info(f"Matching features with {self.matching}...")
655-
logger.info(f"{self.matching} configuration: ")
656-
pprint(self.config.matcher)
657-
658-
# Check that feature_path exists
659-
feature_path = Path(feature_path)
660-
if not feature_path.exists():
661-
raise ValueError(f"Feature path {feature_path} does not exist")
662-
663-
# Define matches path
664-
matches_path = feature_path.parent / "matches.h5"
665-
666-
# Match pairs
667-
logger.info("Matching features...")
668-
logger.info("")
669-
for i, pair in enumerate(tqdm(self.pairs)):
670-
name0 = pair[0].name if isinstance(pair[0], Path) else pair[0]
671-
name1 = pair[1].name if isinstance(pair[1], Path) else pair[1]
672-
im0 = self.image_dir / name0
673-
im1 = self.image_dir / name1
674-
675-
logger.debug(f"Matching image pair: {name0} - {name1}")
676-
677-
# Run matching
678-
self._matcher.match(
679-
feature_path=feature_path,
680-
matches_path=matches_path,
681-
img0=im0,
682-
img1=im1,
683-
try_full_image=try_full_image,
684-
)
685-
timer.update("Match pair")
686-
687-
# NOTE: Geometric verif. has been moved to the end of the matching process
688-
689-
torch.cuda.empty_cache()
690-
timer.print("matching")
691-
692-
return matches_path
693-
694699
def rotate_back_features(self, feature_path: Path) -> None:
695700
"""
696701
Rotates back the features.

0 commit comments

Comments
 (0)