Skip to content

Commit c230ff6

Browse files
committed
refactor
1 parent 4f9d1dd commit c230ff6

File tree

1 file changed

+99
-86
lines changed
  • python-threatexchange/threatexchange/signal_type/pdq

1 file changed

+99
-86
lines changed

python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py

+99-86
Original file line numberDiff line numberDiff line change
@@ -24,116 +24,129 @@
2424

2525
PDQIndexMatch = IndexMatchUntyped[SignalSimilarityInfoWithIntDistance, IndexT]
2626

27+
def pick_n_centroids(n: int) -> int:
28+
"""Pick number of centroids based on dataset size."""
29+
return max(int(n ** 0.5), 1)
30+
31+
class _PDQHashIndex:
32+
"""
33+
A wrapper around the faiss index for pickle serialization
34+
"""
35+
36+
def __init__(self, faiss_index: faiss.Index) -> None:
37+
self.faiss_index = faiss_index
38+
39+
def search(
40+
self,
41+
queries: t.Sequence[str],
42+
threshold: int,
43+
) -> t.List[t.List[t.Tuple[int, float]]]:
44+
"""
45+
Search method that returns a mapping from query_str => (id, distance)
46+
"""
47+
qs = convert_pdq_strings_to_ndarray(queries)
48+
limits, D, I = self.faiss_index.range_search(qs, threshold + 1)
49+
50+
results = []
51+
for i in range(len(queries)):
52+
matches = [result.item() for result in I[limits[i] : limits[i + 1]]]
53+
distances = [dist for dist in D[limits[i] : limits[i + 1]]]
54+
results.append(list(zip(matches, distances)))
55+
return results
56+
57+
def add(self, vectors: np.ndarray) -> None:
58+
"""Add vectors to the index."""
59+
self.faiss_index.add(vectors)
60+
61+
def train(self, vectors: np.ndarray) -> None:
62+
"""Train the index if needed."""
63+
if isinstance(self.faiss_index, faiss.IndexIVFFlat):
64+
self.faiss_index.train(vectors)
65+
66+
def __getstate__(self):
67+
return faiss.serialize_index(self.faiss_index)
68+
69+
def __setstate__(self, data):
70+
self.faiss_index = faiss.deserialize_index(data)
71+
2772

2873
class PDQIndex2(SignalTypeIndex[IndexT]):
2974
"""
3075
Indexing and querying PDQ signals using Faiss for approximate nearest neighbor search.
31-
32-
This is a redo of the existing PDQ index,
33-
designed to be simpler and fix hard-to-squash bugs in the existing implementation.
34-
Purpose of this class: to replace the original index in pytx 2.0
3576
"""
3677

3778
def __init__(
3879
self,
39-
index: t.Optional[faiss.Index] = None,
40-
entries: t.Iterable[t.Tuple[str, IndexT]] = (),
41-
*,
4280
threshold: int = PDQ_CONFIDENT_MATCH_THRESHOLD,
81+
faiss_index: t.Optional[faiss.Index] = None,
4382
) -> None:
4483
super().__init__()
84+
if faiss_index is None:
85+
faiss_index = faiss.IndexFlatL2(BITS_IN_PDQ)
86+
self.faiss_index = _PDQHashIndex(faiss_index)
4587
self.threshold = threshold
46-
47-
if index is None:
48-
index = faiss.IndexFlatL2(BITS_IN_PDQ)
49-
self._index = _PDQFaissIndex(index)
50-
51-
# Matches hash to Faiss index
5288
self._deduper: t.Dict[str, int] = {}
53-
# Entry mapping: Each list[entries]'s index is its hash's index
5489
self._idx_to_entries: t.List[t.List[IndexT]] = []
5590

