Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 173 additions & 2 deletions rust/index/src/sparse/maxscore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -671,12 +671,36 @@ struct TermState<'a> {
window_score: f32,
}

// ── Budget pruning (scalar; SIMD added in PR #4) ────────────────────
// ── Budget pruning ──────────────────────────────────────────────────

/// Remove candidates whose score <= cutoff. Both parallel arrays are
/// compacted in-place.
/// compacted in-place. Dispatches to SIMD on supported architectures.
fn filter_competitive(cand_docs: &mut Vec<u32>, cand_scores: &mut Vec<f32>, cutoff: f32) {
debug_assert_eq!(cand_docs.len(), cand_scores.len());

#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("sse2") {
// SAFETY: SSE2 feature detected at runtime; both slices have
// equal length (debug_assert above), and write <= read index
// guarantees no out-of-bounds writes.
unsafe { filter_competitive_sse2(cand_docs, cand_scores, cutoff) };
return;
}
}

#[cfg(target_arch = "aarch64")]
{
// SAFETY: NEON is always available on aarch64. Same index invariants.
unsafe { filter_competitive_neon(cand_docs, cand_scores, cutoff) };
return;
}

#[allow(unreachable_code)]
filter_competitive_scalar(cand_docs, cand_scores, cutoff);
}

fn filter_competitive_scalar(cand_docs: &mut Vec<u32>, cand_scores: &mut Vec<f32>, cutoff: f32) {
let n = cand_docs.len();
let mut write = 0;
for i in 0..n {
Expand All @@ -690,6 +714,116 @@ fn filter_competitive(cand_docs: &mut Vec<u32>, cand_scores: &mut Vec<f32>, cuto
cand_scores.truncate(write);
}

/// SSE2: 4-wide `_mm_cmpgt_ps` + `_mm_movemask_ps` for branchless comparison,
/// then scatter surviving elements.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
unsafe fn filter_competitive_sse2(
cand_docs: &mut Vec<u32>,
cand_scores: &mut Vec<f32>,
cutoff: f32,
) {
use std::arch::x86_64::*;

let n = cand_docs.len();
let chunks = n / 4;
let mut write = 0;

let vcutoff = _mm_set1_ps(cutoff);

// SAFETY: `base + bit` is in 0..n for all iterations. `write <= base + bit`
// because we only advance write when an element passes, so writes never
// overtake reads.
for c in 0..chunks {
let base = c * 4;
let vs = _mm_loadu_ps(cand_scores.as_ptr().add(base));
let cmp = _mm_cmpgt_ps(vs, vcutoff);
let mask = _mm_movemask_ps(cmp) as u32;

for bit in 0..4u32 {
if mask & (1 << bit) != 0 {
*cand_docs.get_unchecked_mut(write) = *cand_docs.get_unchecked(base + bit as usize);
*cand_scores.get_unchecked_mut(write) =
*cand_scores.get_unchecked(base + bit as usize);
write += 1;
}
}
}

// Scalar remainder
for i in (chunks * 4)..n {
if *cand_scores.get_unchecked(i) > cutoff {
*cand_docs.get_unchecked_mut(write) = *cand_docs.get_unchecked(i);
*cand_scores.get_unchecked_mut(write) = *cand_scores.get_unchecked(i);
write += 1;
}
}

cand_docs.truncate(write);
cand_scores.truncate(write);
}

/// NEON: 4-wide `vcgtq_f32` comparison, extract per-lane masks, scatter survivors.
#[cfg(target_arch = "aarch64")]
unsafe fn filter_competitive_neon(
cand_docs: &mut Vec<u32>,
cand_scores: &mut Vec<f32>,
cutoff: f32,
) {
use std::arch::aarch64::*;

let n = cand_docs.len();
let chunks = n / 4;
let mut write = 0;

let vcutoff = vdupq_n_f32(cutoff);

// SAFETY: same index invariants as SSE2 — write <= read, all indices in 0..n.
for c in 0..chunks {
let base = c * 4;
let vs = vld1q_f32(cand_scores.as_ptr().add(base));
let cmp = vcgtq_f32(vs, vcutoff);

let m0 = vgetq_lane_u32(cmp, 0);
let m1 = vgetq_lane_u32(cmp, 1);
let m2 = vgetq_lane_u32(cmp, 2);
let m3 = vgetq_lane_u32(cmp, 3);

if m0 != 0 {
*cand_docs.get_unchecked_mut(write) = *cand_docs.get_unchecked(base);
*cand_scores.get_unchecked_mut(write) = *cand_scores.get_unchecked(base);
write += 1;
}
if m1 != 0 {
*cand_docs.get_unchecked_mut(write) = *cand_docs.get_unchecked(base + 1);
*cand_scores.get_unchecked_mut(write) = *cand_scores.get_unchecked(base + 1);
write += 1;
}
if m2 != 0 {
*cand_docs.get_unchecked_mut(write) = *cand_docs.get_unchecked(base + 2);
*cand_scores.get_unchecked_mut(write) = *cand_scores.get_unchecked(base + 2);
write += 1;
}
if m3 != 0 {
*cand_docs.get_unchecked_mut(write) = *cand_docs.get_unchecked(base + 3);
*cand_scores.get_unchecked_mut(write) = *cand_scores.get_unchecked(base + 3);
write += 1;
}
}

// Scalar remainder
for i in (chunks * 4)..n {
if *cand_scores.get_unchecked(i) > cutoff {
*cand_docs.get_unchecked_mut(write) = *cand_docs.get_unchecked(i);
*cand_scores.get_unchecked_mut(write) = *cand_scores.get_unchecked(i);
write += 1;
}
}

cand_docs.truncate(write);
cand_scores.truncate(write);
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -711,6 +845,43 @@ mod tests {
assert!(docs.is_empty());
}

#[test]
fn filter_competitive_simd_matches_scalar() {
// Test at various sizes including remainder paths (not multiple of 4).
for n in [1, 3, 4, 5, 7, 8, 9, 15, 16, 17, 100] {
let docs: Vec<u32> = (0..n as u32).collect();
let scores: Vec<f32> = (0..n).map(|i| 0.1 * (i as f32 + 1.0)).collect();
let cutoff = 0.5;

let mut scalar_docs = docs.clone();
let mut scalar_scores = scores.clone();
filter_competitive_scalar(&mut scalar_docs, &mut scalar_scores, cutoff);

let mut simd_docs = docs.clone();
let mut simd_scores = scores.clone();
filter_competitive(&mut simd_docs, &mut simd_scores, cutoff);

assert_eq!(scalar_docs, simd_docs, "docs mismatch at n={n}");
assert_eq!(scalar_scores, simd_scores, "scores mismatch at n={n}");
}
}

#[test]
fn filter_competitive_all_pass() {
let mut docs = vec![1, 2, 3, 4, 5, 6, 7, 8];
let mut scores = vec![1.0; 8];
filter_competitive(&mut docs, &mut scores, 0.0);
assert_eq!(docs.len(), 8);
}

#[test]
fn filter_competitive_none_pass() {
let mut docs = vec![1, 2, 3, 4, 5, 6, 7, 8];
let mut scores = vec![0.1; 8];
filter_competitive(&mut docs, &mut scores, 1.0);
assert!(docs.is_empty());
}

#[test]
fn cursor_from_blocks_single() {
let block = SparsePostingBlock::from_sorted_entries(&[(0, 0.5), (10, 0.9)]).unwrap();
Expand Down
126 changes: 125 additions & 1 deletion rust/types/src/sparse_posting_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -737,14 +737,108 @@ impl Directory {
}
}

// ── f16 → f32 bulk conversion (scalar; SIMD added in PR 4) ─────────
// ── f16 → f32 bulk conversion ────────────────────────────────────────

/// Convert a slice of little-endian f16 bytes into f32 values.
/// Dispatches to SIMD on supported architectures, scalar fallback otherwise.
pub fn convert_f16_to_f32(f16_bytes: &[u8], out: &mut [f32]) {
#[cfg(target_arch = "aarch64")]
{
convert_f16_to_f32_neon(f16_bytes, out);
return;
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("f16c") {
// SAFETY: f16c feature detected at runtime; inputs are valid
// f16 byte slices and output buffer is correctly sized.
unsafe { convert_f16_to_f32_f16c(f16_bytes, out) };
return;
}
}
#[allow(unreachable_code)]
convert_f16_to_f32_scalar(f16_bytes, out);
}

/// Scalar f16→f32 conversion via the `half` crate.
pub fn convert_f16_to_f32_scalar(f16_bytes: &[u8], out: &mut [f32]) {
for (o, chunk) in out.iter_mut().zip(f16_bytes.chunks_exact(2)) {
*o = f16::from_le_bytes([chunk[0], chunk[1]]).to_f32();
}
}

/// NEON f16→f32 via bit manipulation. Processes 8 values at a time.
///
/// Uses integer shift+mask+bias to convert f16 bit patterns to f32 bit
/// patterns. Handles all normal f16 values correctly; subnormals map to
/// a tiny positive value rather than zero, which is acceptable for
/// SPLADE/BM25 weights that are always positive normals.
#[cfg(target_arch = "aarch64")]
fn convert_f16_to_f32_neon(f16_bytes: &[u8], out: &mut [f32]) {
use std::arch::aarch64::*;

let n = out.len();
let chunks = n / 8;

// SAFETY: NEON is always available on aarch64. Pointer arithmetic is
// bounded by `chunks * 8 <= n` for output and `chunks * 16 <= f16_bytes.len()`
// for input (caller guarantees f16_bytes.len() >= out.len() * 2).
unsafe {
let sign_mask = vdupq_n_u32(0x8000);
let nosign_mask = vdupq_n_u32(0x7FFF);
let bias = vdupq_n_u32(0x3800_0000); // (127 - 15) << 23

for c in 0..chunks {
let base = c * 8;
let byte_base = base * 2;

let h8 = vld1q_u16(f16_bytes.as_ptr().add(byte_base) as *const u16);
let lo = vmovl_u16(vget_low_u16(h8));
let hi = vmovl_u16(vget_high_u16(h8));

macro_rules! cvt {
($h:expr, $off:expr) => {{
let sign = vshlq_n_u32::<16>(vandq_u32($h, sign_mask));
let nosign = vshlq_n_u32::<13>(vandq_u32($h, nosign_mask));
let bits = vorrq_u32(sign, vaddq_u32(nosign, bias));
vst1q_f32(
out.as_mut_ptr().add(base + $off),
vreinterpretq_f32_u32(bits),
);
}};
}
cvt!(lo, 0);
cvt!(hi, 4);
}
}

let rem_start = chunks * 8;
convert_f16_to_f32_scalar(&f16_bytes[rem_start * 2..], &mut out[rem_start..]);
}

/// x86_64 F16C: `_mm256_cvtph_ps` converts 8×f16 → 8×f32 in one instruction.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "f16c")]
unsafe fn convert_f16_to_f32_f16c(f16_bytes: &[u8], out: &mut [f32]) {
use std::arch::x86_64::*;

let n = out.len();
let chunks = n / 8;

// SAFETY: f16c target_feature is enabled. Pointer arithmetic is bounded
// by `chunks * 8 <= n` for output and `chunks * 16 <= f16_bytes.len()`.
for c in 0..chunks {
let base = c * 8;
let byte_base = base * 2;
let h8 = _mm_loadu_si128(f16_bytes.as_ptr().add(byte_base) as *const __m128i);
let f8 = _mm256_cvtph_ps(h8);
_mm256_storeu_ps(out.as_mut_ptr().add(base), f8);
}

let rem_start = chunks * 8;
convert_f16_to_f32_scalar(&f16_bytes[rem_start * 2..], &mut out[rem_start..]);
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -1397,6 +1491,36 @@ mod tests {
assert_eq!(out[1], 0.0); // not overwritten: chunks_exact skips trailing
}

// ── SIMD vs scalar f16 conversion consistency ─────────────────────

#[test]
fn convert_f16_simd_matches_scalar() {
// Test at various sizes including remainder paths (not multiple of 8).
for n in [1, 3, 7, 8, 9, 15, 16, 17, 31, 63, 64, 100, 256, 1000] {
let f16_bytes: Vec<u8> = (0..n)
.flat_map(|i| {
let val = 0.01 * (i as f32 + 1.0);
f16::from_f32(val).to_le_bytes()
})
.collect();

let mut scalar_out = vec![0.0f32; n];
let mut simd_out = vec![0.0f32; n];

convert_f16_to_f32_scalar(&f16_bytes, &mut scalar_out);
convert_f16_to_f32(&f16_bytes, &mut simd_out);

for i in 0..n {
assert!(
(scalar_out[i] - simd_out[i]).abs() <= f32::EPSILON,
"mismatch at n={n}, i={i}: scalar={} simd={}",
scalar_out[i],
simd_out[i],
);
}
}
}

// ── Zero-copy methods at various block sizes (incl. remainder paths) ──

#[test]
Expand Down
Loading