Skip to content

Commit 6a9936c

Browse files
alibeklfcmeta-codesync[bot]
authored andcommitted
Hoist per-iteration vector allocations in IndexIVFRaBitQFastScan::compute_LUT
Summary: Move `rotated_q` and `centroid_buf` vector allocations from inside the inner loop to the `#pragma omp parallel` block scope. Previously each of the `n * nprobe` iterations allocated and freed these vectors on the heap. Now each thread allocates once and reuses capacity across all its iterations. Split `#pragma omp parallel for` into `#pragma omp parallel` + `#pragma omp for` to create a per-thread scope for the buffers. Reviewed By: mnorris11 Differential Revision: D102678232 fbshipit-source-id: 601022dd65943d11d53c7e5a91015553ad75870b
1 parent 6c70444 commit 6a9936c

1 file changed

Lines changed: 31 additions & 25 deletions

File tree

faiss/IndexIVFRaBitQFastScan.cpp

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -467,33 +467,39 @@ void IndexIVFRaBitQFastScan::compute_LUT(
467467
if (n * cq_nprobe > 0) {
468468
memset(biases.get(), 0, sizeof(float) * n * cq_nprobe);
469469
}
470+
// Use per-thread buffers instead of one O(n * nprobe * d) allocation.
471+
// rotated_q / centroid_buf keep their capacity across iterations so the
472+
// allocator is only hit once per thread.
473+
#pragma omp parallel if (n * cq_nprobe > 1000)
474+
{
475+
std::vector<float> rotated_q(d);
476+
std::vector<float> centroid_buf(d);
470477

471-
#pragma omp parallel for if (n * cq_nprobe > 1000)
472-
for (idx_t ij = 0; ij < static_cast<idx_t>(n * cq_nprobe); ij++) {
473-
idx_t i = ij / cq_nprobe;
474-
idx_t cij = cq.ids[ij];
475-
476-
if (cij >= 0) {
477-
std::vector<float> rotated_q(d);
478-
std::vector<float> centroid_buf(d);
479-
QueryFactorsData query_factors_data;
480-
481-
compute_residual_LUT(
482-
x + i * d,
483-
cij,
484-
query_factors_data,
485-
dis_tables.get() + ij * dim12,
486-
used_qb,
487-
used_centered,
488-
rotated_q,
489-
centroid_buf);
490-
491-
if (context.query_factors != nullptr) {
492-
context.query_factors[ij] = query_factors_data;
493-
}
478+
#pragma omp for
479+
for (idx_t ij = 0; ij < static_cast<idx_t>(n * cq_nprobe); ij++) {
480+
idx_t i = ij / cq_nprobe;
481+
idx_t cij = cq.ids[ij];
482+
483+
if (cij >= 0) {
484+
QueryFactorsData query_factors_data;
485+
486+
compute_residual_LUT(
487+
x + i * d,
488+
cij,
489+
query_factors_data,
490+
dis_tables.get() + ij * dim12,
491+
used_qb,
492+
used_centered,
493+
rotated_q,
494+
centroid_buf);
495+
496+
if (context.query_factors != nullptr) {
497+
context.query_factors[ij] = std::move(query_factors_data);
498+
}
494499

495-
} else {
496-
memset(dis_tables.get() + ij * dim12, 0, sizeof(float) * dim12);
500+
} else {
501+
memset(dis_tables.get() + ij * dim12, 0, sizeof(float) * dim12);
502+
}
497503
}
498504
}
499505
}

0 commit comments

Comments
 (0)