@@ -267,57 +267,118 @@ pub const TurboQuant = struct {
267267 }
268268
269269 /// Pre-rotate a query vector: y_q = Π · q. Caller owns the returned slice.
270- /// Call this ONCE per search, then use scanL2Rotated/scanDotRotated per vector.
270+ /// Call this ONCE per search, then use buildDistTable + scanWithTable per vector.
271271 pub fn rotateQuery (self : * const TurboQuant , allocator : Allocator , query : []const f32 ) ! []f32 {
272272 const d : usize = @intCast (self .dims );
273273 const yq = try allocator .alloc (f32 , d );
274274 matvecMul (self .rotation , query , yq , d );
275275 return yq ;
276276 }
277277
278- /// Fast L2 distance from pre-rotated query to a quantized vector.
279- /// O(d) with no rotation — just codebook lookups + SIMD subtraction.
280- pub fn scanL2Rotated (self : * const TurboQuant , rotated_query : []const f32 , quantized : []const u8 ) f32 {
278+ /// Build an ADC (Asymmetric Distance Computation) lookup table for L2.
279+ /// dist_table[j * nc + k] = (rotated_query[j] - centroid[k])^2
280+ /// Call ONCE per query, then scanWithTable is a trivial table lookup per coordinate.
281+ pub fn buildL2Table (self : * const TurboQuant , allocator : Allocator , rotated_query : []const f32 ) ! []f32 {
281282 const d : usize = @intCast (self .dims );
282- const bw = self .bit_width ;
283+ const nc : usize = self .num_centroids ;
283284 const cb = self .codebook ;
284- var j : usize = 0 ;
285- var v_sum : V8 = @splat (0.0 );
286- while (j + 8 <= d ) : (j += 8 ) {
287- const yq_v : V8 = rotated_query [j .. ][0.. 8].* ;
288- var cb_v : V8 = undefined ;
289- inline for (0.. 8) | k | {
290- cb_v [k ] = cb [unpackBits (quantized , j + k , bw )];
285+ const table = try allocator .alloc (f32 , d * nc );
286+ for (0.. d ) | j | {
287+ const qj = rotated_query [j ];
288+ for (0.. nc ) | k | {
289+ const diff = qj - cb [k ];
290+ table [j * nc + k ] = diff * diff ;
291291 }
292- const diff = yq_v - cb_v ;
293- v_sum += diff * diff ;
294292 }
295- var sum = @reduce (.Add , v_sum );
296- while (j < d ) : (j += 1 ) {
293+ return table ;
294+ }
295+
296+ /// Build an ADC lookup table for dot product.
297+ /// dot_table[j * nc + k] = rotated_query[j] * centroid[k]
298+ pub fn buildDotTable (self : * const TurboQuant , allocator : Allocator , rotated_query : []const f32 ) ! []f32 {
299+ const d : usize = @intCast (self .dims );
300+ const nc : usize = self .num_centroids ;
301+ const cb = self .codebook ;
302+ const table = try allocator .alloc (f32 , d * nc );
303+ for (0.. d ) | j | {
304+ const qj = rotated_query [j ];
305+ for (0.. nc ) | k | {
306+ table [j * nc + k ] = qj * cb [k ];
307+ }
308+ }
309+ return table ;
310+ }
311+
312+ /// Ultra-fast scan: just table lookups per coordinate. O(d) with tiny constant.
313+ /// For 4-bit: 2 coords per byte, each is a table lookup + accumulate.
314+ pub fn scanWithTable (self : * const TurboQuant , table : []const f32 , quantized : []const u8 ) f32 {
315+ const d : usize = @intCast (self .dims );
316+ const nc : usize = self .num_centroids ;
317+ const bw = self .bit_width ;
318+ var sum : f32 = 0 ;
319+
320+ if (bw == 4 ) {
321+ // Fast path for 4-bit: 2 nibbles per byte, no bit-shifting needed
322+ var j : usize = 0 ;
323+ var byte_idx : usize = 0 ;
324+ while (j + 2 <= d ) : ({
325+ j += 2 ;
326+ byte_idx += 1 ;
327+ }) {
328+ const b = quantized [byte_idx ];
329+ const lo : usize = b & 0x0F ;
330+ const hi : usize = (b >> 4 ) & 0x0F ;
331+ sum += table [j * nc + lo ];
332+ sum += table [(j + 1 ) * nc + hi ];
333+ }
334+ if (j < d ) {
335+ sum += table [j * nc + @as (usize , quantized [byte_idx ] & 0x0F )];
336+ }
337+ } else if (bw == 2 ) {
338+ // Fast path for 2-bit: 4 values per byte
339+ var j : usize = 0 ;
340+ var byte_idx : usize = 0 ;
341+ while (j + 4 <= d ) : ({
342+ j += 4 ;
343+ byte_idx += 1 ;
344+ }) {
345+ const b = quantized [byte_idx ];
346+ sum += table [j * nc + @as (usize , b & 0x03 )];
347+ sum += table [(j + 1 ) * nc + @as (usize , (b >> 2 ) & 0x03 )];
348+ sum += table [(j + 2 ) * nc + @as (usize , (b >> 4 ) & 0x03 )];
349+ sum += table [(j + 3 ) * nc + @as (usize , (b >> 6 ) & 0x03 )];
350+ }
351+ while (j < d ) : (j += 1 ) {
352+ sum += table [j * nc + @as (usize , unpackBits (quantized , j , bw ))];
353+ }
354+ } else {
355+ // Generic path
356+ for (0.. d ) | j | {
357+ sum += table [j * nc + @as (usize , unpackBits (quantized , j , bw ))];
358+ }
359+ }
360+ return sum ;
361+ }
362+
363+ // Keep the old methods for backward compat / single-vector queries
364+ pub fn scanL2Rotated (self : * const TurboQuant , rotated_query : []const f32 , quantized : []const u8 ) f32 {
365+ const d : usize = @intCast (self .dims );
366+ const bw = self .bit_width ;
367+ const cb = self .codebook ;
368+ var sum : f32 = 0 ;
369+ for (0.. d ) | j | {
297370 const diff = rotated_query [j ] - cb [unpackBits (quantized , j , bw )];
298371 sum += diff * diff ;
299372 }
300373 return sum ;
301374 }
302375
303- /// Fast dot product from pre-rotated query to a quantized vector.
304- /// O(d) with no rotation — just codebook lookups + SIMD multiply.
305376 pub fn scanDotRotated (self : * const TurboQuant , rotated_query : []const f32 , quantized : []const u8 ) f32 {
306377 const d : usize = @intCast (self .dims );
307378 const bw = self .bit_width ;
308379 const cb = self .codebook ;
309- var j : usize = 0 ;
310- var v_sum : V8 = @splat (0.0 );
311- while (j + 8 <= d ) : (j += 8 ) {
312- const yq_v : V8 = rotated_query [j .. ][0.. 8].* ;
313- var cb_v : V8 = undefined ;
314- inline for (0.. 8) | k | {
315- cb_v [k ] = cb [unpackBits (quantized , j + k , bw )];
316- }
317- v_sum += yq_v * cb_v ;
318- }
319- var sum = @reduce (.Add , v_sum );
320- while (j < d ) : (j += 1 ) {
380+ var sum : f32 = 0 ;
381+ for (0.. d ) | j | {
321382 sum += rotated_query [j ] * cb [unpackBits (quantized , j , bw )];
322383 }
323384 return sum ;
0 commit comments