Skip to content

Commit aa3ce37

Browse files
alibeklfcmeta-codesync[bot]
authored andcommitted
Fix race condition, memory management, debug output, and hashtable lookup in sorting.cpp (facebookresearch#5078)
Summary: Pull Request resolved: facebookresearch#5078 Four fixes in `faiss/utils/sorting.cpp`: **1. OpenMP directive fix in `fvec_argsort_parallel`** The initialization loop used `#pragma omp parallel` without the `for` directive. This caused every thread to execute the entire loop independently rather than distributing iterations. With `nt` threads, each `permA[i]` was written by all `nt` threads concurrently — a data race under the C++ memory model (multiple unsynchronized writes to the same non-atomic location), and O(n * nt) wasted work instead of O(n). Fixed by changing to `#pragma omp parallel for`. In practice, all threads write the same value (`permA[i] = i`), so the output was always correct despite the UB. The fix eliminates the undefined behavior and the redundant work. **2. RAII memory management in `fvec_argsort_parallel`** Replaced `new size_t[n]` / `delete[] perm2` with `std::vector<size_t>`. The old code had no realistic exception path between allocation and deallocation (all intermediate code is either C functions or non-throwing OpenMP regions), but the manual `new`/`delete` pattern is fragile against future edits that might introduce a throwing path. The `std::vector` provides RAII lifetime management with no behavioral change. **3. Removed debug `printf` in `fvec_argsort_parallel`** A leftover `printf("merge %d %d, %d threads\n", ...)` in the parallel merge loop wrote to stdout during normal operation. Removed. **4. Missing early termination in `hashtable_int64_to_int64_lookup`** The linear probing loop did not check for empty slots (`tab[slot * 2] == -1`). In an open-addressing hash table with no deletion support, an empty slot is definitive proof that the key was not inserted — the insert function would have placed it there or earlier. Without this check, lookups for absent keys probed every slot in the bucket before the wrap-around termination at `slot == hk_i`. The fix adds the standard empty-slot check, matching the structure of the insert function (`hashtable_int64_to_int64_add`). This is a performance optimization — the old code always returned the correct result (`-1` after a full bucket scan), just slower. Reviewed By: mnorris11 Differential Revision: D100317917 fbshipit-source-id: aadfe33b1d76c34e04db7fe0c9b7ca53b4a30c71
1 parent b068fd9 commit aa3ce37

File tree

2 files changed

+87
-6
lines changed

2 files changed

+87
-6
lines changed

faiss/utils/sorting.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@ void fvec_argsort(size_t n, const float* vals, size_t* perm) {
134134
}
135135

136136
void fvec_argsort_parallel(size_t n, const float* vals, size_t* perm) {
137-
size_t* perm2 = new size_t[n];
137+
std::vector<size_t> perm2(n);
138138
// 2 result tables, during merging, flip between them
139-
size_t *permB = perm2, *permA = perm;
139+
size_t *permB = perm2.data(), *permA = perm;
140140

141141
int nt = omp_get_max_threads();
142142
{ // prepare correct permutation so that the result ends in perm
@@ -148,8 +148,8 @@ void fvec_argsort_parallel(size_t n, const float* vals, size_t* perm) {
148148
}
149149
}
150150

151-
#pragma omp parallel
152-
for (size_t i = 0; i < n; i++) {
151+
#pragma omp parallel for
152+
for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
153153
permA[i] = i;
154154
}
155155

@@ -184,7 +184,6 @@ void fvec_argsort_parallel(size_t n, const float* vals, size_t* perm) {
184184
} else {
185185
int t0 = s * sub_nt / sub_nseg1;
186186
int t1 = (s + 1) * sub_nt / sub_nseg1;
187-
printf("merge %d %d, %d threads\n", s, s + 1, t1 - t0);
188187
parallel_merge(
189188
permA, permB, segs[s], segs[s + 1], t1 - t0, comp);
190189
}
@@ -197,7 +196,6 @@ void fvec_argsort_parallel(size_t n, const float* vals, size_t* perm) {
197196
}
198197
assert(permA == perm);
199198
omp_set_nested(prev_nested);
200-
delete[] perm2;
201199
}
202200

203201
/*****************************************************************************
@@ -816,6 +814,10 @@ void hashtable_int64_to_int64_lookup(
816814
size_t k0 = bucket << (log2_capacity - log2_nbucket);
817815
size_t k1 = (bucket + 1) << (log2_capacity - log2_nbucket);
818816
for (;;) {
817+
if (tab[slot * 2] == -1) { // empty slot, key not in table
818+
vals[i] = -1;
819+
break;
820+
}
819821
if (tab[slot * 2] == k) { // found!
820822
vals[i] = tab[2 * slot + 1];
821823
break;

tests/test_sorting.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include <gtest/gtest.h>
9+
10+
#include <random>
11+
#include <vector>
12+
13+
#include <faiss/utils/sorting.h>
14+
15+
TEST(TestSorting, argsort_parallel_matches_serial) {
16+
// n > 1M to exercise the parallel merge path
17+
size_t n = 2000000;
18+
19+
std::vector<float> vals(n);
20+
std::mt19937 rng(42);
21+
std::uniform_real_distribution<float> dist(-1000.0f, 1000.0f);
22+
for (size_t i = 0; i < n; i++) {
23+
vals[i] = dist(rng);
24+
}
25+
26+
std::vector<size_t> perm_serial(n);
27+
faiss::fvec_argsort(n, vals.data(), perm_serial.data());
28+
29+
std::vector<size_t> perm_parallel(n);
30+
faiss::fvec_argsort_parallel(n, vals.data(), perm_parallel.data());
31+
32+
// Permutations may differ on ties, but sorted values must match
33+
for (size_t i = 0; i < n; i++) {
34+
ASSERT_FLOAT_EQ(vals[perm_serial[i]], vals[perm_parallel[i]])
35+
<< "mismatch at position " << i;
36+
}
37+
}
38+
39+
TEST(TestSorting, hashtable_lookup) {
40+
int log2_capacity = 12;
41+
size_t capacity = (size_t)1 << log2_capacity;
42+
43+
std::vector<int64_t> tab(capacity * 2);
44+
faiss::hashtable_int64_to_int64_init(log2_capacity, tab.data());
45+
46+
size_t n = 200;
47+
std::vector<int64_t> keys(n), vals(n);
48+
for (size_t i = 0; i < n; i++) {
49+
keys[i] = static_cast<int64_t>(i * 3);
50+
vals[i] = static_cast<int64_t>(i + 1);
51+
}
52+
faiss::hashtable_int64_to_int64_add(
53+
log2_capacity, tab.data(), n, keys.data(), vals.data());
54+
55+
// Interleave present and absent keys
56+
size_t n_query = n * 2;
57+
std::vector<int64_t> query_keys(n_query);
58+
std::vector<int64_t> expected(n_query);
59+
for (size_t i = 0; i < n; i++) {
60+
query_keys[2 * i] = keys[i];
61+
expected[2 * i] = vals[i];
62+
query_keys[2 * i + 1] =
63+
keys[i] + 1; // not a multiple of 3, never inserted
64+
expected[2 * i + 1] = -1;
65+
}
66+
67+
std::vector<int64_t> result(n_query);
68+
faiss::hashtable_int64_to_int64_lookup(
69+
log2_capacity,
70+
tab.data(),
71+
n_query,
72+
query_keys.data(),
73+
result.data());
74+
75+
for (size_t i = 0; i < n_query; i++) {
76+
ASSERT_EQ(result[i], expected[i])
77+
<< "query key " << query_keys[i] << " at index " << i;
78+
}
79+
}

0 commit comments

Comments
 (0)