Skip to content

Commit e3139b4

Browse files
arkrishn94Alex Razumov (from Dev Box)
andauthored
[flat index] Flat Search Interface (#983)
This PR introduces a trait interface and a light index to support brute-force search for providers that can be used as/are a flat-index. There is an associated RFC that walks through the interface and associated implementation in `diskann` as a new `flat` module. Rendered RFC [link](https://github.com/microsoft/DiskANN/blob/u/adkrishnan/flat-index/rfcs/00983-flat-search.md). ## Motivation The repo has no first-class surface for brute-force search. This PR adds a small trait hierarchy that gives flat search the same provider-agnostic shape that graph search has, so any backend (in-memory, quantized, disk, remote) can plug in once and reuse a shared algorithm. ## Traits (`flat/strategy.rs`) **`DistancesUnordered<C>`** — the single trait a backend must implement. Fuses iteration and scoring into one method: the implementation drives a full scan, scoring each element with a precomputed query computer `C`, and invokes a callback with `(id, distance)` pairs. Key associated types: - `ElementRef<'a>` -- the reference shape `C` scores against. - `Id` -- the id type yielded to the callback (decoupled from `HasId` so visitors can yield any id shape). - `C : for<'a> PreprocessedDistanceFunction<Self::ElementRef<'a>, f32>` -- the precomputer query computer. ```rust pub trait DistancesUnordered<C>: Send + Sync where C: for<'a> PreprocessedDistanceFunction<Self::ElementRef<'a>, f32>, { type ElementRef<'a>; type Id; type Error: ToRanked + Debug + Send + Sync + 'static; fn distances_unordered<F>( &mut self, computer: &C, f: F, ) -> impl SendFuture<Result<(), Self::Error>> where F: Send + FnMut(Self::Id, f32); } ``` **`SearchStrategy<P, T>`** — factory that creates a `DistancesUnordered` visitor from a provider + context, and builds the per-query computer. Mirrors the graph-side strategy pattern. Two fallible methods: - `create_visitor` — borrows provider + context, returns a `Visitor` - `build_query_computer` — preprocesses the query `T` into a `QueryComputer` ```rust pub trait SearchStrategy<P, T>: Send + Sync where P: DataProvider, { type ElementRef<'a>; type Id; type QueryComputer: for<'a> PreprocessedDistanceFunction<Self::ElementRef<'a>, f32>; type QueryComputerError: StandardError; type Visitor<'a>: for<'b> DistancesUnordered< Self::QueryComputer, ElementRef<'b> = Self::ElementRef<'b>, Id = Self::Id, >; type Error: StandardError; fn create_visitor<'a>(&'a self, provider: &'a P, context: &'a P::Context) -> Result<Self::Visitor<'a>, Self::Error>; fn build_query_computer(&self, query: T) -> Result<Self::QueryComputer, Self::QueryComputerError>; } ``` ## Index (`flat/index.rs`) **`FlatIndex<P>`** — thin `'static` wrapper around a `DataProvider`. Currently we have implemented the naive kNN search algorithm for the flat index. `knn_search` asks the strategy for a visitor, builds the query computer, drives `distances_unordered` through a priority queue, and writes results via `SearchPostProcess`. ## Test infrastructure (`flat/test/`) A self-contained test provider with dimension-validated `Strategy`, transient-error injection, and a `KnnOracleRun` harness that compares `knn_search` results against a brute-force reference with baseline caching for regression detection. ## Future work - **Post-processing** — `knn_search` currently uses the graph-side `SearchPostProcess` trait to write results into the output buffer. #1067 will introduce a flat-specific post-processing step that decouples flat search from the graph module's output machinery. - **Vector ids** - The vector id over which `DistancesUnordered` acts is an associated type, instead of the `InternalId` of the provider. This is due to overly restrictive trait bounds on `VectorId` trait. We plan on relaxing this allowing for more generic id types. --------- Co-authored-by: Alex Razumov (from Dev Box) <alrazu@microsoft.com>
1 parent 4f70a82 commit e3139b4

13 files changed

Lines changed: 2439 additions & 0 deletions

File tree

diskann/src/flat/index.rs

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
/*
2+
* Copyright (c) Microsoft Corporation.
3+
* Licensed under the MIT license.
4+
*/
5+
6+
//! [`FlatIndex`] — the index wrapper for a [`DataProvider`](crate::provider::DataProvider)
7+
//! over which we do flat search.
8+
use std::num::NonZeroUsize;
9+
10+
use diskann_utils::future::SendFuture;
11+
12+
use crate::{
13+
ANNResult,
14+
error::{ErrorExt, IntoANNResult},
15+
flat::{DistancesUnordered, SearchStrategy},
16+
graph::SearchOutputBuffer,
17+
neighbor::{Neighbor, NeighborPriorityQueue, NeighborPriorityQueueIdType},
18+
provider::DataProvider,
19+
};
20+
21+
/// Statistics collected during a flat search.
22+
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
23+
pub struct SearchStats {
24+
/// The total number of distance computations performed during the scan.
25+
pub cmps: u32,
26+
27+
/// The total number of results written to the output buffer.
28+
pub result_count: u32,
29+
}
30+
31+
/// A thin wrapper around a [`DataProvider`] used for flat search.
32+
#[derive(Debug)]
33+
pub struct FlatIndex<P: DataProvider> {
34+
/// The backing provider.
35+
provider: P,
36+
}
37+
38+
impl<P: DataProvider> FlatIndex<P> {
39+
/// Construct a new [`FlatIndex`] around `provider`.
40+
pub fn new(provider: P) -> Self {
41+
Self { provider }
42+
}
43+
44+
/// Borrow the underlying provider.
45+
pub fn provider(&self) -> &P {
46+
&self.provider
47+
}
48+
49+
/// Brute-force k-nearest-neighbor flat search.
50+
///
51+
/// Streams every element produced by the strategy's visitor through the query
52+
/// computer, keeps the best `k` candidates in a [`NeighborPriorityQueue`], and
53+
/// writes the `(id, distance)` survivors into `output` in best-first order.
54+
pub fn knn_search<S, T, OB>(
55+
&self,
56+
k: NonZeroUsize,
57+
strategy: &S,
58+
context: &P::Context,
59+
query: T,
60+
output: &mut OB,
61+
) -> impl SendFuture<ANNResult<SearchStats>>
62+
where
63+
S: SearchStrategy<P, T>,
64+
S::Id: NeighborPriorityQueueIdType,
65+
T: Send + Sync,
66+
OB: SearchOutputBuffer<S::Id> + Send + ?Sized,
67+
{
68+
async move {
69+
let mut visitor = strategy
70+
.create_visitor(&self.provider, context)
71+
.into_ann_result()?;
72+
73+
let computer = strategy.build_query_computer(query).into_ann_result()?;
74+
75+
let k = k.get();
76+
let mut queue = NeighborPriorityQueue::new(k);
77+
let mut cmps: u32 = 0;
78+
79+
visitor
80+
.distances_unordered(&computer, |id, dist| {
81+
cmps += 1;
82+
queue.insert(Neighbor::new(id, dist));
83+
})
84+
.await
85+
.escalate("flat scan must complete to produce correct k-NN results")?;
86+
87+
let result_count =
88+
output.extend(queue.iter().take(k).map(|n| (n.id, n.distance))) as u32;
89+
90+
Ok(SearchStats { cmps, result_count })
91+
}
92+
}
93+
}
94+
95+
///////////
96+
// Tests //
97+
///////////
98+
99+
#[cfg(test)]
100+
mod tests {
101+
use crate::flat::{
102+
FlatIndex,
103+
test::{
104+
harness::KnnOracleRun,
105+
provider::{self as flat_provider},
106+
},
107+
};
108+
use crate::graph::test::synthetic::Grid;
109+
110+
fn fixture(grid: Grid, size: usize) -> (FlatIndex<flat_provider::Provider>, usize) {
111+
let provider = flat_provider::Provider::grid(grid, size).unwrap();
112+
let len = provider.len();
113+
(FlatIndex::new(provider), len)
114+
}
115+
116+
/// `knn_search` returns a `Send` future, and a shared `&FlatIndex` can serve
117+
/// many concurrent searches on a multi-threaded runtime, each producing the
118+
/// correct top-k independently.
119+
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
120+
async fn multithreaded_knn_search() {
121+
use std::sync::Arc;
122+
123+
let (index, len) = fixture(Grid::Two, 4);
124+
let index = Arc::new(index);
125+
126+
// Mix of corner, axis-aligned, and off-grid queries; k spans 1..=len.
127+
let cases: &[(&[f32], usize)] = &[
128+
(&[-1.0, -1.0], 1),
129+
(&[1.0, 1.0], len),
130+
(&[-1.0, 1.0], len / 2),
131+
(&[1.0, -1.0], len - 1),
132+
(&[0.0, 0.0], 3),
133+
(&[3.0, 3.0], len),
134+
(&[-2.0, 0.5], 2),
135+
(&[0.5, -0.5], len),
136+
];
137+
138+
let mut set = tokio::task::JoinSet::new();
139+
for (query, k) in cases {
140+
let index = Arc::clone(&index);
141+
let query: Vec<f32> = query.to_vec();
142+
let k = *k;
143+
set.spawn(async move {
144+
let outcome = KnnOracleRun::run(
145+
&index,
146+
&flat_provider::Strategy::new(index.provider().dim()),
147+
&query,
148+
k,
149+
)
150+
.await
151+
.expect("knn_search failed");
152+
(query, k, outcome)
153+
});
154+
}
155+
156+
while let Some(joined) = set.join_next().await {
157+
let (query, k, outcome) = joined.expect("task panicked");
158+
assert_eq!(
159+
outcome.top_k, outcome.ground_truth,
160+
"query = {query:?}, k = {k}: top-k must match brute force",
161+
);
162+
assert_eq!(outcome.stats.cmps as usize, len);
163+
assert_eq!(outcome.stats.result_count as usize, k.min(len));
164+
}
165+
}
166+
167+
////////////
168+
// Errors //
169+
////////////
170+
171+
/// A transient error from the visitor's scan must escalate up through `knn_search`.
172+
#[test]
173+
fn transient_scan_error() {
174+
// The flat scan touches every id, so any transient id is guaranteed to be hit.
175+
for transient_ids in [&[0u32][..], &[3][..], &[1, 2, 5][..]] {
176+
let strategy =
177+
flat_provider::Strategy::with_transient(2, transient_ids.iter().copied());
178+
let (index, _) = fixture(Grid::Two, 3);
179+
let err = KnnOracleRun::run_sync(&index, &strategy, &[1.0, 0.0], 4)
180+
.expect_err("transient error during full scan must escalate");
181+
182+
let msg = format!("{err}");
183+
assert!(
184+
transient_ids
185+
.iter()
186+
.any(|id| msg.contains(&format!("id {id}"))),
187+
"transients = {transient_ids:?}: expected error to name one of the \
188+
transient ids, got: {msg}",
189+
);
190+
}
191+
}
192+
193+
/// Run `knn_search` via the harness, assert it fails, and check the error
194+
/// message contains `expected_msg`.
195+
fn assert_search_error(strategy: &flat_provider::Strategy, query: &[f32], expected_msg: &str) {
196+
let (index, _) = fixture(Grid::Two, 3);
197+
let err = KnnOracleRun::run_sync(&index, strategy, query, 4)
198+
.expect_err("expected knn_search to fail");
199+
200+
let msg = format!("{err}");
201+
assert!(
202+
msg.contains(expected_msg),
203+
"expected error containing {expected_msg:?}, got: {msg}",
204+
);
205+
}
206+
207+
#[test]
208+
fn strategy_constructor_errors() {
209+
// Strategy/provider expect dim=2, query has dim=3.
210+
assert_search_error(
211+
&flat_provider::Strategy::new(2),
212+
&[0.0, 0.0, 0.0],
213+
"dimension mismatch",
214+
);
215+
216+
// Strategy expects dim=5, provider has dim=2.
217+
assert_search_error(
218+
&flat_provider::Strategy::new(5),
219+
&[0.0, 0.0],
220+
"dimension mismatch",
221+
);
222+
}
223+
}

diskann/src/flat/mod.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) Microsoft Corporation.
3+
* Licensed under the MIT license.
4+
*/
5+
6+
//! Sequential ("flat") search.
7+
//!
8+
//! This module is the streaming counterpart to the random-access
9+
//! [`crate::provider::Accessor`] family. It is designed for backends whose natural access
10+
//! pattern is a one-pass scan over their data -- for example append-only buffered stores or
11+
//! on-disk shards streamed via I/O.
12+
//!
13+
//! # Architecture
14+
//!
15+
//! The module mirrors the layering used by graph search:
16+
//!
17+
//! | Graph (random access) | Flat (sequential) | Shared? |
18+
//! | :------------------------------------ | :----------------------------------------- |:--------- |
19+
//! | [`crate::provider::DataProvider`] | [`crate::provider::DataProvider`] | Yes |
20+
//! | [`crate::graph::DiskANNIndex`] | [`FlatIndex`] | No |
21+
//! | [`crate::graph::glue::ExpandBeam`] | [`DistancesUnordered`] | No |
22+
//! | [`crate::graph::glue::SearchStrategy`] | [`SearchStrategy`] | No |
23+
//! | [`crate::graph::Search`] | [`FlatIndex::knn_search`] | No |
24+
//!
25+
pub mod index;
26+
pub mod strategy;
27+
28+
pub use index::{FlatIndex, SearchStats};
29+
pub use strategy::{DistancesUnordered, SearchStrategy};
30+
31+
#[cfg(test)]
32+
mod test;

0 commit comments

Comments
 (0)