Skip to content

Commit 5b304cb

Browse files
b8zhongDcallies
andcommitted
Initial implementation of IVF faiss indices
lint: black Update PDQ index and signal implementation remove unecessary test file remove unused import Update python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py Co-authored-by: David Callies <[email protected]> Some more PR comments Co-authored-by: David Callies <[email protected]>
1 parent db7960f commit 5b304cb

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import faiss
99
import numpy as np
1010

11-
1211
from threatexchange.signal_type.index import (
1312
IndexMatchUntyped,
1413
SignalSimilarityInfoWithIntDistance,
@@ -33,6 +32,29 @@ class PDQIndex2(SignalTypeIndex[IndexT]):
3332
Purpose of this class: to replace the original index in pytx 2.0
3433
"""
3534

35+
IVF_THRESHOLD = 1000
36+
37+
@classmethod
38+
def build(
39+
cls: t.Type["PDQIndex2"], entries: t.Iterable[t.Tuple[str, IndexT]]
40+
) -> "PDQIndex2":
41+
"""
42+
Build an index from a set of entries.
43+
Selects between flat and IVF index based on number of entries.
44+
"""
45+
entries_list = list(entries)
46+
47+
index = faiss.IndexFlatL2(BITS_IN_PDQ)
48+
if len(entries_list) >= cls.IVF_THRESHOLD:
49+
nlist = len(entries_list) // 2
50+
index = faiss.IndexIVFFlat(index, BITS_IN_PDQ, nlist)
51+
vectors = convert_pdq_strings_to_ndarray([h for h, _ in entries_list])
52+
index.train(vectors)
53+
else:
54+
index = faiss.IndexFlatL2(BITS_IN_PDQ)
55+
56+
return cls(index=index, entries=entries_list)
57+
3658
def __init__(
3759
self,
3860
index: t.Optional[faiss.Index] = None,

python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py

+28
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,31 @@ def test_one_entry_sample_index():
178178

179179
results = index.query(unmatching_test_hash)
180180
assert len(results) == 0
181+
182+
183+
def test_flat_index_selection():
184+
"""Test flat index selection for small datasets"""
185+
get_random_hashes = _get_hash_generator()
186+
base_hashes = get_random_hashes(100)
187+
index = PDQIndex2.build([(h, i) for i, h in enumerate(base_hashes)])
188+
assert isinstance(index._index.faiss_index, faiss.IndexFlatL2)
189+
190+
191+
def test_ivf_index_selection():
192+
"""Test IVF index selection for large datasets"""
193+
get_random_hashes = _get_hash_generator()
194+
base_hashes = get_random_hashes(2000)
195+
index = PDQIndex2.build([(h, i) for i, h in enumerate(base_hashes)])
196+
assert isinstance(index._index.faiss_index, faiss.IndexIVFFlat)
197+
198+
199+
def test_build_method():
200+
"""Test build method creates valid index"""
201+
get_random_hashes = _get_hash_generator()
202+
base_hashes = get_random_hashes(500)
203+
index = PDQIndex2.build([(h, i) for i, h in enumerate(base_hashes)])
204+
205+
query_hash = base_hashes[0]
206+
results = index.query(query_hash)
207+
assert len(results) >= 1
208+
assert any(r.metadata == 0 for r in results)

0 commit comments

Comments
 (0)