-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathretrieval_system.py
More file actions
224 lines (180 loc) · 8.77 KB
/
Copy pathretrieval_system.py
File metadata and controls
224 lines (180 loc) · 8.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# https://youtu.be/rFCtZj_r6tA
"""
Image Retrieval System
This code implements a system for finding similar images using feature-based similarity search.
It extracts visual features from images using a neural network and enables fast similarity
search through the following main components:
1. Feature Extraction: Converts images into numerical feature vectors that capture their
visual characteristics (handled by a separate ImageFeatureExtractor class)
2. Indexing:
- Processes a directory of images and extracts their features
- Stores these features in a FAISS index (Facebook AI Similarity Search)
- Maintains metadata about each indexed image (path, filename, indexing date)
3. Search:
- Takes a query image and finds the k most similar images from the indexed collection
- Uses IndexIVFFlat to measure similarity between images
- Returns matched images sorted by similarity score
Note about IndexIVFFlat:
- Uses a "divide and conquer" approach
- First divides vectors into clusters/regions
- When searching:
* First finds which clusters are most relevant
* Only searches within those chosen clusters
- Requires two extra steps:
* Training: Learning how to divide vectors into clusters
* nprobe: Choosing how many clusters to check (tradeoff between speed and accuracy)
- Usually much faster for large datasets
- Might miss some matches (approximate search) but usually good enough
"""
import os
import json
import torch
import faiss
import numpy as np
from torch.utils.data import DataLoader
from typing import List, Tuple, Optional
from datetime import datetime
import logging
from feature_extractor import ImageFeatureExtractor, ImageDataset
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ImageRetrievalSystem:
def __init__(self,
feature_extractor: Optional[ImageFeatureExtractor] = None,
index_path: Optional[str] = None,
metadata_path: Optional[str] = None,
use_gpu: bool = False,
n_regions: int = 100,
nprobe: int = 10):
"""Initialize the retrieval system with IVF index."""
self.feature_extractor = feature_extractor or ImageFeatureExtractor()
self.feature_dim = self.feature_extractor.feature_dim
self.n_regions = n_regions
self.nprobe = nprobe
logger.info(f"Initializing retrieval system with dimension: {self.feature_dim}")
self.metadata = {}
self.is_trained = False
# Load existing index and metadata if provided
if index_path and metadata_path:
self.load(index_path, metadata_path)
else:
# Initialize new FAISS IVF index
logger.info(f"Creating new IVF index with {n_regions} regions")
self.quantizer = faiss.IndexFlatL2(self.feature_dim)
self.index = faiss.IndexIVFFlat(self.quantizer, self.feature_dim,
self.n_regions, faiss.METRIC_L2)
self.index.nprobe = self.nprobe
# Convert to GPU index if requested
if use_gpu:
try:
res = faiss.StandardGpuResources()
self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
logger.info("Successfully moved index to GPU")
except Exception as e:
logger.warning(f"Failed to use GPU, falling back to CPU: {str(e)}")
def index_images(self,
image_dir: str,
batch_size: int = 32,
num_workers: int = 4) -> None:
"""Index all images in the specified directory."""
logger.info(f"Indexing images from {image_dir}")
# Get all image paths
image_paths = [
os.path.join(image_dir, f) for f in os.listdir(image_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))
]
features_list = []
valid_paths = []
# Process images one by one
for img_path in image_paths:
try:
features = self.feature_extractor.extract_features(img_path)
features_list.append(features)
valid_paths.append(img_path)
logger.info(f"Processed {img_path}")
except Exception as e:
logger.error(f"Error processing {img_path}: {str(e)}")
continue
if not features_list:
raise ValueError("No valid features extracted from images")
# Combine all features
all_features = np.stack(features_list)
logger.info(f"Feature array shape: {all_features.shape}")
logger.info(f"Feature stats - Min: {all_features.min():.4f}, Max: {all_features.max():.4f}")
# Train index if not already trained
if not self.is_trained:
logger.info("Training IVF index...")
self.index.train(all_features)
self.is_trained = True
logger.info("Index training completed")
# Add to index
self.index.add(all_features)
logger.info(f"Total vectors in index: {self.index.ntotal}")
# Update metadata
for idx, path in enumerate(valid_paths):
self.metadata[str(idx)] = {
'path': path,
'filename': os.path.basename(path),
'indexed_at': datetime.now().isoformat()
}
logger.info(f"Successfully indexed {len(valid_paths)} images")
def search(self,
query_image_path: str,
k: int = 5) -> List[Tuple[str, float]]:
"""Search for similar images."""
logger.info(f"Searching for similar images to {query_image_path}")
logger.info(f"Total images in index: {self.index.ntotal}")
logger.info(f"Available metadata keys: {list(self.metadata.keys())}")
if not self.is_trained:
raise RuntimeError("Index has not been trained. Add images first.")
# Extract features from query image
query_features = self.feature_extractor.extract_features(query_image_path)
logger.info(f"Query feature shape: {query_features.shape}")
# Search index
k = min(k, self.index.ntotal) # Make sure k doesn't exceed number of indexed images
distances, indices = self.index.search(
query_features.reshape(1, -1),
k
)
logger.info(f"Raw search results - distances: {distances[0]}")
logger.info(f"Raw search results - indices: {indices[0]}")
logger.info(f"Searched {self.nprobe} out of {self.n_regions} regions")
# Prepare results
results = []
for dist, idx in zip(distances[0], indices[0]):
str_idx = str(int(idx))
if str_idx in self.metadata:
results.append((self.metadata[str_idx]['path'], float(dist)))
logger.info(f"Match found: {self.metadata[str_idx]['path']} with distance {dist:.3f}")
else:
logger.warning(f"Index {idx} not found in metadata")
# Sort results by distance (smaller is better)
results.sort(key=lambda x: x[1])
if not results:
logger.warning("No matches found!")
else:
logger.info(f"Found {len(results)} matches")
return results
def save(self, index_path: str, metadata_path: str) -> None:
"""Save the index and metadata to disk."""
# If index is on GPU, convert back to CPU for saving
if faiss.get_num_gpus() > 0:
self.index = faiss.index_gpu_to_cpu(self.index)
faiss.write_index(self.index, index_path)
with open(metadata_path, 'w') as f:
json.dump(self.metadata, f)
logger.info(f"Saved index with {self.index.ntotal} vectors")
logger.info(f"Saved index to {index_path} and metadata to {metadata_path}")
def load(self, index_path: str, metadata_path: str) -> None:
"""Load the index and metadata from disk."""
logger.info(f"Loading index from {index_path}")
self.index = faiss.read_index(index_path)
self.is_trained = True # Loaded indexes are already trained
# Set nprobe for loaded index
if isinstance(self.index, faiss.IndexIVFFlat):
self.index.nprobe = self.nprobe
logger.info(f"Set nprobe to {self.nprobe} for loaded IVF index")
with open(metadata_path, 'r') as f:
self.metadata = json.load(f)
logger.info(f"Loaded index with {self.index.ntotal} vectors")
logger.info(f"Metadata contains {len(self.metadata)} entries")