Skip to content

Commit 4f9d1dd

Browse files
committed
lint: black
1 parent d3d23d2 commit 4f9d1dd

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -208,4 +208,4 @@ def deserialize(cls, f: t.BinaryIO) -> "PDQSignalTypeIndex2":
208208

209209
ret = cls()
210210
ret._index = pickle.load(f)
211-
return ret
211+
return ret

python-threatexchange/threatexchange/signal_type/tests/test_pdq_signal_type_index2.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from threatexchange.signal_type.pdq.pdq_index import PDQIndex
1515

16+
1617
def _get_hash_generator(seed: int = 42):
1718
random.seed(seed)
1819

@@ -21,6 +22,7 @@ def get_n_hashes(n: int):
2122

2223
return get_n_hashes
2324

25+
2426
def _brute_force_match(
2527
base: t.List[str], query: str, threshold: int = PDQ_CONFIDENT_MATCH_THRESHOLD
2628
) -> t.Set[t.Tuple[int, int]]:
@@ -32,28 +34,32 @@ def _brute_force_match(
3234
matches.add((i, distance))
3335
return matches
3436

37+
3538
def test_flat_index_small_dataset():
3639
"""Test that flat index is used for small datasets."""
3740
get_random_hashes = _get_hash_generator()
38-
base_hashes = get_random_hashes(100) # Below IVF threshold here.
41+
base_hashes = get_random_hashes(100) # Below IVF threshold here.
3942
index = PDQSignalTypeIndex2.build([(h, base_hashes.index(h)) for h in base_hashes])
4043

4144
assert isinstance(index._index._index.faiss_index, faiss.IndexFlatL2)
4245

46+
4347
def test_ivf_index_large_dataset():
4448
"""Test that IVF index is used for large datasets."""
4549
get_random_hashes = _get_hash_generator()
46-
base_hashes = get_random_hashes(2000)
50+
base_hashes = get_random_hashes(2000)
4751
index = PDQSignalTypeIndex2.build([(h, base_hashes.index(h)) for h in base_hashes])
4852

4953
assert isinstance(index._index._index.faiss_index, faiss.IndexIVFFlat)
5054

55+
5156
def test_empty_index_query():
5257
"""Test querying an empty index."""
5358
index = PDQSignalTypeIndex2()
5459
results = index.query(PdqSignal.get_random_signal())
5560
assert len(results) == 0
5661

62+
5763
def test_single_hash_query():
5864
"""Test querying with a single hash."""
5965
hash_str = PdqSignal.get_random_signal()
@@ -65,6 +71,7 @@ def test_single_hash_query():
6571
assert results[0].metadata == "test_entry"
6672
assert results[0].similarity_info.distance == 0
6773

74+
6875
def test_add_all_and_query():
6976
"""Test adding multiple hashes and querying."""
7077
get_random_hashes = _get_hash_generator()
@@ -73,8 +80,9 @@ def test_add_all_and_query():
7380
index.add_all([(h, base_hashes.index(h)) for h in base_hashes])
7481

7582
results = index.query(base_hashes[0])
76-
assert len(results) >= 1
77-
assert any(r.metadata == 0 for r in results)
83+
assert len(results) >= 1
84+
assert any(r.metadata == 0 for r in results)
85+
7886

7987
def test_build_and_query():
8088
"""Test building index and querying."""
@@ -83,8 +91,9 @@ def test_build_and_query():
8391
index = PDQSignalTypeIndex2.build([(h, base_hashes.index(h)) for h in base_hashes])
8492

8593
results = index.query(base_hashes[0])
86-
assert len(results) >= 1
87-
assert any(r.metadata == 0 for r in results)
94+
assert len(results) >= 1
95+
assert any(r.metadata == 0 for r in results)
96+
8897

8998
def test_len():
9099
"""Test length reporting."""
@@ -96,6 +105,7 @@ def test_len():
96105
index.add_all([(h, base_hashes.index(h)) for h in base_hashes])
97106
assert len(index) == 100
98107

108+
99109
def test_serialize_deserialize_empty_index():
100110
"""Test serialization/deserialization of empty index."""
101111
index = PDQSignalTypeIndex2()
@@ -108,10 +118,11 @@ def test_serialize_deserialize_empty_index():
108118
assert len(deserialized) == 0
109119
assert deserialized._index is None
110120

121+
111122
def test_serialize_deserialize_flat_index():
112123
"""Test serialization/deserialization of flat index."""
113124
get_random_hashes = _get_hash_generator()
114-
base_hashes = get_random_hashes(100)
125+
base_hashes = get_random_hashes(100)
115126
index = PDQSignalTypeIndex2.build([(h, base_hashes.index(h)) for h in base_hashes])
116127
buffer = io.BytesIO()
117128

@@ -126,10 +137,11 @@ def test_serialize_deserialize_flat_index():
126137
assert len(results) >= 1
127138
assert any(r.metadata == 0 for r in results)
128139

140+
129141
def test_serialize_deserialize_ivf_index():
130142
"""Test serialization/deserialization of IVF index."""
131143
get_random_hashes = _get_hash_generator()
132-
base_hashes = get_random_hashes(2000)
144+
base_hashes = get_random_hashes(2000)
133145
index = PDQSignalTypeIndex2.build([(h, base_hashes.index(h)) for h in base_hashes])
134146
buffer = io.BytesIO()
135147

@@ -144,6 +156,7 @@ def test_serialize_deserialize_ivf_index():
144156
assert len(results) >= 1
145157
assert any(r.metadata == 0 for r in results)
146158

159+
147160
def test_compatibility_with_old_index():
148161
"""Test that we can read indices serialized with the old PDQIndex class."""
149162

@@ -168,4 +181,4 @@ def test_compatibility_with_old_index():
168181
assert len(old_results) >= 1
169182
assert len(new_results) >= 1
170183
assert any(r.metadata == 0 for r in old_results)
171-
assert any(r.metadata == 0 for r in new_results)
184+
assert any(r.metadata == 0 for r in new_results)

0 commit comments

Comments
 (0)