Skip to content

Commit 6f84b94

Browse files
committed
Big perf win by mini-batching during search.
1 parent f26c0bd commit 6f84b94

File tree

5 files changed

+125
-77
lines changed

5 files changed

+125
-77
lines changed

examples/dotnet/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
This is still work in progress.

examples/dotnet/dotnet.csproj

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@
5858
<ItemGroup>
5959
<None Include="App.config" />
6060
</ItemGroup>
61-
<ItemGroup />
61+
<ItemGroup>
62+
<None Include="README.md" />
63+
</ItemGroup>
6264
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
6365
</Project>

examples/test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060

6161
print(numpy.frombuffer(query[1], dtype=numpy.float32))
6262

63-
results = db.search(query[1], topk=topk)
63+
results = db.search(query[1], topk=topk, norm=True)
6464

6565
for id, score in results:
6666
print("id:", id, "score:", score)

src/embeddings.c

Lines changed: 102 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,11 @@ static inline uint32_t powoftwo(uint32_t x)
4949
return x;
5050
}
5151

52-
EMBEDDINGS_API Embeddings* EMBEDDINGS_CALL fileopen(const wchar_t* pwszpath, DWORD access, DWORD dwCreationDisposition, uint32_t dwBlobSize)
52+
EMBEDDINGS_API Embeddings* EMBEDDINGS_CALL fileopen(
53+
const wchar_t* pwszpath, DWORD dwAccess, DWORD dwCreationDisposition, uint32_t dwBlobSize)
5354
{
5455
Embeddings* db = malloc(sizeof(Embeddings));
55-
_dbglog(">> fileopen(path='%ls' blob=%u access=0x%08X, disposition=0x%08X);\n", pwszpath, dwBlobSize, access, dwCreationDisposition);
56+
_dbglog(">> fileopen(path='%ls' blob=%u access=0x%08X, disposition=0x%08X);\n", pwszpath, dwBlobSize, dwAccess, dwCreationDisposition);
5657
memset(db, 0, sizeof(*db));
5758
DWORD flags;
5859
assert(PATH >= MAX_PATH);
@@ -66,7 +67,7 @@ EMBEDDINGS_API Embeddings* EMBEDDINGS_CALL fileopen(const wchar_t* pwszpath, DWO
6667
}
6768
flags = FILE_ATTRIBUTE_TEMPORARY | FILE_FLAG_DELETE_ON_CLOSE | FILE_FLAG_SEQUENTIAL_SCAN;
6869
dwCreationDisposition = CREATE_ALWAYS;
69-
access = FILE_READ_DATA | FILE_APPEND_DATA | FILE_WRITE_DATA;
70+
dwAccess = FILE_READ_DATA | FILE_APPEND_DATA | FILE_WRITE_DATA;
7071
pwszpath = db->wszPath;
7172
} else {
7273
if (!GetFullPathNameW(pwszpath, PATH, db->wszPath, NULL)) {
@@ -80,14 +81,14 @@ EMBEDDINGS_API Embeddings* EMBEDDINGS_CALL fileopen(const wchar_t* pwszpath, DWO
8081
_dbglog("path='%ls' blob=%u access=0x%08X, disposition=0x%08X\n",
8182
pwszpath,
8283
dwBlobSize,
83-
access,
84+
dwAccess,
8485
dwCreationDisposition);
8586
GetSystemInfo(&db->os);
8687
db->os.dwPageSize = db->os.dwPageSize ? db->os.dwPageSize : 4096;
8788
db->os.dwAllocationGranularity = db->os.dwAllocationGranularity ? db->os.dwAllocationGranularity : 65536;
8889
if (dwBlobSize > MAXBLOB) {
8990
free(db);
90-
fprintf(stderr, "Maximum blob Size is %lu.\n", MAXBLOB);
91+
fprintf(stderr, "The specified blob size %lu is invalid. Maximum blob size is %lu.\n", dwBlobSize, MAXBLOB);
9192
return NULL;
9293
}
9394
if (dwBlobSize > 0) {
@@ -124,10 +125,10 @@ EMBEDDINGS_API Embeddings* EMBEDDINGS_CALL fileopen(const wchar_t* pwszpath, DWO
124125
fprintf(stderr, "Invalid header size.\n");
125126
return NULL;
126127
}
127-
db->access = access;
128+
db->access = dwAccess;
128129
db->dwCreationDisposition = dwCreationDisposition;
129130
db->hWrite = CreateFileW(pwszpath,
130-
access,
131+
dwAccess,
131132
FILE_SHARE_READ | FILE_SHARE_WRITE,
132133
NULL,
133134
dwCreationDisposition,
@@ -392,7 +393,41 @@ static inline float _normf(const float* a, DWORD n) {
392393

393394
const float EPSILON = 1e-6f;
394395

395-
EMBEDDINGS_API int32_t EMBEDDINGS_CALL cosinesearch(
396+
void cosine(const float* query, uint32_t len,
397+
float qnorm,
398+
uint8_t* buff, float min,
399+
size_t* num,
400+
uint32_t topk,
401+
Score* heap,
402+
BOOL bNorm)
403+
{
404+
const uiid* id = (const uiid*)buff;
405+
const float* blob = (const float*)(buff + sizeof(uiid));
406+
float norm = bNorm
407+
? _normf(blob, len)
408+
: 1;
409+
if (norm > EPSILON) {
410+
double dot = _dotf(blob, query, len);
411+
float score = (float)(dot / ((double)qnorm * (double)norm));
412+
if (score >= min) {
413+
if (*num < topk) {
414+
// start accumulating until we fill the heap
415+
_uiidcpy(&heap[*num].id, id);
416+
heap[*num].score = score;
417+
(*num) = (*num) + 1;
418+
qsort(heap, *num, sizeof(Score), _score);
419+
}
420+
else if (score > heap[topk - 1].score) {
421+
// evict the lowest score
422+
_uiidcpy(&heap[topk - 1].id, id);
423+
heap[topk - 1].score = score;
424+
qsort(heap, *num, sizeof(Score), _score);
425+
}
426+
}
427+
}
428+
}
429+
430+
EMBEDDINGS_API int32_t EMBEDDINGS_CALL filesearch(
396431
Embeddings* db,
397432
const float* query, uint32_t len,
398433
uint32_t topk,
@@ -409,9 +444,17 @@ EMBEDDINGS_API int32_t EMBEDDINGS_CALL cosinesearch(
409444
fprintf(stderr, "The specified query pointer is NULL.\n");
410445
return -1;
411446
}
447+
float qnorm = bNorm
448+
? _normf(query, len)
449+
: 1;
450+
_dbglog("qnorm = %f;\n", qnorm);
451+
if (qnorm < EPSILON) {
452+
fprintf(stderr, "Query vector norm too small (%.8g).\n", qnorm);
453+
return -1;
454+
}
412455
if (len == 0) {
413456
fprintf(stderr, "The specified query length is zero.\n");
414-
return FALSE;
457+
return -1;
415458
}
416459
if (!db->hWrite || db->hWrite == INVALID_HANDLE_VALUE) {
417460
fprintf(stderr, "The specified database is closed or invalid.\n");
@@ -439,93 +482,83 @@ EMBEDDINGS_API int32_t EMBEDDINGS_CALL cosinesearch(
439482
fprintf(stderr, "Failed to duplicate file handle for search (system error %lu).\n", GetLastError());
440483
return -1;
441484
}
442-
size_t cc = __alignup(sizeof(uiid) + db->header.blobSize, db->header.alignment);
443-
uint8_t* buff = (uint8_t*)_aligned_malloc(cc, db->header.alignment);
444-
if (!buff) {
445-
fprintf(stderr, "Memory allocation failed while preparing the read buffer.\n");
446-
CloseHandle(hRead);
447-
return -1;
448-
}
449485
LARGE_INTEGER offset = { MAXHEAD };
450486
if (!SetFilePointerEx(hRead, offset, NULL, FILE_BEGIN)) {
451487
fprintf(stderr, "Failed to seek to the first record (system error %lu).\n", GetLastError());
452-
_aligned_free(buff);
453488
CloseHandle(hRead);
454489
return -1;
455490
}
456491
Score* heap = (Score*)calloc(topk, sizeof(Score));
457492
if (!heap) {
458493
fprintf(stderr, "Memory allocation failed while preparing the top-k heap.\n");
459-
_aligned_free(buff);
460494
CloseHandle(hRead);
461495
return -1;
462496
}
463-
float qnorm = bNorm
464-
? _normf(query, len)
465-
: 1;
466-
_dbglog("qnorm = %f;\n", qnorm);
467-
if (qnorm < EPSILON) {
468-
fprintf(stderr, "Query vector norm too small (%.8g).\n", qnorm);
497+
uint32_t stride = __alignup(sizeof(uiid) + db->header.blobSize, db->header.alignment);
498+
uint8_t* carry = (uint8_t*)malloc(stride);
499+
if (!carry) {
500+
fprintf(stderr, "Memory allocation failed while preparing the read buffers.\n");
469501
free(heap);
470-
_aligned_free(buff);
471502
CloseHandle(hRead);
472503
return -1;
473504
}
474-
DWORD count = 0;
505+
const uint32_t MAX = 1024;
506+
uint8_t* big = (uint8_t*)_aligned_malloc((size_t)(MAX * stride), db->header.alignment);
507+
if (!big) {
508+
fprintf(stderr, "Memory allocation failed while preparing the read buffers.\n");
509+
free(heap);
510+
free(carry);
511+
CloseHandle(hRead);
512+
return -1;
513+
}
514+
size_t num = 0; size_t leftoverBytes = 0;
475515
for (;;) {
476-
DWORD bytesRead = 0;
477-
BOOL ok = ReadFile(hRead, buff, (DWORD)cc, &bytesRead, NULL);
478-
if (!ok) {
479-
// _dbglog("EOF\n");
480-
break;
516+
if (leftoverBytes) {
517+
memcpy(big, carry, leftoverBytes);
481518
}
482-
if (bytesRead < cc) {
483-
// _dbglog("partial read expected: %lu, read: %lu\n", (unsigned long)cc, (unsigned long)bytesRead);
484-
break;
485-
}
486-
const uiid* id = (const uiid*)buff;
487-
const float* blob = (const float*)(buff + sizeof(uiid));
488-
float norm = bNorm
489-
? _normf(blob, len)
490-
: 1;
491-
// _dbglog("norm = %f;\n", norm);
492-
if (norm < EPSILON)
493-
continue;
494-
double dot = _dotf(blob, query, len);
495-
// _dbglog("dot = %f;\n", dot);
496-
float score = (float)(dot / ((double)qnorm * (double)norm));
497-
// _dbglog("score = %f;\n", score);
498-
if (score < min) {
499-
continue; /* prune below threshold */
500-
}
501-
if (count < topk) {
502-
_uiidcpy(&heap[count].id, id);
503-
heap[count].score = score;
504-
++count;
505-
if (count == topk) {
506-
qsort(heap, count, sizeof(Score), _score);
507-
}
519+
DWORD bytesRead = 0;
520+
BOOL ok = ReadFile(hRead, big + leftoverBytes, (DWORD)(MAX * stride - leftoverBytes), &bytesRead, NULL);
521+
if (!ok || bytesRead == 0) {
522+
break; // EOF
508523
}
509-
else if (score > heap[topk - 1].score) {
510-
_uiidcpy(&heap[topk - 1].id, id);
511-
heap[topk - 1].score = score;
512-
qsort(heap, count, sizeof(Score), _score);
524+
size_t total = leftoverBytes + bytesRead;
525+
size_t offset = 0;
526+
while (offset + stride <= total) {
527+
cosine(
528+
query,
529+
len,
530+
qnorm,
531+
(uint8_t*)big + offset,
532+
min,
533+
&num,
534+
topk,
535+
heap,
536+
bNorm);
537+
offset += stride;
538+
}
539+
leftoverBytes = total - offset;
540+
if (leftoverBytes) {
541+
memcpy(carry, big + offset, leftoverBytes);
513542
}
543+
514544
}
515-
assert(count <= topk);
516-
/* Return in descending order */
545+
assert(num <= topk);
517546
memset(scores, 0, topk * sizeof(Score));
518-
for (DWORD i = 0; i < count; ++i) {
547+
for (DWORD i = 0; i < num; ++i) {
519548
_uiidcpy(&scores[i].id, &heap[i].id);
520549
scores[i].score = heap[i].score;
521550
}
551+
_aligned_free(big);
552+
free(carry);
522553
free(heap);
523-
_aligned_free(buff);
524554
CloseHandle(hRead);
525-
_dbglog("filesearch() = %d;\n", count);
526-
return count;
555+
_dbglog("filesearch() = %u;\n", (unsigned int)num);
556+
return num;
527557
}
528558

559+
560+
/* Cursor API is desined for offline processing. It should not be used on a live index for upserting. */
561+
529562
EMBEDDINGS_API void EMBEDDINGS_CALL cursorclose(Cursor* cur)
530563
{
531564
_dbglog("Cursor_close();\n");
@@ -1227,9 +1260,7 @@ static PyObject* PyEmbeddings_Search(PyEmbeddingsObject* self, PyObject* args, P
12271260
PyObject* len_obj = NULL;
12281261
DWORD len = 0, topk = 0;
12291262
float threshold = 0.0f;
1230-
/* 0 = use as-is, 1 = L2-normalize */
1231-
int norm = 0;
1232-
1263+
int norm = 1; // Normalize by default
12331264
if (!PyArg_ParseTupleAndKeywords(args, kwds, "y*|OIfp:search", kwlist,
12341265
&buf, &len_obj, &topk, &threshold, &norm)) {
12351266
return NULL;
@@ -1289,7 +1320,7 @@ static PyObject* PyEmbeddings_Search(PyEmbeddingsObject* self, PyObject* args, P
12891320
return NULL;
12901321
}
12911322

1292-
int32_t count = cosinesearch(self->db,
1323+
int32_t count = filesearch(self->db,
12931324
(const float*)buf.buf,
12941325
len,
12951326
topk,

src/embeddings.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@
2525
extern "C" {
2626
#endif
2727

28+
typedef enum DTYPE {
29+
DTYPE_FLOAT32 = 0, /* 4 bytes per component */
30+
DTYPE_FLOAT16 = 1, /* 2 bytes per component (IEEE 754 half) */
31+
DTYPE_INT8 = 2 /* per-vector: [float scale][dim x int8_t] */
32+
} DTYPE;
33+
2834
#pragma pack(push, 1)
2935
typedef struct uiid {
3036
unsigned char bytes[16];
@@ -42,7 +48,7 @@ extern "C" {
4248
} FileHeader;
4349
#pragma pack(pop)
4450

45-
#define PATH 1024
51+
#define PATH 1024
4652

4753
#pragma pack(push, 1)
4854
typedef struct Embeddings {
@@ -55,8 +61,12 @@ extern "C" {
5561
} Embeddings;
5662
#pragma pack(pop)
5763

58-
EMBEDDINGS_API Embeddings* EMBEDDINGS_CALL fileopen(const wchar_t* szPath, DWORD access, DWORD dwCreationDisposition, uint32_t dwBlobSize);
59-
EMBEDDINGS_API BOOL EMBEDDINGS_CALL fileappend(Embeddings* db, uiid id, const void* blob, DWORD blobSize, BOOL bFlush);
64+
EMBEDDINGS_API Embeddings* EMBEDDINGS_CALL fileopen(
65+
const wchar_t* szPath, DWORD dwAccess, DWORD dwCreationDisposition,
66+
uint32_t dwBlobSize);
67+
EMBEDDINGS_API BOOL EMBEDDINGS_CALL fileappend(
68+
Embeddings* db, uiid id,
69+
const void* blob, DWORD blobSize, BOOL bFlush);
6070
EMBEDDINGS_API BOOL EMBEDDINGS_CALL fileflush(Embeddings* db);
6171
EMBEDDINGS_API void EMBEDDINGS_CALL fileclose(Embeddings* db);
6272
EMBEDDINGS_API uint32_t EMBEDDINGS_CALL fileversion(Embeddings* db);
@@ -68,14 +78,16 @@ extern "C" {
6878
} Score;
6979
#pragma pack(pop)
7080

71-
EMBEDDINGS_API int32_t EMBEDDINGS_CALL cosinesearch(
81+
EMBEDDINGS_API int32_t EMBEDDINGS_CALL filesearch(
7282
Embeddings* db,
7383
const float* query, uint32_t len,
7484
uint32_t topk,
7585
Score* scores,
7686
float min,
7787
BOOL bNorm);
7888

89+
void cosine(const float* query, uint32_t len, float qnorm, uint8_t* buff, float min, size_t* pnum, uint32_t topk, Score* heap, BOOL bNorm);
90+
7991
#pragma pack(push, 1)
8092
typedef struct Cursor {
8193
HANDLE hReadWrite;
@@ -89,6 +101,8 @@ extern "C" {
89101
} Cursor;
90102
#pragma pack(pop)
91103

104+
/* Cursor API is desined for offline processing. It should not be used on a live index for upserting. */
105+
92106
EMBEDDINGS_API Cursor* EMBEDDINGS_CALL cursoropen(Embeddings* db, BOOL bReadOnly);
93107
EMBEDDINGS_API void EMBEDDINGS_CALL cursorclose(Cursor* cur);
94108
EMBEDDINGS_API BOOL EMBEDDINGS_CALL cursorreset(Cursor* cur);

0 commit comments

Comments
 (0)