|
24 | 24 |
|
25 | 25 | PDQIndexMatch = IndexMatchUntyped[SignalSimilarityInfoWithIntDistance, IndexT]
|
26 | 26 |
|
| 27 | +def pick_n_centroids(n: int) -> int: |
| 28 | + """Pick number of centroids based on dataset size.""" |
| 29 | + return max(int(n ** 0.5), 1) |
| 30 | + |
| 31 | +class _PDQHashIndex: |
| 32 | + """ |
| 33 | + A wrapper around the faiss index for pickle serialization |
| 34 | + """ |
| 35 | + |
| 36 | + def __init__(self, faiss_index: faiss.Index) -> None: |
| 37 | + self.faiss_index = faiss_index |
| 38 | + |
| 39 | + def search( |
| 40 | + self, |
| 41 | + queries: t.Sequence[str], |
| 42 | + threshold: int, |
| 43 | + ) -> t.List[t.List[t.Tuple[int, float]]]: |
| 44 | + """ |
| 45 | + Search method that returns a mapping from query_str => (id, distance) |
| 46 | + """ |
| 47 | + qs = convert_pdq_strings_to_ndarray(queries) |
| 48 | + limits, D, I = self.faiss_index.range_search(qs, threshold + 1) |
| 49 | + |
| 50 | + results = [] |
| 51 | + for i in range(len(queries)): |
| 52 | + matches = [result.item() for result in I[limits[i] : limits[i + 1]]] |
| 53 | + distances = [dist for dist in D[limits[i] : limits[i + 1]]] |
| 54 | + results.append(list(zip(matches, distances))) |
| 55 | + return results |
| 56 | + |
| 57 | + def add(self, vectors: np.ndarray) -> None: |
| 58 | + """Add vectors to the index.""" |
| 59 | + self.faiss_index.add(vectors) |
| 60 | + |
| 61 | + def train(self, vectors: np.ndarray) -> None: |
| 62 | + """Train the index if needed.""" |
| 63 | + if isinstance(self.faiss_index, faiss.IndexIVFFlat): |
| 64 | + self.faiss_index.train(vectors) |
| 65 | + |
| 66 | + def __getstate__(self): |
| 67 | + return faiss.serialize_index(self.faiss_index) |
| 68 | + |
| 69 | + def __setstate__(self, data): |
| 70 | + self.faiss_index = faiss.deserialize_index(data) |
| 71 | + |
27 | 72 |
|
28 | 73 | class PDQIndex2(SignalTypeIndex[IndexT]):
|
29 | 74 | """
|
30 | 75 | Indexing and querying PDQ signals using Faiss for approximate nearest neighbor search.
|
31 |
| -
|
32 |
| - This is a redo of the existing PDQ index, |
33 |
| - designed to be simpler and fix hard-to-squash bugs in the existing implementation. |
34 |
| - Purpose of this class: to replace the original index in pytx 2.0 |
35 | 76 | """
|
36 | 77 |
|
37 | 78 | def __init__(
|
38 | 79 | self,
|
39 |
| - index: t.Optional[faiss.Index] = None, |
40 |
| - entries: t.Iterable[t.Tuple[str, IndexT]] = (), |
41 |
| - *, |
42 | 80 | threshold: int = PDQ_CONFIDENT_MATCH_THRESHOLD,
|
| 81 | + faiss_index: t.Optional[faiss.Index] = None, |
43 | 82 | ) -> None:
|
44 | 83 | super().__init__()
|
| 84 | + if faiss_index is None: |
| 85 | + faiss_index = faiss.IndexFlatL2(BITS_IN_PDQ) |
| 86 | + self.faiss_index = _PDQHashIndex(faiss_index) |
45 | 87 | self.threshold = threshold
|
46 |
| - |
47 |
| - if index is None: |
48 |
| - index = faiss.IndexFlatL2(BITS_IN_PDQ) |
49 |
| - self._index = _PDQFaissIndex(index) |
50 |
| - |
51 |
| - # Matches hash to Faiss index |
52 | 88 | self._deduper: t.Dict[str, int] = {}
|
53 |
| - # Entry mapping: Each list[entries]'s index is its hash's index |
54 | 89 | self._idx_to_entries: t.List[t.List[IndexT]] = []
|
55 | 90 |
|
56 |
| - self.add_all(entries=entries) |
57 |
| - |
58 |
| - def __len__(self) -> int: |
59 |
| - return len(self._idx_to_entries) |
60 |
| - |
61 |
| - def query(self, hash: str) -> t.Sequence[PDQIndexMatch[IndexT]]: |
62 |
| - """ |
63 |
| - Look up entries against the index, up to the threshold. |
64 |
| - """ |
65 |
| - results: t.List[PDQIndexMatch[IndexT]] = [] |
66 |
| - matches_list: t.List[t.Tuple[int, int]] = self._index.search( |
67 |
| - queries=[hash], threshold=self.threshold |
68 |
| - ) |
69 |
| - |
70 |
| - for match, distance in matches_list: |
71 |
| - entries = self._idx_to_entries[match] |
72 |
| - # Create match objects for each entry |
73 |
| - results.extend( |
74 |
| - PDQIndexMatch( |
75 |
| - SignalSimilarityInfoWithIntDistance(distance=distance), |
76 |
| - entry, |
77 |
| - ) |
78 |
| - for entry in entries |
| 91 | + def query(self, query: str) -> t.Sequence[PDQIndexMatch[IndexT]]: |
| 92 | + """Look up entries against the index.""" |
| 93 | + results = self.faiss_index.search([query], self.threshold) |
| 94 | + return [ |
| 95 | + PDQIndexMatch( |
| 96 | + SignalSimilarityInfoWithIntDistance(distance=int(distf)), |
| 97 | + entry |
79 | 98 | )
|
80 |
| - return results |
| 99 | + for idx, distf in results[0] |
| 100 | + for entry in self._idx_to_entries[idx] |
| 101 | + ] |
| 102 | + |
| 103 | + def _dedupe_and_add(self, signal_str: str, entry: IndexT, add_to_faiss: bool = True) -> None: |
| 104 | + """Helper to handle deduplication and adding entries.""" |
| 105 | + existing_id = self._deduper.get(signal_str) |
| 106 | + if existing_id is None: |
| 107 | + next_id = len(self._deduper) |
| 108 | + self._deduper[signal_str] = next_id |
| 109 | + self._idx_to_entries.append([entry]) |
| 110 | + if add_to_faiss: |
| 111 | + self.faiss_index.add(convert_pdq_strings_to_ndarray([signal_str])) |
| 112 | + else: |
| 113 | + self._idx_to_entries[existing_id].append(entry) |
81 | 114 |
|
82 | 115 | def add(self, signal_str: str, entry: IndexT) -> None:
|
83 |
| - self.add_all(((signal_str, entry),)) |
| 116 | + """Add an entry to the index.""" |
| 117 | + self._dedupe_and_add(signal_str, entry) |
84 | 118 |
|
85 | 119 | def add_all(self, entries: t.Iterable[t.Tuple[str, IndexT]]) -> None:
|
86 |
| - for h, i in entries: |
87 |
| - existing_faiss_id = self._deduper.get(h) |
88 |
| - if existing_faiss_id is None: |
89 |
| - self._index.add([h]) |
90 |
| - self._idx_to_entries.append([i]) |
91 |
| - next_id = len(self._deduper) # Because faiss index starts from 0 up |
92 |
| - self._deduper[h] = next_id |
93 |
| - else: |
94 |
| - # Since this already exists, we don't add it to Faiss because Faiss cannot handle duplication |
95 |
| - self._idx_to_entries[existing_faiss_id].append(i) |
96 |
| - |
97 |
| - |
98 |
| -class _PDQFaissIndex: |
99 |
| - """ |
100 |
| - A wrapper around the faiss index for pickle serialization |
101 |
| - """ |
102 |
| - |
103 |
| - def __init__(self, faiss_index: faiss.Index) -> None: |
104 |
| - self.faiss_index = faiss_index |
105 |
| - |
106 |
| - def add(self, pdq_strings: t.Sequence[str]) -> None: |
107 |
| - """ |
108 |
| - Add PDQ hashes to the FAISS index. |
109 |
| - """ |
110 |
| - vectors = convert_pdq_strings_to_ndarray(pdq_strings) |
111 |
| - self.faiss_index.add(vectors) |
| 120 | + """Add multiple entries to the index.""" |
| 121 | + for signal_str, entry in entries: |
| 122 | + self._dedupe_and_add(signal_str, entry) |
112 | 123 |
|
113 |
| - def search( |
114 |
| - self, queries: t.Sequence[str], threshold: int |
115 |
| - ) -> t.List[t.Tuple[int, int]]: |
| 124 | + @classmethod |
| 125 | + def build(cls: t.Type[Self], entries: t.Iterable[t.Tuple[str, IndexT]]) -> Self: |
116 | 126 | """
|
117 |
| - Search the FAISS index for matches to the given PDQ queries. |
| 127 | + Faiss has many potential options that we can choose based on the size of the index. |
118 | 128 | """
|
119 |
| - query_array: np.ndarray = convert_pdq_strings_to_ndarray(queries) |
120 |
| - limits, distances, indices = self.faiss_index.range_search( |
121 |
| - query_array, threshold + 1 |
122 |
| - ) |
123 |
| - |
124 |
| - results: t.List[t.Tuple[int, int]] = [] |
125 |
| - for i in range(len(queries)): |
126 |
| - matches = [idx.item() for idx in indices[limits[i] : limits[i + 1]]] |
127 |
| - dists = [dist for dist in distances[limits[i] : limits[i + 1]]] |
128 |
| - for j in range(len(matches)): |
129 |
| - results.append((matches[j], dists[j])) |
130 |
| - return results |
131 |
| - |
132 |
| - def __getstate__(self): |
133 |
| - return faiss.serialize_index(self.faiss_index) |
| 129 | + entry_list = list(entries) |
| 130 | + xn = len(entry_list) |
| 131 | + if xn < 1024: # If small enough, just use brute force |
| 132 | + return cls(entries=entry_list) |
| 133 | + |
| 134 | + centroids = pick_n_centroids(xn) |
| 135 | + index = faiss.IndexIVFFlat(faiss.IndexFlatL2(BITS_IN_PDQ), BITS_IN_PDQ, centroids) |
| 136 | + index.nprobe = 16 # 16-64 should be high enough accuracy for 1-10M |
| 137 | + index.cp.min_points_per_centroid = 1 # Allow small clusters |
| 138 | + |
| 139 | + ret = cls(faiss_index=index) |
| 140 | + for signal_str, entry in entry_list: |
| 141 | + ret._dedupe_and_add(signal_str, entry, add_to_faiss=False) |
| 142 | + |
| 143 | + xb = convert_pdq_strings_to_ndarray([s for s in ret._deduper]) |
| 144 | + ret.faiss_index.train(xb) |
| 145 | + ret.faiss_index.add(xb) |
| 146 | + return ret |
134 | 147 |
|
135 |
| - def __setstate__(self, data): |
136 |
| - self.faiss_index = faiss.deserialize_index(data) |
| 148 | + def __len__(self) -> int: |
| 149 | + return len(self._idx_to_entries) |
137 | 150 |
|
138 | 151 |
|
139 | 152 | class PDQSignalTypeIndex2(SignalTypeIndex[IndexT]):
|
|
0 commit comments