13
13
)
14
14
from threatexchange .signal_type .pdq .pdq_index import PDQIndex
15
15
16
+
16
17
def _get_hash_generator (seed : int = 42 ):
17
18
random .seed (seed )
18
19
@@ -21,6 +22,7 @@ def get_n_hashes(n: int):
21
22
22
23
return get_n_hashes
23
24
25
+
24
26
def _brute_force_match (
25
27
base : t .List [str ], query : str , threshold : int = PDQ_CONFIDENT_MATCH_THRESHOLD
26
28
) -> t .Set [t .Tuple [int , int ]]:
@@ -32,28 +34,32 @@ def _brute_force_match(
32
34
matches .add ((i , distance ))
33
35
return matches
34
36
37
+
35
38
def test_flat_index_small_dataset ():
36
39
"""Test that flat index is used for small datasets."""
37
40
get_random_hashes = _get_hash_generator ()
38
- base_hashes = get_random_hashes (100 ) # Below IVF threshold here.
41
+ base_hashes = get_random_hashes (100 ) # Below IVF threshold here.
39
42
index = PDQSignalTypeIndex2 .build ([(h , base_hashes .index (h )) for h in base_hashes ])
40
43
41
44
assert isinstance (index ._index ._index .faiss_index , faiss .IndexFlatL2 )
42
45
46
+
43
47
def test_ivf_index_large_dataset ():
44
48
"""Test that IVF index is used for large datasets."""
45
49
get_random_hashes = _get_hash_generator ()
46
- base_hashes = get_random_hashes (2000 )
50
+ base_hashes = get_random_hashes (2000 )
47
51
index = PDQSignalTypeIndex2 .build ([(h , base_hashes .index (h )) for h in base_hashes ])
48
52
49
53
assert isinstance (index ._index ._index .faiss_index , faiss .IndexIVFFlat )
50
54
55
+
51
56
def test_empty_index_query ():
52
57
"""Test querying an empty index."""
53
58
index = PDQSignalTypeIndex2 ()
54
59
results = index .query (PdqSignal .get_random_signal ())
55
60
assert len (results ) == 0
56
61
62
+
57
63
def test_single_hash_query ():
58
64
"""Test querying with a single hash."""
59
65
hash_str = PdqSignal .get_random_signal ()
@@ -65,6 +71,7 @@ def test_single_hash_query():
65
71
assert results [0 ].metadata == "test_entry"
66
72
assert results [0 ].similarity_info .distance == 0
67
73
74
+
68
75
def test_add_all_and_query ():
69
76
"""Test adding multiple hashes and querying."""
70
77
get_random_hashes = _get_hash_generator ()
@@ -73,8 +80,9 @@ def test_add_all_and_query():
73
80
index .add_all ([(h , base_hashes .index (h )) for h in base_hashes ])
74
81
75
82
results = index .query (base_hashes [0 ])
76
- assert len (results ) >= 1
77
- assert any (r .metadata == 0 for r in results )
83
+ assert len (results ) >= 1
84
+ assert any (r .metadata == 0 for r in results )
85
+
78
86
79
87
def test_build_and_query ():
80
88
"""Test building index and querying."""
@@ -83,8 +91,9 @@ def test_build_and_query():
83
91
index = PDQSignalTypeIndex2 .build ([(h , base_hashes .index (h )) for h in base_hashes ])
84
92
85
93
results = index .query (base_hashes [0 ])
86
- assert len (results ) >= 1
87
- assert any (r .metadata == 0 for r in results )
94
+ assert len (results ) >= 1
95
+ assert any (r .metadata == 0 for r in results )
96
+
88
97
89
98
def test_len ():
90
99
"""Test length reporting."""
@@ -96,6 +105,7 @@ def test_len():
96
105
index .add_all ([(h , base_hashes .index (h )) for h in base_hashes ])
97
106
assert len (index ) == 100
98
107
108
+
99
109
def test_serialize_deserialize_empty_index ():
100
110
"""Test serialization/deserialization of empty index."""
101
111
index = PDQSignalTypeIndex2 ()
@@ -108,10 +118,11 @@ def test_serialize_deserialize_empty_index():
108
118
assert len (deserialized ) == 0
109
119
assert deserialized ._index is None
110
120
121
+
111
122
def test_serialize_deserialize_flat_index ():
112
123
"""Test serialization/deserialization of flat index."""
113
124
get_random_hashes = _get_hash_generator ()
114
- base_hashes = get_random_hashes (100 )
125
+ base_hashes = get_random_hashes (100 )
115
126
index = PDQSignalTypeIndex2 .build ([(h , base_hashes .index (h )) for h in base_hashes ])
116
127
buffer = io .BytesIO ()
117
128
@@ -126,10 +137,11 @@ def test_serialize_deserialize_flat_index():
126
137
assert len (results ) >= 1
127
138
assert any (r .metadata == 0 for r in results )
128
139
140
+
129
141
def test_serialize_deserialize_ivf_index ():
130
142
"""Test serialization/deserialization of IVF index."""
131
143
get_random_hashes = _get_hash_generator ()
132
- base_hashes = get_random_hashes (2000 )
144
+ base_hashes = get_random_hashes (2000 )
133
145
index = PDQSignalTypeIndex2 .build ([(h , base_hashes .index (h )) for h in base_hashes ])
134
146
buffer = io .BytesIO ()
135
147
@@ -144,6 +156,7 @@ def test_serialize_deserialize_ivf_index():
144
156
assert len (results ) >= 1
145
157
assert any (r .metadata == 0 for r in results )
146
158
159
+
147
160
def test_compatibility_with_old_index ():
148
161
"""Test that we can read indices serialized with the old PDQIndex class."""
149
162
@@ -168,4 +181,4 @@ def test_compatibility_with_old_index():
168
181
assert len (old_results ) >= 1
169
182
assert len (new_results ) >= 1
170
183
assert any (r .metadata == 0 for r in old_results )
171
- assert any (r .metadata == 0 for r in new_results )
184
+ assert any (r .metadata == 0 for r in new_results )
0 commit comments