Skip to content

Commit b272642

Browse files
Sanhaoji2claude
andauthored
Add debug api (#1053)
<!-- Thanks for contributing a pull request! Please ensure you have taken a look at the contribution guidelines: https://github.com/microsoft/DiskANN/blob/main/CONTRIBUTING.md --> - [ ] Does this PR have a descriptive title that could go in our release notes? - [ ] Does this PR add any new dependencies? - [ ] Does this PR modify any existing APIs? - [ ] Is the change to the API backwards compatible? - [ ] Should this result in any changes to our documentation, either updating existing docs or adding new ones? #### Reference Issues/PRs <!-- Example: Fixes #1234. See also #3456. Please use keywords (e.g., Fixes) to create link to the issues or pull requests you resolved, so that they will automatically be closed when your pull request is merged. See https://github.com/blog/1506-closing-issues-via-pull-requests --> #### What does this implement/fix? Briefly explain your changes. #### Any other comments? --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f07f6ba commit b272642

7 files changed

Lines changed: 513 additions & 63 deletions

File tree

include/abstract_index.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "index_config.h"
77
#include "index_build_params.h"
88
#include "percentile_stats.h"
9+
#include "debug_utils.h"
910
#include <any>
1011

1112
namespace diskann
@@ -90,6 +91,30 @@ class AbstractIndex
9091
float *distances,
9192
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr);
9293

94+
// Debug interface: retrieve the raw embedding at internal location index.
95+
// Caller must pre-allocate vec with at least the index dimension elements.
96+
template <typename data_type>
97+
void get_embedding(uint32_t location, data_type *vec);
98+
99+
// Debug search: runs ANN search and records every traversed node in debug_info.
100+
template <typename data_type, typename IDType>
101+
std::pair<uint32_t, uint32_t> debug_search(
102+
const data_type *query, const size_t K, const uint32_t L,
103+
IDType *indices, float *distances,
104+
DebugTraversalInfo &debug_info,
105+
const uint32_t maxLperSeller = 0,
106+
std::function<float(const std::uint8_t *, size_t)> rerank_fn = nullptr);
107+
108+
// Debug filtered search: same as debug_search with label filtering.
109+
template <typename data_type, typename IDType>
110+
std::pair<uint32_t, uint32_t> debug_search_with_filters(
111+
const data_type *query, const std::vector<std::string> &raw_labels,
112+
const size_t K, const uint32_t L,
113+
IDType *indices, float *distances,
114+
DebugTraversalInfo &debug_info,
115+
const uint32_t maxLperSeller = 0,
116+
std::function<float(const std::uint8_t *, size_t)> rerank_fn = nullptr);
117+
93118
// insert points with labels, labels should be present for filtered index
94119
template <typename data_type, typename tag_type>
95120
int insert_point(const data_type *point, const tag_type tag, const std::vector<std::string> &labels);
@@ -148,5 +173,18 @@ class AbstractIndex
148173
const std::vector<std::string>& filter_labels) = 0;
149174
virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) = 0;
150175
virtual void _set_universal_label(const LabelType universal_label) = 0;
176+
virtual void _get_embedding(uint32_t location, DataType &vec) = 0;
177+
virtual std::pair<uint32_t, uint32_t> _debug_search(const DataType &query, const size_t K, const uint32_t L,
178+
std::any &indices, float *distances,
179+
DebugTraversalInfo &debug_info,
180+
const uint32_t maxLperSeller,
181+
std::function<float(const std::uint8_t *, size_t)> rerank_fn) = 0;
182+
virtual std::pair<uint32_t, uint32_t> _debug_search_with_filters(const DataType &query,
183+
const std::vector<std::string> &raw_labels,
184+
const size_t K, const uint32_t L,
185+
std::any &indices, float *distances,
186+
DebugTraversalInfo &debug_info,
187+
const uint32_t maxLperSeller,
188+
std::function<float(const std::uint8_t *, size_t)> rerank_fn) = 0;
151189
};
152190
} // namespace diskann

