Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions diskann-benchmark/src/backend/exhaustive/spherical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,19 @@ mod imp {
// Compressing
let start = std::time::Instant::now();
let store = {
let threadpool = rayon::ThreadPoolBuilder::new()
.num_threads(input.compression_threads.get())
.build()?;

let compression_progress =
make_progress_bar("compressing", data.nrows(), output.draw_target())?;
let store = Store::new(
data.as_view(),
diskann_quantization::spherical::iface::Impl::<NBITS>::new(quantizer)?,
&compression_progress,
)?;
let store = threadpool.install(|| {
Store::new(
data.as_view(),
diskann_quantization::spherical::iface::Impl::<NBITS>::new(quantizer)?,
&compression_progress,
)
})?;
compression_progress.finish();
store
};
Expand Down
201 changes: 113 additions & 88 deletions diskann-quantization/src/spherical/quantizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ use super::{
};
use crate::{
AsFunctor, CompressIntoWith,
algorithms::{
heap::SliceHeap,
transforms::{NewTransformError, Transform, TransformFailed, TransformKind},
},
algorithms::transforms::{NewTransformError, Transform, TransformFailed, TransformKind},
alloc::{Allocator, AllocatorError, GlobalAllocator, Poly, ScopedAllocator, TryClone},
bits::{PermutationStrategy, Representation, Unsigned},
num::Positive,
Expand Down Expand Up @@ -988,106 +985,134 @@ fn maximize_cosine_similarity(
num_bits: NonZeroUsize,
allocator: ScopedAllocator<'_>,
) -> Result<f32, AllocatorError> {
// Lint: This is a private method and all the callers have an invariant that they check
// for non-empty inputs.
#[allow(clippy::expect_used)]
let _: NonZeroUsize =
NonZeroUsize::new(v.len()).expect("calling code should not allow the slice to be empty");

// Initially, the lattice element has the value `0.5` for all dimensions.
// This means the initial inner product between `v` and the rounded term is simply
// `0.5 * sum(abs.(v))`. The absolute value is used because the latice element is
// always in the direction of the components in `v`.
let mut current_ip = 0.5 * v.iter().map(|i| i.abs() as f64).sum::<f64>();
let mut current_square_norm = 0.25 * (v.len() as f64);

// Book keeping for the current value of the rounded vector.
// The true numeric value is 0.5 less than this (in the direction of `v`), but we use
// integers for a smaller memory footprint.
let mut rounded = Poly::broadcast(1u16, v.len(), allocator)?;

// Compute the critical values and store them on a heap.
// Enumerate every critical value into a single flat buffer; which is then sorted.
//
// Critical values for dimension `i` are the scaling factors at which `rounded[i]`
// advances from one integer level to the next.
//
// The binary heap will keep track of the minimum critical value. Multiplying `v` by the
// minimum critical value `s` means that `s * v` will only change `rounded` from its
// current value at a single index (the position associated with `s`).
// For `num_bits >= 2`, the encoding for dimension `i` saturates
// once `rounded[i] == stop`, so dimension `i` contributes `stop - 1` critical values.
//
// For `num_bits == 1` the encoding only has the levels `0` and
// `1`, so each dimension contributes a single critical value at the transition from
// `r = 1` to `r = 2`; `visits_per_dim` is clamped to 1 to capture that case.
//
// The `eps` term breaks ties between dimensions that would otherwise transition at the
// exact same scaling factor.
let stop: usize = 1usize << (num_bits.get() - 1);
let visits_per_dim: usize = stop.max(2) - 1;
let total = v.len() * visits_per_dim;
Comment on lines +1015 to +1017
let eps = 0.0001f32;
let one_and_change = 1.0 + eps;
let mut base = Poly::from_iter(
v.iter().enumerate().map(|(position, value)| {
let value = one_and_change / value.abs();
Pair {
value,
position: position as u32,

let mut crits = Poly::<[Pair], _>::new_uninit_slice(total, allocator)?;
{
let buf = crits.as_mut();
let mut k = 0usize;
for (position, value) in v.iter().enumerate() {
// Hoist the reciprocal so the inner loop only multiplies.
let inv = 1.0f32 / value.abs();
for r in 1..=visits_per_dim {
// SAFETY: `k` is bounded above by `v.len() * visits_per_dim = total`,
// which is exactly the length of `buf`.
unsafe {
buf.get_unchecked_mut(k).write(Pair {
value: (r as f32 + eps) * inv,
position: position as u32,
});
}
k += 1;
}
}),
allocator,
)?;
}
debug_assert_eq!(k, total);
}
// SAFETY: The loop above initialized exactly `total` entries, matching the slice's
// length.
let mut crits = unsafe { crits.assume_init() };

// Sort critical values in ascending order so that walking the slice corresponds to
// sweeping `s` from `0` to `+inf`. `Pair`'s `Ord` impl is reversed so it's
// in ascending order.
crits.sort_unstable_by(|a, b| {
a.value
.partial_cmp(&b.value)
.unwrap_or(std::cmp::Ordering::Equal)
});
Comment on lines +1045 to +1052

// Lint: This is a private method and all the callers have an invariant that they check
// for non-empty inputs.
#[allow(clippy::expect_used)]
let mut critical_values =
SliceHeap::new(&mut base).expect("calling code should not allow the slice to be empty");
// Book keeping for the current value of the rounded vector.
// The true numeric value is 0.5 less than this (in the direction of `v`), but we use
// integers for a smaller memory footprint.
let mut rounded = Poly::broadcast(1u16, v.len(), allocator)?;

let mut max_similarity = f64::NEG_INFINITY;
// `max_ip_sq / max_sn` is the squared best cosine similarity seen so far.
// See the cosine-similarity comparison below for why squared quantities are tracked
// instead of the cosine similarity directly.
let mut max_ip_sq = f64::NEG_INFINITY;
let mut max_sn = 1.0f64;
let mut optimal_scale = f32::default();
let stop = (2usize).pow(num_bits.get() as u32 - 1) as u16;

loop {
let mut should_break = false;
critical_values.update_root(|pair| {
let Pair { value, position } = *pair;
if value == f32::MAX {
should_break = true;
return;
}

let r = &mut rounded[position as usize];
let vp = &v[position as usize];
for &Pair { value, position } in crits.iter() {
// SAFETY: `position` is in `0..v.len()` by construction above and is never
// modified.
let r = unsafe { rounded.get_unchecked_mut(position as usize) };
// SAFETY: Same as above.
let vp = unsafe { *v.get_unchecked(position as usize) };

let old_r = *r;
// By the nature of cricital values, only `r` will change in `rounded` when
// multiplying by `value`. And that change will be to increase by 1.
*r += 1;

// The inner product estimate simply increases by `vp.abs()` because:
//
// * `r` is the only value in `rounded` that changes.
// * `r` is increased by 1.
current_ip += vp.abs() as f64;

// This uses the formula
// ```math
// (x + 1)^2 - x^2 = x^2 + 2x + 1 - x^2
// = 2x + 1
// ```
// substitute `x = y - 1/2` to obtain the true value associated with rounded and
// we get
// ```math
// 2 ( y - 1/2 ) + 1 = 2y - 1 + 1
// = 2y
// ```
// Therefore, the change in the estimate for the square norm of `rounded` is
// `2 * old_r`.
current_square_norm += (2 * old_r) as f64;

// Compute the current cosine similarity and update max if needed.
let similarity = current_ip / current_square_norm.sqrt();
if similarity > max_similarity {
max_similarity = similarity;
optimal_scale = value;
}
let old_r = *r;
// By the nature of critical values, only `r` will change in `rounded` when
// multiplying by `value`. And that change will be to increase by 1.
*r += 1;

// Compute the scaling factor that will change this dimension to the next value.
if *r < stop {
*pair = Pair {
value: (*r as f32 + eps) / vp.abs(),
position,
};
} else {
*pair = Pair {
value: f32::MAX,
position,
};
}
});
if should_break {
break;
// The inner product estimate simply increases by `vp.abs()` because:
//
// * `r` is the only value in `rounded` that changes.
// * `r` is increased by 1.
current_ip += vp.abs() as f64;

// This uses the formula
// ```math
// (x + 1)^2 - x^2 = x^2 + 2x + 1 - x^2
// = 2x + 1
// ```
// substitute `x = y - 1/2` to obtain the true value associated with rounded and we
// get
// ```math
// 2 ( y - 1/2 ) + 1 = 2y - 1 + 1
// = 2y
// ```
// Therefore, the change in the estimate for the square norm of `rounded` is
// `2 * old_r`.
current_square_norm += (2 * old_r) as f64;

// Compare cosine similarities without taking square roots. The cosine similarity
// is `ip / sqrt(sn)`, so the comparison we want to make is
// ```math
// ip / sqrt(sn) > max_ip / sqrt(max_sn)
// ```
// Both sides are non-negative: `current_ip` starts at `0.5 * sum(|v|)` and only
// grows, and `current_square_norm` is strictly positive. Squaring is monotonic on
// non-negative reals, so the comparison is equivalent to
// ```math
// ip^2 * max_sn > max_ip^2 * sn
// ```
// which avoids the `sqrt` per critical value.
let ip_sq = current_ip * current_ip;
if ip_sq * max_sn > max_ip_sq * current_square_norm {
max_ip_sq = ip_sq;
max_sn = current_square_norm;
optimal_scale = value;
}
}

Expand Down
Loading