56-
self.add_all(entries=entries)
57-
58-
def __len__(self) -> int:
59-
return len(self._idx_to_entries)
60-
61-
def query(self, hash: str) -> t.Sequence[PDQIndexMatch[IndexT]]:
62-
"""
63-
Look up entries against the index, up to the threshold.
64-
"""
65-
results: t.List[PDQIndexMatch[IndexT]] = []
66-
matches_list: t.List[t.Tuple[int, int]] = self._index.search(
67-
queries=[hash], threshold=self.threshold
68-
)
69-
70-
for match, distance in matches_list:
71-
entries = self._idx_to_entries[match]
72-
# Create match objects for each entry
73-
results.extend(
74-
PDQIndexMatch(
75-
SignalSimilarityInfoWithIntDistance(distance=distance),
76-
entry,
77-
)
78-
for entry in entries
91+
def query(self, query: str) -> t.Sequence[PDQIndexMatch[IndexT]]:
92+
"""Look up entries against the index."""
93+
results = self.faiss_index.search([query], self.threshold)
94+
return [
95+
PDQIndexMatch(
96+
SignalSimilarityInfoWithIntDistance(distance=int(distf)),
97+
entry
7998
)
80-
return results
99+
for idx, distf in results[0]
100+
for entry in self._idx_to_entries[idx]
101+
]
102+
103+
def _dedupe_and_add(self, signal_str: str, entry: IndexT, add_to_faiss: bool = True) -> None:
104+
"""Helper to handle deduplication and adding entries."""
105+
existing_id = self._deduper.get(signal_str)
106+
if existing_id is None:
107+
next_id = len(self._deduper)
108+
self._deduper[signal_str] = next_id
109+
self._idx_to_entries.append([entry])
110+
if add_to_faiss:
111+
self.faiss_index.add(convert_pdq_strings_to_ndarray([signal_str]))
112+
else:
113+
self._idx_to_entries[existing_id].append(entry)
81114

82115
def add(self, signal_str: str, entry: IndexT) -> None:
83-
self.add_all(((signal_str, entry),))
116+
"""Add an entry to the index."""
117+
self._dedupe_and_add(signal_str, entry)
84118

85119
def add_all(self, entries: t.Iterable[t.Tuple[str, IndexT]]) -> None:
86-
for h, i in entries:
87-
existing_faiss_id = self._deduper.get(h)
88-
if existing_faiss_id is None:
89-
self._index.add([h])
90-
self._idx_to_entries.append([i])
91-
next_id = len(self._deduper) # Because faiss index starts from 0 up
92-
self._deduper[h] = next_id
93-
else:
94-
# Since this already exists, we don't add it to Faiss because Faiss cannot handle duplication
95-
self._idx_to_entries[existing_faiss_id].append(i)
96-
97-
98-
class _PDQFaissIndex:
99-
"""
100-
A wrapper around the faiss index for pickle serialization
101-
"""
102-
103-
def __init__(self, faiss_index: faiss.Index) -> None:
104-
self.faiss_index = faiss_index
105-
106-
def add(self, pdq_strings: t.Sequence[str]) -> None:
107-
"""
108-
Add PDQ hashes to the FAISS index.
109-
"""
110-
vectors = convert_pdq_strings_to_ndarray(pdq_strings)
111-
self.faiss_index.add(vectors)
120+
"""Add multiple entries to the index."""
121+
for signal_str, entry in entries:
122+
self._dedupe_and_add(signal_str, entry)
112123

113-
def search(
114-
self, queries: t.Sequence[str], threshold: int
115-
) -> t.List[t.Tuple[int, int]]:
124+
@classmethod
125+
def build(cls: t.Type[Self], entries: t.Iterable[t.Tuple[str, IndexT]]) -> Self:
116126
"""
117-
Search the FAISS index for matches to the given PDQ queries.
127+
Faiss has many potential options that we can choose based on the size of the index.
118128
"""
119-
query_array: np.ndarray = convert_pdq_strings_to_ndarray(queries)
120-
limits, distances, indices = self.faiss_index.range_search(
121-
query_array, threshold + 1
122-
)
123-
124-
results: t.List[t.Tuple[int, int]] = []
125-
for i in range(len(queries)):
126-
matches = [idx.item() for idx in indices[limits[i] : limits[i + 1]]]
127-
dists = [dist for dist in distances[limits[i] : limits[i + 1]]]
128-
for j in range(len(matches)):
129-
results.append((matches[j], dists[j]))
130-
return results
131-
132-
def __getstate__(self):
133-
return faiss.serialize_index(self.faiss_index)
129+
entry_list = list(entries)
130+
xn = len(entry_list)
131+
if xn < 1024: # If small enough, just use brute force
132+
return cls(entries=entry_list)
133+
134+
centroids = pick_n_centroids(xn)
135+
index = faiss.IndexIVFFlat(faiss.IndexFlatL2(BITS_IN_PDQ), BITS_IN_PDQ, centroids)
136+
index.nprobe = 16 # 16-64 should be high enough accuracy for 1-10M
137+
index.cp.min_points_per_centroid = 1 # Allow small clusters
138+
139+
ret = cls(faiss_index=index)
140+
for signal_str, entry in entry_list:
141+
ret._dedupe_and_add(signal_str, entry, add_to_faiss=False)
142+
143+
xb = convert_pdq_strings_to_ndarray([s for s in ret._deduper])
144+
ret.faiss_index.train(xb)
145+
ret.faiss_index.add(xb)
146+
return ret
134147

135-
def __setstate__(self, data):
136-
self.faiss_index = faiss.deserialize_index(data)
148+
def __len__(self) -> int:
149+
return len(self._idx_to_entries)
137150

138151

139152
class PDQSignalTypeIndex2(SignalTypeIndex[IndexT]):

0 commit comments

Comments
 (0)