include/debug_utils.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT license.
3+
4+
#pragma once
5+
6+
#include <cstdint>
7+
#include <vector>
8+
#include <limits>
9+
10+
namespace diskann
11+
{
12+
13+
// Reason why a node visited during ANN graph traversal was or was not included
14+
// in the final result set.
15+
enum class FilterReason : uint8_t
16+
{
17+
InResult = 0, // Node was kept in the final top-K result
18+
DistanceTooLarge, // Node was visited but its distance was too large for top-K
19+
LabelMismatch, // Node was skipped because its label did not match the filter
20+
};
21+
22+
// Collects per-node traversal information during a debug ANN search.
23+
// Populated by iterate_to_fixed_point / cached_beam_search when a non-null
24+
// pointer is passed. Each parallel vector entry corresponds to one node
25+
// encountered during traversal.
26+
struct DebugTraversalInfo
27+
{
28+
std::vector<uint32_t> ids; // Internal location index of each encountered node
29+
std::vector<float> distances; // PQ/exact distance to query; FLT_MAX when label-rejected
30+
std::vector<uint8_t> label_rejected; // 1 if skipped due to label mismatch, 0 if evaluated
31+
32+
void clear()
33+
{
34+
ids.clear();
35+
distances.clear();
36+
label_rejected.clear();
37+
}
38+
39+
void record_label_rejected(uint32_t id)
40+
{
41+
ids.push_back(id);
42+
distances.push_back(std::numeric_limits<float>::max());
43+
label_rejected.push_back(1);
44+
}
45+
46+
void record_visited(uint32_t id, float dist)
47+
{
48+
ids.push_back(id);
49+
distances.push_back(dist);
50+
label_rejected.push_back(0);
51+
}
52+
};
53+
54+
} // namespace diskann

include/index.h

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
#include "quantized_distance.h"
3030
#include "pq_data_store.h"
31+
#include "debug_utils.h"
3132

3233
#define OVERHEAD_FACTOR 1.1
3334
#define EXPAND_IF_FULL 0
@@ -145,7 +146,8 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
145146
template <typename IDType>
146147
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search(const T *query, const size_t K, const uint32_t L,
147148
IDType *indices, float *distances = nullptr, const uint32_t maxLperSeller = 0,
148-
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr);
149+
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr,
150+
DebugTraversalInfo *debug_info = nullptr);
149151

150152
template <typename IDType>
151153
std::pair<uint32_t, uint32_t> diverse_search(const T* query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, IDType* indices,
@@ -157,6 +159,30 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
157159
float *distances, std::vector<T *> &res_vectors, bool use_filters,
158160
const std::vector<std::string>& filter_labels);
159161

162+
// Debug interface: retrieve the raw embedding stored at the given internal location index.
163+
// Caller must allocate vec with at least get_aligned_dim() elements of type T.
164+
DISKANN_DLLEXPORT void get_embedding(uint32_t location, T *vec) const;
165+
166+
// Debug search: runs ANN search and records every traversed node.
167+
// debug_info is populated in traversal order; call FilterReason helpers to classify entries.
168+
template <typename IDType>
169+
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> debug_search(
170+
const T *query, const size_t K, const uint32_t L,
171+
IDType *indices, float *distances,
172+
DebugTraversalInfo &debug_info,
173+
const uint32_t maxLperSeller = 0,
174+
std::function<float(const std::uint8_t *, size_t)> rerank_fn = nullptr);
175+
176+
// Debug filtered search: same as debug_search but applies label filtering.
177+
template <typename IDType>
178+
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> debug_search_with_filters(
179+
const T *query, const std::vector<LabelT> &filter_labels,
180+
const size_t K, const uint32_t L,
181+
IDType *indices, float *distances,
182+
DebugTraversalInfo &debug_info,
183+
const uint32_t maxLperSeller = 0,
184+
std::function<float(const std::uint8_t *, size_t)> rerank_fn = nullptr);
185+
160186
virtual std::pair<uint32_t, uint32_t> _diverse_search(const DataType& query, const size_t K, const uint32_t L, const uint32_t maxLperSeller,
161187
std::any& indices, float* distances = nullptr,
162188
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr) override;
@@ -166,7 +192,8 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
166192
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search_with_filters(const T *query, const std::vector<LabelT> &filter_labels,
167193
const size_t K, const uint32_t L, const uint32_t maxLperSeller,
168194
IndexType *indices, float *distances,
169-
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr);
195+
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr,
196+
DebugTraversalInfo *debug_info = nullptr);
170197

