Skip to content

Commit 553c78b

Browse files
mdouzemeta-codesync[bot]
authored andcommitted
HNSW SIMD dispatch (#5015)
Summary: Pull Request resolved: #5015 Extract `HNSW::MinimaxHeap` into a standalone `faiss::MinimaxHeap` struct in a new `impl/hnsw/` subdirectory, and convert `pop_min` to use FAISS SIMD dynamic dispatch. ### Changes **New files:** - `impl/hnsw/MinimaxHeap.h` — Standalone `faiss::MinimaxHeap` struct (no longer nested in `HNSW`). Trivial methods (`max`, `size`, `clear`) inlined. Declares `template<SIMDLevel> pop_min_tpl` and runtime-dispatched `pop_min`. - `impl/hnsw/MinimaxHeap.cpp` — `push`, `count_below`, scalar `pop_min_tpl<NONE>`, and `pop_min` dispatch wrapper using `with_selected_simd_levels`. - `impl/hnsw/avx2.cpp` — `pop_min_tpl<AVX2>` specialization, guarded by `COMPILE_SIMD_AVX2`. - `impl/hnsw/avx512.cpp` — `pop_min_tpl<AVX512>` specialization, guarded by `COMPILE_SIMD_AVX512`. - `impl/hnsw/.nobuck` — Prevents autodeps from generating a BUCK file (sources are part of top-level faiss target). **Modified files:** - `impl/HNSW.h` — Includes new header, removed nested struct, free functions use `MinimaxHeap&`. - `impl/HNSW.cpp` — Removed all `MinimaxHeap` implementations. - `IndexHNSW.cpp` — Removed stale `using MinimaxHeap = HNSW::MinimaxHeap` alias. - `tests/test_hnsw.cpp` — Updated `faiss::HNSW::MinimaxHeap` → `faiss::MinimaxHeap`. - `CMakeLists.txt` — Added new source/header files and SIMD file registrations. - `xplat.bzl` — Added new source/header files and SIMD_FILES entries. Reviewed By: algoriddle Differential Revision: D98920170 fbshipit-source-id: 7c4d04e7eb6647b359efb2fa2f86c9f2a084d189
1 parent 499b488 commit 553c78b

11 files changed

Lines changed: 383 additions & 300 deletions

File tree

faiss/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# =============================================================================
1111
set(FAISS_SIMD_AVX2_SRC
1212
impl/fast_scan/impl-avx2.cpp
13+
impl/hnsw/avx2.cpp
1314
impl/pq_code_distance/pq_code_distance-avx2.cpp
1415
impl/scalar_quantizer/sq-avx2.cpp
1516
impl/approx_topk/avx2.cpp
@@ -19,6 +20,7 @@ set(FAISS_SIMD_AVX2_SRC
1920
)
2021
set(FAISS_SIMD_AVX512_SRC
2122
impl/fast_scan/impl-avx512.cpp
23+
impl/hnsw/avx512.cpp
2224
impl/pq_code_distance/pq_code_distance-avx512.cpp
2325
impl/scalar_quantizer/sq-avx512.cpp
2426
utils/simd_impl/distances_avx512.cpp
@@ -108,6 +110,7 @@ set(FAISS_SRC
108110
impl/IDSelector.cpp
109111
impl/FaissException.cpp
110112
impl/HNSW.cpp
113+
impl/hnsw/MinimaxHeap.cpp
111114
impl/NSG.cpp
112115
impl/PolysemousTraining.cpp
113116
impl/ProductQuantizer.cpp
@@ -232,6 +235,7 @@ set(FAISS_HEADERS
232235
impl/FaissAssert.h
233236
impl/FaissException.h
234237
impl/HNSW.h
238+
impl/hnsw/MinimaxHeap.h
235239
impl/LocalSearchQuantizer.h
236240
impl/ProductAdditiveQuantizer.h
237241
impl/fast_scan/LookupTableScaler.h

faiss/IndexHNSW.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333

3434
namespace faiss {
3535

36-
using MinimaxHeap = HNSW::MinimaxHeap;
3736
using storage_idx_t = HNSW::storage_idx_t;
3837
using NodeDistFarther = HNSW::NodeDistFarther;
3938

faiss/impl/HNSW.cpp

Lines changed: 0 additions & 262 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,6 @@
1717
#include <faiss/impl/ResultHandler.h>
1818
#include <faiss/impl/VisitedTable.h>
1919

20-
#ifdef __AVX2__
21-
#include <immintrin.h>
22-
23-
#include <limits>
24-
#include <type_traits>
25-
#endif
26-
2720
namespace faiss {
2821

2922
/**************************************************************
@@ -597,7 +590,6 @@ void HNSW::add_with_locks(
597590
* Searching
598591
**************************************************************/
599592

600-
using MinimaxHeap = HNSW::MinimaxHeap;
601593
using Node = HNSW::Node;
602594
using C = HNSW::C;
603595

@@ -1190,7 +1182,6 @@ HNSWStats greedy_update_nearest(
11901182
}
11911183

11921184
namespace {
1193-
using MinimaxHeap = HNSW::MinimaxHeap;
11941185
using Node = HNSW::Node;
11951186
using C = HNSW::C;
11961187

@@ -1388,257 +1379,4 @@ void HNSW::permute_entries(const idx_t* map) {
13881379
neighbors = std::move(new_neighbors);
13891380
}
13901381

1391-
/**************************************************************
1392-
* MinimaxHeap
1393-
**************************************************************/
1394-
1395-
void HNSW::MinimaxHeap::push(storage_idx_t i, float v) {
1396-
if (k == n) {
1397-
if (v >= dis[0]) {
1398-
return;
1399-
}
1400-
if (ids[0] != -1) {
1401-
--nvalid;
1402-
}
1403-
faiss::heap_pop<HC>(k--, dis.data(), ids.data());
1404-
}
1405-
faiss::heap_push<HC>(++k, dis.data(), ids.data(), v, i);
1406-
++nvalid;
1407-
}
1408-
1409-
float HNSW::MinimaxHeap::max() const {
1410-
return dis[0];
1411-
}
1412-
1413-
int HNSW::MinimaxHeap::size() const {
1414-
return nvalid;
1415-
}
1416-
1417-
void HNSW::MinimaxHeap::clear() {
1418-
nvalid = k = 0;
1419-
}
1420-
1421-
#ifdef __AVX512F__
1422-
1423-
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1424-
assert(k > 0);
1425-
static_assert(
1426-
std::is_same<storage_idx_t, int32_t>::value,
1427-
"This code expects storage_idx_t to be int32_t");
1428-
1429-
int32_t min_idx = -1;
1430-
float min_dis = std::numeric_limits<float>::infinity();
1431-
1432-
__m512i min_indices = _mm512_set1_epi32(-1);
1433-
__m512 min_distances =
1434-
_mm512_set1_ps(std::numeric_limits<float>::infinity());
1435-
__m512i current_indices = _mm512_setr_epi32(
1436-
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
1437-
__m512i offset = _mm512_set1_epi32(16);
1438-
1439-
// The following loop tracks the rightmost index with the min distance.
1440-
// -1 index values are ignored.
1441-
const size_t k16 = (k / 16) * 16;
1442-
for (size_t iii = 0; iii < k16; iii += 16) {
1443-
__m512i indices =
1444-
_mm512_loadu_si512((const __m512i*)(ids.data() + iii));
1445-
__m512 distances = _mm512_loadu_ps(dis.data() + iii);
1446-
1447-
// This mask filters out -1 values among indices.
1448-
__mmask16 m1mask =
1449-
_mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
1450-
1451-
__mmask16 dmask =
1452-
_mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
1453-
__mmask16 finalmask = m1mask | dmask;
1454-
1455-
const __m512i min_indices_new = _mm512_mask_blend_epi32(
1456-
finalmask, current_indices, min_indices);
1457-
const __m512 min_distances_new =
1458-
_mm512_mask_blend_ps(finalmask, distances, min_distances);
1459-
1460-
min_indices = min_indices_new;
1461-
min_distances = min_distances_new;
1462-
1463-
current_indices = _mm512_add_epi32(current_indices, offset);
1464-
}
1465-
1466-
// leftovers
1467-
if (k16 != static_cast<size_t>(k)) {
1468-
const __mmask16 kmask = (1 << (k - k16)) - 1;
1469-
1470-
__m512i indices = _mm512_mask_loadu_epi32(
1471-
_mm512_set1_epi32(-1), kmask, ids.data() + k16);
1472-
__m512 distances = _mm512_maskz_loadu_ps(kmask, dis.data() + k16);
1473-
1474-
// This mask filters out -1 values among indices.
1475-
__mmask16 m1mask =
1476-
_mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
1477-
1478-
__mmask16 dmask =
1479-
_mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
1480-
__mmask16 finalmask = m1mask | dmask;
1481-
1482-
const __m512i min_indices_new = _mm512_mask_blend_epi32(
1483-
finalmask, current_indices, min_indices);
1484-
const __m512 min_distances_new =
1485-
_mm512_mask_blend_ps(finalmask, distances, min_distances);
1486-
1487-
min_indices = min_indices_new;
1488-
min_distances = min_distances_new;
1489-
}
1490-
1491-
// grab min distance
1492-
min_dis = _mm512_reduce_min_ps(min_distances);
1493-
// blend
1494-
__mmask16 mindmask =
1495-
_mm512_cmpeq_ps_mask(min_distances, _mm512_set1_ps(min_dis));
1496-
// pick the max one
1497-
min_idx = _mm512_mask_reduce_max_epi32(mindmask, min_indices);
1498-
1499-
if (min_idx == -1) {
1500-
return -1;
1501-
}
1502-
1503-
if (vmin_out) {
1504-
*vmin_out = min_dis;
1505-
}
1506-
int ret = ids[min_idx];
1507-
ids[min_idx] = -1;
1508-
--nvalid;
1509-
return ret;
1510-
}
1511-
1512-
#elif __AVX2__
1513-
1514-
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1515-
assert(k > 0);
1516-
static_assert(
1517-
std::is_same<storage_idx_t, int32_t>::value,
1518-
"This code expects storage_idx_t to be int32_t");
1519-
1520-
int32_t min_idx = -1;
1521-
float min_dis = std::numeric_limits<float>::infinity();
1522-
1523-
size_t iii = 0;
1524-
1525-
__m256i min_indices = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1);
1526-
__m256 min_distances =
1527-
_mm256_set1_ps(std::numeric_limits<float>::infinity());
1528-
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
1529-
__m256i offset = _mm256_set1_epi32(8);
1530-
1531-
// The baseline version is available in non-AVX2 branch.
1532-
1533-
// The following loop tracks the rightmost index with the min distance.
1534-
// -1 index values are ignored.
1535-
const size_t k8 = (k / 8) * 8;
1536-
for (; iii < k8; iii += 8) {
1537-
__m256i indices =
1538-
_mm256_loadu_si256((const __m256i*)(ids.data() + iii));
1539-
__m256 distances = _mm256_loadu_ps(dis.data() + iii);
1540-
1541-
// This mask filters out -1 values among indices.
1542-
__m256i m1mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), indices);
1543-
1544-
__m256i dmask = _mm256_castps_si256(
1545-
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS));
1546-
__m256 finalmask = _mm256_castsi256_ps(_mm256_or_si256(m1mask, dmask));
1547-
1548-
const __m256i min_indices_new = _mm256_castps_si256(_mm256_blendv_ps(
1549-
_mm256_castsi256_ps(current_indices),
1550-
_mm256_castsi256_ps(min_indices),
1551-
finalmask));
1552-
1553-
const __m256 min_distances_new =
1554-
_mm256_blendv_ps(distances, min_distances, finalmask);
1555-
1556-
min_indices = min_indices_new;
1557-
min_distances = min_distances_new;
1558-
1559-
current_indices = _mm256_add_epi32(current_indices, offset);
1560-
}
1561-
1562-
// Vectorizing is doable, but is not practical
1563-
int32_t vidx8[8];
1564-
float vdis8[8];
1565-
_mm256_storeu_ps(vdis8, min_distances);
1566-
_mm256_storeu_si256((__m256i*)vidx8, min_indices);
1567-
1568-
for (size_t j = 0; j < 8; j++) {
1569-
if (min_dis > vdis8[j] || (min_dis == vdis8[j] && min_idx < vidx8[j])) {
1570-
min_idx = vidx8[j];
1571-
min_dis = vdis8[j];
1572-
}
1573-
}
1574-
1575-
// process last values. Vectorizing is doable, but is not practical
1576-
for (; iii < static_cast<size_t>(k); iii++) {
1577-
if (ids[iii] != -1 && dis[iii] <= min_dis) {
1578-
min_dis = dis[iii];
1579-
min_idx = iii;
1580-
}
1581-
}
1582-
1583-
if (min_idx == -1) {
1584-
return -1;
1585-
}
1586-
1587-
if (vmin_out) {
1588-
*vmin_out = min_dis;
1589-
}
1590-
int ret = ids[min_idx];
1591-
ids[min_idx] = -1;
1592-
--nvalid;
1593-
return ret;
1594-
}
1595-
1596-
#else
1597-
1598-
// baseline non-vectorized version
1599-
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1600-
assert(k > 0);
1601-
// returns min. This is an O(n) operation
1602-
int i = k - 1;
1603-
while (i >= 0) {
1604-
if (ids[i] != -1) {
1605-
break;
1606-
}
1607-
i--;
1608-
}
1609-
if (i == -1) {
1610-
return -1;
1611-
}
1612-
int imin = i;
1613-
float vmin = dis[i];
1614-
i--;
1615-
while (i >= 0) {
1616-
if (ids[i] != -1 && dis[i] < vmin) {
1617-
vmin = dis[i];
1618-
imin = i;
1619-
}
1620-
i--;
1621-
}
1622-
if (vmin_out) {
1623-
*vmin_out = vmin;
1624-
}
1625-
int ret = ids[imin];
1626-
ids[imin] = -1;
1627-
--nvalid;
1628-
1629-
return ret;
1630-
}
1631-
#endif
1632-
1633-
int HNSW::MinimaxHeap::count_below(float thresh) {
1634-
int n_below = 0;
1635-
for (int i = 0; i < k; i++) {
1636-
if (dis[i] < thresh) {
1637-
n_below++;
1638-
}
1639-
}
1640-
1641-
return n_below;
1642-
}
1643-
16441382
} // namespace faiss

faiss/impl/HNSW.h

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <faiss/Index.h>
1717
#include <faiss/impl/DistanceComputer.h>
1818
#include <faiss/impl/FaissAssert.h>
19+
#include <faiss/impl/hnsw/MinimaxHeap.h>
1920
#include <faiss/impl/maybe_owned_vector.h>
2021
#include <faiss/impl/platform_macros.h>
2122
#include <faiss/utils/Heap.h>
@@ -64,33 +65,6 @@ struct HNSW {
6465

6566
typedef std::pair<float, storage_idx_t> Node;
6667

67-
/** Heap structure that allows fast access and updates.
68-
*/
69-
struct MinimaxHeap {
70-
int n;
71-
int k;
72-
int nvalid;
73-
74-
std::vector<storage_idx_t> ids;
75-
std::vector<float> dis;
76-
typedef faiss::CMax<float, storage_idx_t> HC;
77-
78-
explicit MinimaxHeap(int n_in)
79-
: n(n_in), k(0), nvalid(0), ids(n_in), dis(n_in) {}
80-
81-
void push(storage_idx_t i, float v);
82-
83-
float max() const;
84-
85-
int size() const;
86-
87-
void clear();
88-
89-
int pop_min(float* vmin_out = nullptr);
90-
91-
int count_below(float thresh);
92-
};
93-
9468
/// to sort pairs of (id, distance) from nearest to farthest or the reverse
9569
struct NodeDistCloser {
9670
float d;
@@ -280,7 +254,7 @@ int search_from_candidates(
280254
const HNSW& hnsw,
281255
DistanceComputer& qdis,
282256
ResultHandler& res,
283-
HNSW::MinimaxHeap& candidates,
257+
MinimaxHeap& candidates,
284258
VisitedTable& vt,
285259
HNSWStats& stats,
286260
int level,
@@ -296,7 +270,7 @@ int search_from_candidates_panorama(
296270
const IndexHNSW* index,
297271
DistanceComputer& qdis,
298272
ResultHandler& res,
299-
HNSW::MinimaxHeap& candidates,
273+
MinimaxHeap& candidates,
300274
VisitedTable& vt,
301275
HNSWStats& stats,
302276
int level,

faiss/impl/hnsw/.nobuck

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# This directory's sources are compiled as part of the top-level faiss target.
2+
# No separate BUCK file is needed here.

0 commit comments

Comments
 (0)