@@ -49,10 +49,11 @@ static inline uint32_t powoftwo(uint32_t x)
4949 return x ;
5050}
5151
52- EMBEDDINGS_API Embeddings * EMBEDDINGS_CALL fileopen (const wchar_t * pwszpath , DWORD access , DWORD dwCreationDisposition , uint32_t dwBlobSize )
52+ EMBEDDINGS_API Embeddings * EMBEDDINGS_CALL fileopen (
53+ const wchar_t * pwszpath , DWORD dwAccess , DWORD dwCreationDisposition , uint32_t dwBlobSize )
5354{
5455 Embeddings * db = malloc (sizeof (Embeddings ));
55- _dbglog (">> fileopen(path='%ls' blob=%u access=0x%08X, disposition=0x%08X);\n" , pwszpath , dwBlobSize , access , dwCreationDisposition );
56+ _dbglog (">> fileopen(path='%ls' blob=%u access=0x%08X, disposition=0x%08X);\n" , pwszpath , dwBlobSize , dwAccess , dwCreationDisposition );
5657 memset (db , 0 , sizeof (* db ));
5758 DWORD flags ;
5859 assert (PATH >= MAX_PATH );
@@ -66,7 +67,7 @@ EMBEDDINGS_API Embeddings* EMBEDDINGS_CALL fileopen(const wchar_t* pwszpath, DWO
6667 }
6768 flags = FILE_ATTRIBUTE_TEMPORARY | FILE_FLAG_DELETE_ON_CLOSE | FILE_FLAG_SEQUENTIAL_SCAN ;
6869 dwCreationDisposition = CREATE_ALWAYS ;
69- access = FILE_READ_DATA | FILE_APPEND_DATA | FILE_WRITE_DATA ;
70+ dwAccess = FILE_READ_DATA | FILE_APPEND_DATA | FILE_WRITE_DATA ;
7071 pwszpath = db -> wszPath ;
7172 } else {
7273 if (!GetFullPathNameW (pwszpath , PATH , db -> wszPath , NULL )) {
@@ -80,14 +81,14 @@ EMBEDDINGS_API Embeddings* EMBEDDINGS_CALL fileopen(const wchar_t* pwszpath, DWO
8081 _dbglog ("path='%ls' blob=%u access=0x%08X, disposition=0x%08X\n" ,
8182 pwszpath ,
8283 dwBlobSize ,
83- access ,
84+ dwAccess ,
8485 dwCreationDisposition );
8586 GetSystemInfo (& db -> os );
8687 db -> os .dwPageSize = db -> os .dwPageSize ? db -> os .dwPageSize : 4096 ;
8788 db -> os .dwAllocationGranularity = db -> os .dwAllocationGranularity ? db -> os .dwAllocationGranularity : 65536 ;
8889 if (dwBlobSize > MAXBLOB ) {
8990 free (db );
90- fprintf (stderr , "Maximum blob Size is %lu.\n" , MAXBLOB );
91+ fprintf (stderr , "The specified blob size %lu is invalid. Maximum blob size is %lu.\n" , dwBlobSize , MAXBLOB );
9192 return NULL ;
9293 }
9394 if (dwBlobSize > 0 ) {
@@ -124,10 +125,10 @@ EMBEDDINGS_API Embeddings* EMBEDDINGS_CALL fileopen(const wchar_t* pwszpath, DWO
124125 fprintf (stderr , "Invalid header size.\n" );
125126 return NULL ;
126127 }
127- db -> access = access ;
128+ db -> access = dwAccess ;
128129 db -> dwCreationDisposition = dwCreationDisposition ;
129130 db -> hWrite = CreateFileW (pwszpath ,
130- access ,
131+ dwAccess ,
131132 FILE_SHARE_READ | FILE_SHARE_WRITE ,
132133 NULL ,
133134 dwCreationDisposition ,
@@ -392,7 +393,41 @@ static inline float _normf(const float* a, DWORD n) {
392393
393394const float EPSILON = 1e-6f ;
394395
395- EMBEDDINGS_API int32_t EMBEDDINGS_CALL cosinesearch (
396+ void cosine (const float * query , uint32_t len ,
397+ float qnorm ,
398+ uint8_t * buff , float min ,
399+ size_t * num ,
400+ uint32_t topk ,
401+ Score * heap ,
402+ BOOL bNorm )
403+ {
404+ const uiid * id = (const uiid * )buff ;
405+ const float * blob = (const float * )(buff + sizeof (uiid ));
406+ float norm = bNorm
407+ ? _normf (blob , len )
408+ : 1 ;
409+ if (norm > EPSILON ) {
410+ double dot = _dotf (blob , query , len );
411+ float score = (float )(dot / ((double )qnorm * (double )norm ));
412+ if (score >= min ) {
413+ if (* num < topk ) {
414+ // start accumulating until we fill the heap
415+ _uiidcpy (& heap [* num ].id , id );
416+ heap [* num ].score = score ;
417+ (* num ) = (* num ) + 1 ;
418+ qsort (heap , * num , sizeof (Score ), _score );
419+ }
420+ else if (score > heap [topk - 1 ].score ) {
421+ // evict the lowest score
422+ _uiidcpy (& heap [topk - 1 ].id , id );
423+ heap [topk - 1 ].score = score ;
424+ qsort (heap , * num , sizeof (Score ), _score );
425+ }
426+ }
427+ }
428+ }
429+
430+ EMBEDDINGS_API int32_t EMBEDDINGS_CALL filesearch (
396431 Embeddings * db ,
397432 const float * query , uint32_t len ,
398433 uint32_t topk ,
@@ -409,9 +444,17 @@ EMBEDDINGS_API int32_t EMBEDDINGS_CALL cosinesearch(
409444 fprintf (stderr , "The specified query pointer is NULL.\n" );
410445 return -1 ;
411446 }
447+ float qnorm = bNorm
448+ ? _normf (query , len )
449+ : 1 ;
450+ _dbglog ("qnorm = %f;\n" , qnorm );
451+ if (qnorm < EPSILON ) {
452+ fprintf (stderr , "Query vector norm too small (%.8g).\n" , qnorm );
453+ return -1 ;
454+ }
412455 if (len == 0 ) {
413456 fprintf (stderr , "The specified query length is zero.\n" );
414- return FALSE ;
457+ return -1 ;
415458 }
416459 if (!db -> hWrite || db -> hWrite == INVALID_HANDLE_VALUE ) {
417460 fprintf (stderr , "The specified database is closed or invalid.\n" );
@@ -439,93 +482,83 @@ EMBEDDINGS_API int32_t EMBEDDINGS_CALL cosinesearch(
439482 fprintf (stderr , "Failed to duplicate file handle for search (system error %lu).\n" , GetLastError ());
440483 return -1 ;
441484 }
442- size_t cc = __alignup (sizeof (uiid ) + db -> header .blobSize , db -> header .alignment );
443- uint8_t * buff = (uint8_t * )_aligned_malloc (cc , db -> header .alignment );
444- if (!buff ) {
445- fprintf (stderr , "Memory allocation failed while preparing the read buffer.\n" );
446- CloseHandle (hRead );
447- return -1 ;
448- }
449485 LARGE_INTEGER offset = { MAXHEAD };
450486 if (!SetFilePointerEx (hRead , offset , NULL , FILE_BEGIN )) {
451487 fprintf (stderr , "Failed to seek to the first record (system error %lu).\n" , GetLastError ());
452- _aligned_free (buff );
453488 CloseHandle (hRead );
454489 return -1 ;
455490 }
456491 Score * heap = (Score * )calloc (topk , sizeof (Score ));
457492 if (!heap ) {
458493 fprintf (stderr , "Memory allocation failed while preparing the top-k heap.\n" );
459- _aligned_free (buff );
460494 CloseHandle (hRead );
461495 return -1 ;
462496 }
463- float qnorm = bNorm
464- ? _normf (query , len )
465- : 1 ;
466- _dbglog ("qnorm = %f;\n" , qnorm );
467- if (qnorm < EPSILON ) {
468- fprintf (stderr , "Query vector norm too small (%.8g).\n" , qnorm );
497+ uint32_t stride = __alignup (sizeof (uiid ) + db -> header .blobSize , db -> header .alignment );
498+ uint8_t * carry = (uint8_t * )malloc (stride );
499+ if (!carry ) {
500+ fprintf (stderr , "Memory allocation failed while preparing the read buffers.\n" );
469501 free (heap );
470- _aligned_free (buff );
471502 CloseHandle (hRead );
472503 return -1 ;
473504 }
474- DWORD count = 0 ;
505+ const uint32_t MAX = 1024 ;
506+ uint8_t * big = (uint8_t * )_aligned_malloc ((size_t )(MAX * stride ), db -> header .alignment );
507+ if (!big ) {
508+ fprintf (stderr , "Memory allocation failed while preparing the read buffers.\n" );
509+ free (heap );
510+ free (carry );
511+ CloseHandle (hRead );
512+ return -1 ;
513+ }
514+ size_t num = 0 ; size_t leftoverBytes = 0 ;
475515 for (;;) {
476- DWORD bytesRead = 0 ;
477- BOOL ok = ReadFile (hRead , buff , (DWORD )cc , & bytesRead , NULL );
478- if (!ok ) {
479- // _dbglog("EOF\n");
480- break ;
516+ if (leftoverBytes ) {
517+ memcpy (big , carry , leftoverBytes );
481518 }
482- if (bytesRead < cc ) {
483- // _dbglog("partial read expected: %lu, read: %lu\n", (unsigned long)cc, (unsigned long)bytesRead);
484- break ;
485- }
486- const uiid * id = (const uiid * )buff ;
487- const float * blob = (const float * )(buff + sizeof (uiid ));
488- float norm = bNorm
489- ? _normf (blob , len )
490- : 1 ;
491- // _dbglog("norm = %f;\n", norm);
492- if (norm < EPSILON )
493- continue ;
494- double dot = _dotf (blob , query , len );
495- // _dbglog("dot = %f;\n", dot);
496- float score = (float )(dot / ((double )qnorm * (double )norm ));
497- // _dbglog("score = %f;\n", score);
498- if (score < min ) {
499- continue ; /* prune below threshold */
500- }
501- if (count < topk ) {
502- _uiidcpy (& heap [count ].id , id );
503- heap [count ].score = score ;
504- ++ count ;
505- if (count == topk ) {
506- qsort (heap , count , sizeof (Score ), _score );
507- }
519+ DWORD bytesRead = 0 ;
520+ BOOL ok = ReadFile (hRead , big + leftoverBytes , (DWORD )(MAX * stride - leftoverBytes ), & bytesRead , NULL );
521+ if (!ok || bytesRead == 0 ) {
522+ break ; // EOF
508523 }
509- else if (score > heap [topk - 1 ].score ) {
510- _uiidcpy (& heap [topk - 1 ].id , id );
511- heap [topk - 1 ].score = score ;
512- qsort (heap , count , sizeof (Score ), _score );
524+ size_t total = leftoverBytes + bytesRead ;
525+ size_t offset = 0 ;
526+ while (offset + stride <= total ) {
527+ cosine (
528+ query ,
529+ len ,
530+ qnorm ,
531+ (uint8_t * )big + offset ,
532+ min ,
533+ & num ,
534+ topk ,
535+ heap ,
536+ bNorm );
537+ offset += stride ;
538+ }
539+ leftoverBytes = total - offset ;
540+ if (leftoverBytes ) {
541+ memcpy (carry , big + offset , leftoverBytes );
513542 }
543+
514544 }
515- assert (count <= topk );
516- /* Return in descending order */
545+ assert (num <= topk );
517546 memset (scores , 0 , topk * sizeof (Score ));
518- for (DWORD i = 0 ; i < count ; ++ i ) {
547+ for (DWORD i = 0 ; i < num ; ++ i ) {
519548 _uiidcpy (& scores [i ].id , & heap [i ].id );
520549 scores [i ].score = heap [i ].score ;
521550 }
551+ _aligned_free (big );
552+ free (carry );
522553 free (heap );
523- _aligned_free (buff );
524554 CloseHandle (hRead );
525- _dbglog ("filesearch() = %d ;\n" , count );
526- return count ;
555+ _dbglog ("filesearch() = %u ;\n" , ( unsigned int ) num );
556+ return num ;
527557}
528558
559+
560+ /* Cursor API is desined for offline processing. It should not be used on a live index for upserting. */
561+
529562EMBEDDINGS_API void EMBEDDINGS_CALL cursorclose (Cursor * cur )
530563{
531564 _dbglog ("Cursor_close();\n" );
@@ -1227,9 +1260,7 @@ static PyObject* PyEmbeddings_Search(PyEmbeddingsObject* self, PyObject* args, P
12271260 PyObject * len_obj = NULL ;
12281261 DWORD len = 0 , topk = 0 ;
12291262 float threshold = 0.0f ;
1230- /* 0 = use as-is, 1 = L2-normalize */
1231- int norm = 0 ;
1232-
1263+ int norm = 1 ; // Normalize by default
12331264 if (!PyArg_ParseTupleAndKeywords (args , kwds , "y*|OIfp:search" , kwlist ,
12341265 & buf , & len_obj , & topk , & threshold , & norm )) {
12351266 return NULL ;
@@ -1289,7 +1320,7 @@ static PyObject* PyEmbeddings_Search(PyEmbeddingsObject* self, PyObject* args, P
12891320 return NULL ;
12901321 }
12911322
1292- int32_t count = cosinesearch (self -> db ,
1323+ int32_t count = filesearch (self -> db ,
12931324 (const float * )buf .buf ,
12941325 len ,
12951326 topk ,
0 commit comments