171198
// Will fail if tag already in the index or if tag=0.
172199
DISKANN_DLLEXPORT int insert_point(const T *point, const TagT tag);
@@ -235,6 +262,22 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
235262
float *distances,
236263
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr) override;
237264

265+
virtual void _get_embedding(uint32_t location, DataType &vec) override;
266+
267+
virtual std::pair<uint32_t, uint32_t> _debug_search(const DataType &query, const size_t K, const uint32_t L,
268+
std::any &indices, float *distances,
269+
DebugTraversalInfo &debug_info,
270+
const uint32_t maxLperSeller,
271+
std::function<float(const std::uint8_t *, size_t)> rerank_fn) override;
272+
273+
virtual std::pair<uint32_t, uint32_t> _debug_search_with_filters(const DataType &query,
274+
const std::vector<std::string> &raw_labels,
275+
const size_t K, const uint32_t L,
276+
std::any &indices, float *distances,
277+
DebugTraversalInfo &debug_info,
278+
const uint32_t maxLperSeller,
279+
std::function<float(const std::uint8_t *, size_t)> rerank_fn) override;
280+
238281
virtual int _insert_point(const DataType &data_point, const TagType tag) override;
239282
virtual int _insert_point(const DataType &data_point, const TagType tag, const std::vector<std::string> &labels) override;
240283

@@ -293,7 +336,9 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
293336
// The query to use is placed in scratch->aligned_query
294337
std::pair<uint32_t, uint32_t> iterate_to_fixed_point(InMemQueryScratch<T> *scratch, const uint32_t Lindex,
295338
const std::vector<uint32_t> &init_ids, bool use_filter,
296-
const std::vector<LabelT> &filters, bool search_invocation, uint32_t maxLperSeller = 0);
339+
const std::vector<LabelT> &filters, bool search_invocation,
340+
uint32_t maxLperSeller = 0,
341+
DebugTraversalInfo *debug_info = nullptr);
297342

298343
void search_for_point_and_prune(int location, uint32_t Lindex, std::vector<uint32_t> &pruned_list,
299344
InMemQueryScratch<T> *scratch, bool use_filter = false,

include/pq_flash_index.h

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "tsl/robin_set.h"
1818
#include "label_bitmask.h"
1919
#include "integer_label_vector.h"
20+
#include "debug_utils.h"
2021

2122
#define FULL_PRECISION_REORDER_MULTIPLIER 3
2223

@@ -81,10 +82,11 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
8182
DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search,
8283
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
8384
const bool use_filter, const std::vector<LabelT> &filter_labels,
84-
const uint32_t io_limit, uint32_t maxLperSeller = 0,
85+
const uint32_t io_limit, uint32_t maxLperSeller = 0,
8586
const bool use_reorder_data = false,
8687
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr,
87-
QueryStats *stats = nullptr);
88+
QueryStats *stats = nullptr,
89+
DebugTraversalInfo *debug_info = nullptr);
8890

8991
DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &filter_label);
9092

@@ -117,6 +119,25 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
117119
DISKANN_DLLEXPORT std::vector<std::uint8_t> get_pq_vector(std::uint64_t vid);
118120
DISKANN_DLLEXPORT uint64_t get_num_points();
119121

122+
// Debug interface: retrieve full-precision embedding for a given internal node ID.
123+
// Caller must pre-allocate vec with at least the index dimension elements (get_data_dim()).
124+
DISKANN_DLLEXPORT void get_embedding(uint32_t id, T *vec);
125+
126+
// Debug search: runs ANN search and records every traversed node with a FilterReason.
127+
DISKANN_DLLEXPORT void debug_search(
128+
const T *query, const uint64_t k_search, const uint64_t l_search,
129+
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
130+
DebugTraversalInfo &debug_info,
131+
uint32_t maxLperSeller = 0);
132+
133+
// Debug filtered search: same as debug_search but applies label filtering.
134+
DISKANN_DLLEXPORT void debug_search_with_filters(
135+
const T *query, const uint64_t k_search, const uint64_t l_search,
136+
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
137+
const std::vector<LabelT> &filter_labels,
138+
DebugTraversalInfo &debug_info,
139+
uint32_t maxLperSeller = 0);
140+
120141
protected:
121142
DISKANN_DLLEXPORT void use_medoids_data_as_centroids();
122143
DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, uint64_t visited_reserve = 4096);

0 commit comments

Comments
 (0)