@@ -15,7 +15,9 @@ use arrow_array::{
1515use arrow_schema:: { ArrowError , DataType , Field , FieldRef , Schema , TimeUnit } ;
1616use chrono:: { DateTime , Timelike , Utc } ;
1717use futures:: TryStreamExt ;
18- use lance:: dataset:: mem_wal:: { DatasetMemWalExt , ShardWriterConfig } ;
18+ use lance:: dataset:: mem_wal:: {
19+ DatasetMemWalExt , LsmScanner , ShardManifestStore , ShardSnapshot , ShardWriterConfig ,
20+ } ;
1921use lance:: dataset:: optimize:: { compact_files, CompactionMetrics , CompactionOptions } ;
2022use lance:: dataset:: { builder:: DatasetBuilder , Dataset , WriteMode , WriteParams } ;
2123use lance:: index:: DatasetIndexExt ;
@@ -34,6 +36,7 @@ use crate::record::{ContextRecord, SearchResult, StateMetadata};
3436/// Embedding length used for the semantic index column.
3537const DEFAULT_EMBEDDING_DIM : i32 = 1536 ;
3638const DEFAULT_SEARCH_LIMIT : usize = 10 ;
39+ const DEFAULT_MANIFEST_SCAN_BATCH_SIZE : usize = 16 ;
3740const ID_INDEX_NAME : & str = "id_idx" ;
3841
3942/// Configuration for background compaction.
@@ -280,18 +283,19 @@ impl ContextStore {
280283 limit : Option < usize > ,
281284 offset : Option < usize > ,
282285 ) -> LanceResult < Vec < ContextRecord > > {
283- let mut scanner = self . dataset . scan ( ) ;
284- if let Some ( limit) = limit {
285- scanner. limit ( Some ( limit as i64 ) , offset. map ( |o| o as i64 ) ) ?;
286- } else if let Some ( offset) = offset {
287- scanner. limit ( None , Some ( offset as i64 ) ) ?;
288- }
289-
286+ let scanner = self . lsm_scanner ( ) . await ?;
290287 let mut stream = scanner. try_into_stream ( ) . await ?;
291288 let mut results = Vec :: new ( ) ;
292289 while let Some ( batch) = stream. try_next ( ) . await ? {
293290 results. extend ( batch_to_records ( & batch) ?) ;
294291 }
292+
293+ if let Some ( offset) = offset {
294+ results = results. into_iter ( ) . skip ( offset) . collect ( ) ;
295+ }
296+ if let Some ( limit) = limit {
297+ results. truncate ( limit) ;
298+ }
295299 Ok ( results)
296300 }
297301
@@ -315,18 +319,51 @@ impl ContextStore {
315319 return Ok ( Vec :: new ( ) ) ;
316320 }
317321
318- let query_array = Float32Array :: from ( query. to_vec ( ) ) ;
322+ let mut results: Vec < SearchResult > = self
323+ . list ( None , None )
324+ . await ?
325+ . into_iter ( )
326+ . filter_map ( |record| {
327+ let distance = l2_distance ( query, record. embedding . as_ref ( ) ?) ;
328+ Some ( SearchResult { record, distance } )
329+ } )
330+ . collect ( ) ;
331+ results. sort_by ( |left, right| left. distance . total_cmp ( & right. distance ) ) ;
332+ results. truncate ( top_k) ;
333+ Ok ( results)
334+ }
319335
320- let mut scanner = self . dataset . scan ( ) ;
321- scanner. nearest ( "embedding" , & query_array, top_k) ?;
322- scanner. limit ( Some ( top_k as i64 ) , Some ( 0 ) ) ?;
336+ async fn lsm_scanner ( & self ) -> LanceResult < LsmScanner > {
337+ let object_store = self . dataset . object_store ( None ) . await ?;
338+ let branch_location = self . dataset . branch_location ( ) ;
339+ let shard_ids = self . dataset . list_mem_wal_latest_shard_ids ( ) . await ?;
340+
341+ let mut shard_snapshots = Vec :: with_capacity ( shard_ids. len ( ) ) ;
342+ for shard_id in shard_ids {
343+ let manifest_store = ShardManifestStore :: new (
344+ object_store. clone ( ) ,
345+ & branch_location. path ,
346+ shard_id,
347+ DEFAULT_MANIFEST_SCAN_BATCH_SIZE ,
348+ ) ;
349+ let Some ( manifest) = manifest_store. read_latest ( ) . await ? else {
350+ continue ;
351+ } ;
323352
324- let mut stream = scanner. try_into_stream ( ) . await ?;
325- let mut results = Vec :: new ( ) ;
326- while let Some ( batch) = stream. try_next ( ) . await ? {
327- results. extend ( batch_to_search_results ( & batch) ?) ;
353+ let mut snapshot = ShardSnapshot :: new ( shard_id)
354+ . with_spec_id ( manifest. shard_spec_id )
355+ . with_current_generation ( manifest. current_generation ) ;
356+ for flushed in manifest. flushed_generations {
357+ snapshot = snapshot. with_flushed_generation ( flushed. generation , flushed. path ) ;
358+ }
359+ shard_snapshots. push ( snapshot) ;
328360 }
329- Ok ( results)
361+
362+ Ok ( LsmScanner :: new (
363+ Arc :: new ( self . dataset . clone ( ) ) ,
364+ shard_snapshots,
365+ vec ! [ "id" . to_string( ) ] ,
366+ ) )
330367 }
331368
332369 /// Manually trigger compaction to merge small fragments.
@@ -837,34 +874,6 @@ impl Drop for ContextStore {
837874 }
838875}
839876
840- fn batch_to_search_results ( batch : & RecordBatch ) -> LanceResult < Vec < SearchResult > > {
841- let records = batch_to_records ( batch) ?;
842-
843- let distance_column = batch. column_by_name ( "_distance" ) . ok_or_else ( || {
844- LanceError :: from ( ArrowError :: InvalidArgumentError (
845- "search results missing _distance column" . to_string ( ) ,
846- ) )
847- } ) ?;
848- let distance_array = distance_column
849- . as_ref ( )
850- . as_any ( )
851- . downcast_ref :: < Float32Array > ( )
852- . ok_or_else ( || {
853- LanceError :: from ( ArrowError :: InvalidArgumentError (
854- "_distance column has unexpected data type" . to_string ( ) ,
855- ) )
856- } ) ?;
857-
858- Ok ( records
859- . into_iter ( )
860- . enumerate ( )
861- . map ( |( i, record) | SearchResult {
862- record,
863- distance : distance_array. value ( i) ,
864- } )
865- . collect ( ) )
866- }
867-
868877/// Convert a record batch to context records.
869878fn batch_to_records ( batch : & RecordBatch ) -> LanceResult < Vec < ContextRecord > > {
870879 let id_array = column_as :: < StringArray > ( batch, "id" ) ?;
@@ -1071,6 +1080,17 @@ fn embedding_from_list(list: &FixedSizeListArray, row: usize) -> LanceResult<Vec
10711080 Ok ( embedding)
10721081}
10731082
1083+ fn l2_distance ( left : & [ f32 ] , right : & [ f32 ] ) -> f32 {
1084+ left. iter ( )
1085+ . zip ( right)
1086+ . map ( |( left, right) | {
1087+ let delta = left - right;
1088+ delta * delta
1089+ } )
1090+ . sum :: < f32 > ( )
1091+ . sqrt ( )
1092+ }
1093+
10741094fn column_as < ' a , A > ( batch : & ' a RecordBatch , name : & str ) -> LanceResult < & ' a A >
10751095where
10761096 A : Array + ' static ,
@@ -1144,16 +1164,15 @@ mod tests {
11441164 store. add ( & [ first. clone ( ) , second. clone ( ) ] ) . await . unwrap ( ) ;
11451165
11461166 let query = make_embedding ( 1.0 ) ;
1147- let _results = store. search ( & query, Some ( 2 ) ) . await . unwrap ( ) ;
1148-
1149- // TODO: MemWAL reads are not yet visible via standard scan.
1150- // assert_eq!(results.len(), 2);
1151- // assert_eq!(results[0].record.id, second.id);
1152- // assert!(
1153- // results[0].distance <= results[1].distance,
1154- // "results not ordered by distance: {:?}",
1155- // results
1156- // );
1167+ let results = store. search ( & query, Some ( 2 ) ) . await . unwrap ( ) ;
1168+
1169+ assert_eq ! ( results. len( ) , 2 ) ;
1170+ assert_eq ! ( results[ 0 ] . record. id, second. id) ;
1171+ assert ! (
1172+ results[ 0 ] . distance <= results[ 1 ] . distance,
1173+ "results not ordered by distance: {:?}" ,
1174+ results
1175+ ) ;
11571176 } ) ;
11581177 }
11591178
@@ -1231,13 +1250,12 @@ mod tests {
12311250 let store = ContextStore :: open ( & uri) . await . unwrap ( ) ;
12321251
12331252 // Verify we can list them back
1234- let _results = store. list ( None , None ) . await . unwrap ( ) ;
1235- // TODO: MemWAL reads are not yet visible via standard scan.
1236- // assert_eq!(results.len(), 2);
1253+ let results = store. list ( None , None ) . await . unwrap ( ) ;
1254+ assert_eq ! ( results. len( ) , 2 ) ;
12371255
1238- // let ids: Vec<String> = results.iter().map(|r| r.id.clone()).collect();
1239- // assert!(ids.contains(&"r1".to_string()));
1240- // assert!(ids.contains(&"r2".to_string()));
1256+ let ids: Vec < String > = results. iter ( ) . map ( |r| r. id . clone ( ) ) . collect ( ) ;
1257+ assert ! ( ids. contains( & "r1" . to_string( ) ) ) ;
1258+ assert ! ( ids. contains( & "r2" . to_string( ) ) ) ;
12411259 } ) ;
12421260 }
12431261
0 commit comments