|
17 | 17 | #include <faiss/impl/ResultHandler.h> |
18 | 18 | #include <faiss/impl/VisitedTable.h> |
19 | 19 |
|
20 | | -#ifdef __AVX2__ |
21 | | -#include <immintrin.h> |
22 | | - |
23 | | -#include <limits> |
24 | | -#include <type_traits> |
25 | | -#endif |
26 | | - |
27 | 20 | namespace faiss { |
28 | 21 |
|
29 | 22 | /************************************************************** |
@@ -597,7 +590,6 @@ void HNSW::add_with_locks( |
597 | 590 | * Searching |
598 | 591 | **************************************************************/ |
599 | 592 |
|
600 | | -using MinimaxHeap = HNSW::MinimaxHeap; |
601 | 593 | using Node = HNSW::Node; |
602 | 594 | using C = HNSW::C; |
603 | 595 |
|
@@ -1190,7 +1182,6 @@ HNSWStats greedy_update_nearest( |
1190 | 1182 | } |
1191 | 1183 |
|
1192 | 1184 | namespace { |
1193 | | -using MinimaxHeap = HNSW::MinimaxHeap; |
1194 | 1185 | using Node = HNSW::Node; |
1195 | 1186 | using C = HNSW::C; |
1196 | 1187 |
|
@@ -1388,257 +1379,4 @@ void HNSW::permute_entries(const idx_t* map) { |
1388 | 1379 | neighbors = std::move(new_neighbors); |
1389 | 1380 | } |
1390 | 1381 |
|
1391 | | -/************************************************************** |
1392 | | - * MinimaxHeap |
1393 | | - **************************************************************/ |
1394 | | - |
1395 | | -void HNSW::MinimaxHeap::push(storage_idx_t i, float v) { |
1396 | | - if (k == n) { |
1397 | | - if (v >= dis[0]) { |
1398 | | - return; |
1399 | | - } |
1400 | | - if (ids[0] != -1) { |
1401 | | - --nvalid; |
1402 | | - } |
1403 | | - faiss::heap_pop<HC>(k--, dis.data(), ids.data()); |
1404 | | - } |
1405 | | - faiss::heap_push<HC>(++k, dis.data(), ids.data(), v, i); |
1406 | | - ++nvalid; |
1407 | | -} |
1408 | | - |
1409 | | -float HNSW::MinimaxHeap::max() const { |
1410 | | - return dis[0]; |
1411 | | -} |
1412 | | - |
1413 | | -int HNSW::MinimaxHeap::size() const { |
1414 | | - return nvalid; |
1415 | | -} |
1416 | | - |
1417 | | -void HNSW::MinimaxHeap::clear() { |
1418 | | - nvalid = k = 0; |
1419 | | -} |
1420 | | - |
1421 | | -#ifdef __AVX512F__ |
1422 | | - |
1423 | | -int HNSW::MinimaxHeap::pop_min(float* vmin_out) { |
1424 | | - assert(k > 0); |
1425 | | - static_assert( |
1426 | | - std::is_same<storage_idx_t, int32_t>::value, |
1427 | | - "This code expects storage_idx_t to be int32_t"); |
1428 | | - |
1429 | | - int32_t min_idx = -1; |
1430 | | - float min_dis = std::numeric_limits<float>::infinity(); |
1431 | | - |
1432 | | - __m512i min_indices = _mm512_set1_epi32(-1); |
1433 | | - __m512 min_distances = |
1434 | | - _mm512_set1_ps(std::numeric_limits<float>::infinity()); |
1435 | | - __m512i current_indices = _mm512_setr_epi32( |
1436 | | - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); |
1437 | | - __m512i offset = _mm512_set1_epi32(16); |
1438 | | - |
1439 | | - // The following loop tracks the rightmost index with the min distance. |
1440 | | - // -1 index values are ignored. |
1441 | | - const size_t k16 = (k / 16) * 16; |
1442 | | - for (size_t iii = 0; iii < k16; iii += 16) { |
1443 | | - __m512i indices = |
1444 | | - _mm512_loadu_si512((const __m512i*)(ids.data() + iii)); |
1445 | | - __m512 distances = _mm512_loadu_ps(dis.data() + iii); |
1446 | | - |
1447 | | - // This mask filters out -1 values among indices. |
1448 | | - __mmask16 m1mask = |
1449 | | - _mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices); |
1450 | | - |
1451 | | - __mmask16 dmask = |
1452 | | - _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS); |
1453 | | - __mmask16 finalmask = m1mask | dmask; |
1454 | | - |
1455 | | - const __m512i min_indices_new = _mm512_mask_blend_epi32( |
1456 | | - finalmask, current_indices, min_indices); |
1457 | | - const __m512 min_distances_new = |
1458 | | - _mm512_mask_blend_ps(finalmask, distances, min_distances); |
1459 | | - |
1460 | | - min_indices = min_indices_new; |
1461 | | - min_distances = min_distances_new; |
1462 | | - |
1463 | | - current_indices = _mm512_add_epi32(current_indices, offset); |
1464 | | - } |
1465 | | - |
1466 | | - // leftovers |
1467 | | - if (k16 != static_cast<size_t>(k)) { |
1468 | | - const __mmask16 kmask = (1 << (k - k16)) - 1; |
1469 | | - |
1470 | | - __m512i indices = _mm512_mask_loadu_epi32( |
1471 | | - _mm512_set1_epi32(-1), kmask, ids.data() + k16); |
1472 | | - __m512 distances = _mm512_maskz_loadu_ps(kmask, dis.data() + k16); |
1473 | | - |
1474 | | - // This mask filters out -1 values among indices. |
1475 | | - __mmask16 m1mask = |
1476 | | - _mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices); |
1477 | | - |
1478 | | - __mmask16 dmask = |
1479 | | - _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS); |
1480 | | - __mmask16 finalmask = m1mask | dmask; |
1481 | | - |
1482 | | - const __m512i min_indices_new = _mm512_mask_blend_epi32( |
1483 | | - finalmask, current_indices, min_indices); |
1484 | | - const __m512 min_distances_new = |
1485 | | - _mm512_mask_blend_ps(finalmask, distances, min_distances); |
1486 | | - |
1487 | | - min_indices = min_indices_new; |
1488 | | - min_distances = min_distances_new; |
1489 | | - } |
1490 | | - |
1491 | | - // grab min distance |
1492 | | - min_dis = _mm512_reduce_min_ps(min_distances); |
1493 | | - // blend |
1494 | | - __mmask16 mindmask = |
1495 | | - _mm512_cmpeq_ps_mask(min_distances, _mm512_set1_ps(min_dis)); |
1496 | | - // pick the max one |
1497 | | - min_idx = _mm512_mask_reduce_max_epi32(mindmask, min_indices); |
1498 | | - |
1499 | | - if (min_idx == -1) { |
1500 | | - return -1; |
1501 | | - } |
1502 | | - |
1503 | | - if (vmin_out) { |
1504 | | - *vmin_out = min_dis; |
1505 | | - } |
1506 | | - int ret = ids[min_idx]; |
1507 | | - ids[min_idx] = -1; |
1508 | | - --nvalid; |
1509 | | - return ret; |
1510 | | -} |
1511 | | - |
1512 | | -#elif __AVX2__ |
1513 | | - |
1514 | | -int HNSW::MinimaxHeap::pop_min(float* vmin_out) { |
1515 | | - assert(k > 0); |
1516 | | - static_assert( |
1517 | | - std::is_same<storage_idx_t, int32_t>::value, |
1518 | | - "This code expects storage_idx_t to be int32_t"); |
1519 | | - |
1520 | | - int32_t min_idx = -1; |
1521 | | - float min_dis = std::numeric_limits<float>::infinity(); |
1522 | | - |
1523 | | - size_t iii = 0; |
1524 | | - |
1525 | | - __m256i min_indices = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1); |
1526 | | - __m256 min_distances = |
1527 | | - _mm256_set1_ps(std::numeric_limits<float>::infinity()); |
1528 | | - __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); |
1529 | | - __m256i offset = _mm256_set1_epi32(8); |
1530 | | - |
1531 | | - // The baseline version is available in non-AVX2 branch. |
1532 | | - |
1533 | | - // The following loop tracks the rightmost index with the min distance. |
1534 | | - // -1 index values are ignored. |
1535 | | - const size_t k8 = (k / 8) * 8; |
1536 | | - for (; iii < k8; iii += 8) { |
1537 | | - __m256i indices = |
1538 | | - _mm256_loadu_si256((const __m256i*)(ids.data() + iii)); |
1539 | | - __m256 distances = _mm256_loadu_ps(dis.data() + iii); |
1540 | | - |
1541 | | - // This mask filters out -1 values among indices. |
1542 | | - __m256i m1mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), indices); |
1543 | | - |
1544 | | - __m256i dmask = _mm256_castps_si256( |
1545 | | - _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS)); |
1546 | | - __m256 finalmask = _mm256_castsi256_ps(_mm256_or_si256(m1mask, dmask)); |
1547 | | - |
1548 | | - const __m256i min_indices_new = _mm256_castps_si256(_mm256_blendv_ps( |
1549 | | - _mm256_castsi256_ps(current_indices), |
1550 | | - _mm256_castsi256_ps(min_indices), |
1551 | | - finalmask)); |
1552 | | - |
1553 | | - const __m256 min_distances_new = |
1554 | | - _mm256_blendv_ps(distances, min_distances, finalmask); |
1555 | | - |
1556 | | - min_indices = min_indices_new; |
1557 | | - min_distances = min_distances_new; |
1558 | | - |
1559 | | - current_indices = _mm256_add_epi32(current_indices, offset); |
1560 | | - } |
1561 | | - |
1562 | | - // Vectorizing is doable, but is not practical |
1563 | | - int32_t vidx8[8]; |
1564 | | - float vdis8[8]; |
1565 | | - _mm256_storeu_ps(vdis8, min_distances); |
1566 | | - _mm256_storeu_si256((__m256i*)vidx8, min_indices); |
1567 | | - |
1568 | | - for (size_t j = 0; j < 8; j++) { |
1569 | | - if (min_dis > vdis8[j] || (min_dis == vdis8[j] && min_idx < vidx8[j])) { |
1570 | | - min_idx = vidx8[j]; |
1571 | | - min_dis = vdis8[j]; |
1572 | | - } |
1573 | | - } |
1574 | | - |
1575 | | - // process last values. Vectorizing is doable, but is not practical |
1576 | | - for (; iii < static_cast<size_t>(k); iii++) { |
1577 | | - if (ids[iii] != -1 && dis[iii] <= min_dis) { |
1578 | | - min_dis = dis[iii]; |
1579 | | - min_idx = iii; |
1580 | | - } |
1581 | | - } |
1582 | | - |
1583 | | - if (min_idx == -1) { |
1584 | | - return -1; |
1585 | | - } |
1586 | | - |
1587 | | - if (vmin_out) { |
1588 | | - *vmin_out = min_dis; |
1589 | | - } |
1590 | | - int ret = ids[min_idx]; |
1591 | | - ids[min_idx] = -1; |
1592 | | - --nvalid; |
1593 | | - return ret; |
1594 | | -} |
1595 | | - |
1596 | | -#else |
1597 | | - |
1598 | | -// baseline non-vectorized version |
1599 | | -int HNSW::MinimaxHeap::pop_min(float* vmin_out) { |
1600 | | - assert(k > 0); |
1601 | | - // returns min. This is an O(n) operation |
1602 | | - int i = k - 1; |
1603 | | - while (i >= 0) { |
1604 | | - if (ids[i] != -1) { |
1605 | | - break; |
1606 | | - } |
1607 | | - i--; |
1608 | | - } |
1609 | | - if (i == -1) { |
1610 | | - return -1; |
1611 | | - } |
1612 | | - int imin = i; |
1613 | | - float vmin = dis[i]; |
1614 | | - i--; |
1615 | | - while (i >= 0) { |
1616 | | - if (ids[i] != -1 && dis[i] < vmin) { |
1617 | | - vmin = dis[i]; |
1618 | | - imin = i; |
1619 | | - } |
1620 | | - i--; |
1621 | | - } |
1622 | | - if (vmin_out) { |
1623 | | - *vmin_out = vmin; |
1624 | | - } |
1625 | | - int ret = ids[imin]; |
1626 | | - ids[imin] = -1; |
1627 | | - --nvalid; |
1628 | | - |
1629 | | - return ret; |
1630 | | -} |
1631 | | -#endif |
1632 | | - |
1633 | | -int HNSW::MinimaxHeap::count_below(float thresh) { |
1634 | | - int n_below = 0; |
1635 | | - for (int i = 0; i < k; i++) { |
1636 | | - if (dis[i] < thresh) { |
1637 | | - n_below++; |
1638 | | - } |
1639 | | - } |
1640 | | - |
1641 | | - return n_below; |
1642 | | -} |
1643 | | - |
1644 | 1382 | } // namespace faiss |
0 commit comments