Skip to content

Commit 60fa28f

Browse files
GaureeshAnvekarGaureesh Anvekar
andauthored
Hybrid search support - addition of bm25 keyword search (#61)
Co-authored-by: Gaureesh Anvekar <[email protected]>
1 parent e8a5ba6 commit 60fa28f

File tree

12 files changed

+7972
-1756
lines changed

12 files changed

+7972
-1756
lines changed

bin/vectra.js

100644100755
File mode changed.

package-lock.json

Lines changed: 6001 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,23 @@
3535
"openai": "^3.2.1",
3636
"turndown": "^7.1.2",
3737
"uuid": "^9.0.0",
38+
"wink-nlp": "^2.3.2",
3839
"yargs": "^17.7.2"
3940
},
40-
"resolutions": {
41-
},
41+
"resolutions": {},
4242
"devDependencies": {
43-
"@types/node": "^14.14.31",
44-
"@types/mocha": "^8.2.0",
4543
"@types/assert": "^1.5.3",
44+
"@types/mocha": "^8.2.0",
45+
"@types/node": "^14.14.31",
4646
"@types/turndown": "^5.0.1",
4747
"@types/uuid": "9.0.1",
4848
"@types/yargs": "17.0.24",
4949
"mocha": "10.2.0",
5050
"nyc": "^15.1.0",
5151
"shx": "^0.3.2",
5252
"ts-mocha": "10.0.0",
53-
"typescript": "^4.2.3"
53+
"typescript": "^4.2.3",
54+
"wink-bm25-text-search": "^3.1.2"
5455
},
5556
"scripts": {
5657
"build": "tsc -b",

samples/wikipedia/.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
vectra.keys
2-
index
2+
index
3+
index-files
4+
index-wiki

src/LocalDocumentIndex.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ export interface DocumentQueryOptions {
3030
* Optional. Filter to apply to the document metadata.
3131
*/
3232
filter?: MetadataFilter;
33+
34+
/**
35+
* Optional. Turn on bm25 keyword search to perform hybrid search - semantic + keyword
36+
*/
37+
isBm25?: boolean;
38+
3339
}
3440

3541
/**
@@ -378,7 +384,7 @@ export class LocalDocumentIndex extends LocalIndex<DocumentChunkMetadata> {
378384
}
379385

380386
// Query index for chunks
381-
const results = await this.queryItems(embeddings.output![0], options.maxChunks!, options.filter);
387+
const results = await this.queryItems(embeddings.output![0], query, options.maxChunks!, options.filter, options.isBm25);
382388

383389
// Group chunks by document
384390
const documentChunks: { [documentId: string]: QueryResult<DocumentChunkMetadata>[]; } = {};

src/LocalDocumentResult.ts

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ export class LocalDocumentResult extends LocalDocument {
6666
startPos: startPos + offset,
6767
endPos: startPos + offset + chunkLength - 1,
6868
score: chunk.score,
69-
tokenCount: chunkLength
69+
tokenCount: chunkLength,
70+
isBm25: false
7071
});
7172
offset += chunkLength;
7273
}
@@ -103,7 +104,8 @@ export class LocalDocumentResult extends LocalDocument {
103104
return {
104105
text: text,
105106
tokenCount: section.tokenCount,
106-
score: section.score
107+
score: section.score,
108+
isBm25: false,
107109
};
108110
});
109111
}
@@ -127,7 +129,8 @@ export class LocalDocumentResult extends LocalDocument {
127129
return [{
128130
text,
129131
tokenCount: length,
130-
score: 1.0
132+
score: 1.0,
133+
isBm25: false,
131134
}];
132135
}
133136

@@ -148,7 +151,8 @@ export class LocalDocumentResult extends LocalDocument {
148151
startPos,
149152
endPos,
150153
score: chunk.score,
151-
tokenCount: this._tokenizer.encode(chunkText).length
154+
tokenCount: this._tokenizer.encode(chunkText).length,
155+
isBm25: Boolean(chunk.item.metadata.isBm25),
152156
};
153157
}).filter(chunk => chunk.tokenCount <= maxTokens).sort((a, b) => a.startPos - b.startPos);
154158

@@ -163,36 +167,63 @@ export class LocalDocumentResult extends LocalDocument {
163167
return [{
164168
text: this._tokenizer.decode(tokens.slice(0, maxTokens)),
165169
tokenCount: maxTokens,
166-
score: topChunk.score
170+
score: topChunk.score,
171+
isBm25: false,
167172
}];
168173
}
169174

170-
// Generate sections
175+
// Generate semantic sections
171176
const sections: Section[] = [];
172177
for (let i = 0; i < chunks.length; i++) {
173178
const chunk = chunks[i];
174179
let section = sections[sections.length - 1];
175-
if (!section || section.tokenCount + chunk.tokenCount > maxTokens) {
176-
section = {
177-
chunks: [],
178-
score: 0,
179-
tokenCount: 0
180-
};
181-
sections.push(section);
180+
if (!chunk.isBm25) {
181+
if (!section || section.tokenCount + chunk.tokenCount > maxTokens) {
182+
section = {
183+
chunks: [],
184+
score: 0,
185+
tokenCount: 0
186+
};
187+
sections.push(section);
188+
}
189+
section.chunks.push(chunk);
190+
section.score += chunk.score;
191+
section.tokenCount += chunk.tokenCount;
182192
}
183-
section.chunks.push(chunk);
184-
section.score += chunk.score;
185-
section.tokenCount += chunk.tokenCount;
186193
}
187194

195+
// Generate bm25 sections
196+
const bm25Sections: Section[] = [];
197+
for (let i = 0; i < chunks.length; i++) {
198+
const chunk = chunks[i];
199+
let section = bm25Sections[bm25Sections.length - 1];
200+
if (chunk.isBm25) {
201+
if (!section || section.tokenCount + chunk.tokenCount > maxTokens) {
202+
section = {
203+
chunks: [],
204+
score: 0,
205+
tokenCount: 0
206+
};
207+
bm25Sections.push(section);
208+
}
209+
section.chunks.push(chunk);
210+
section.score += chunk.score;
211+
section.tokenCount += chunk.tokenCount;
212+
}
213+
}
188214
// Normalize section scores
189215
sections.forEach(section => section.score /= section.chunks.length);
216+
bm25Sections.forEach(section => section.score /= section.chunks.length);
190217

191218
// Sort sections by score and limit to maxSections
192219
sections.sort((a, b) => b.score - a.score);
220+
bm25Sections.sort((a, b) => b.score - a.score);
193221
if (sections.length > maxSections) {
194222
sections.splice(maxSections, sections.length - maxSections);
195223
}
224+
if (bm25Sections.length > maxSections) {
225+
bm25Sections.splice(maxSections, bm25Sections.length - maxSections);
226+
}
196227

197228
// Combine adjacent chunks of text
198229
sections.forEach(section => {
@@ -216,7 +247,8 @@ export class LocalDocumentResult extends LocalDocument {
216247
startPos: -1,
217248
endPos: -1,
218249
score: 0,
219-
tokenCount: this._tokenizer.encode('\n\n...\n\n').length
250+
tokenCount: this._tokenizer.encode('\n\n...\n\n').length,
251+
isBm25: false,
220252
};
221253
sections.forEach(section => {
222254
// Insert connectors between chunks
@@ -242,7 +274,8 @@ export class LocalDocumentResult extends LocalDocument {
242274
startPos: sectionStart - beforeBudget,
243275
endPos: sectionStart - 1,
244276
score: 0,
245-
tokenCount: beforeBudget
277+
tokenCount: beforeBudget,
278+
isBm25: false,
246279
};
247280
section.chunks.unshift(chunk);
248281
section.tokenCount += chunk.tokenCount;
@@ -258,7 +291,8 @@ export class LocalDocumentResult extends LocalDocument {
258291
startPos: sectionEnd + 1,
259292
endPos: sectionEnd + afterBudget,
260293
score: 0,
261-
tokenCount: afterBudget
294+
tokenCount: afterBudget,
295+
isBm25: false,
262296
};
263297
section.chunks.push(chunk);
264298
section.tokenCount += chunk.tokenCount;
@@ -268,16 +302,29 @@ export class LocalDocumentResult extends LocalDocument {
268302
});
269303
}
270304

271-
// Return final rendered sections
272-
return sections.map(section => {
305+
const semanticDocTextSections = sections.map(section => {
306+
let text = '';
307+
section.chunks.forEach(chunk => text += chunk.text);
308+
return {
309+
text: text,
310+
tokenCount: section.tokenCount,
311+
score: section.score,
312+
isBm25: false,
313+
};
314+
});
315+
const bm25DocTextSections = bm25Sections.map(section => {
273316
let text = '';
274317
section.chunks.forEach(chunk => text += chunk.text);
275318
return {
276319
text: text,
277320
tokenCount: section.tokenCount,
278-
score: section.score
321+
score: section.score,
322+
isBm25: true,
279323
};
280324
});
325+
326+
// Return final rendered sections
327+
return [...semanticDocTextSections, ...bm25DocTextSections];
281328
}
282329

283330
private encodeBeforeText(text: string, budget: number): number[] {
@@ -300,6 +347,7 @@ interface SectionChunk {
300347
endPos: number;
301348
score: number;
302349
tokenCount: number;
350+
isBm25: boolean;
303351
}
304352

305353
interface Section {

src/LocalIndex.ts

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@ import * as path from 'path';
33
import { v4 } from 'uuid';
44
import { ItemSelector } from './ItemSelector';
55
import { IndexItem, IndexStats, MetadataFilter, MetadataTypes, QueryResult } from './types';
6-
6+
import { LocalDocument } from './LocalDocument';
7+
import { LocalDocumentIndex } from './LocalDocumentIndex';
8+
import bm25 from 'wink-bm25-text-search';
9+
import winkNLP from 'wink-nlp';
10+
import model from 'wink-eng-lite-web-model';
711
export interface CreateIndexConfig {
812
version: number;
913
deleteIfExists?: boolean;
@@ -24,6 +28,8 @@ export class LocalIndex<TMetadata extends Record<string,MetadataTypes> = Record<
2428

2529
private _data?: IndexData;
2630
private _update?: IndexData;
31+
//member fields for BM25
32+
private _bm25Engine: any;
2733

2834
/**
2935
* Creates a new instance of LocalIndex.
@@ -247,7 +253,7 @@ export class LocalIndex<TMetadata extends Record<string,MetadataTypes> = Record<
247253
* @param filter Optional. Filter to apply.
248254
* @returns Similar items to the vector that matche the supplied filter.
249255
*/
250-
public async queryItems<TItemMetadata extends TMetadata = TMetadata>(vector: number[], topK: number, filter?: MetadataFilter): Promise<QueryResult<TItemMetadata>[]> {
256+
public async queryItems<TItemMetadata extends TMetadata = TMetadata>(vector: number[], query: string, topK: number, filter?: MetadataFilter, isBm25?: boolean): Promise<QueryResult<TItemMetadata>[]> {
251257
await this.loadIndexData();
252258

253259
// Filter items
@@ -285,6 +291,36 @@ export class LocalIndex<TMetadata extends Record<string,MetadataTypes> = Record<
285291
}
286292
}
287293

294+
//Peform bm25 search only if enabled. Avoid duplicate chunks, which are already selected during semantic search.
295+
if (isBm25) {
296+
const itemSet = new Set();
297+
for (const item of top) itemSet.add(item.item.id);
298+
299+
this.setupbm25();
300+
301+
let currDoc;
302+
let currDocTxt;
303+
for (let i = 0; i < items.length; i++) {
304+
if (!itemSet.has(items[i].id)) {
305+
const item = items[i];
306+
currDoc = new LocalDocument((this as unknown) as LocalDocumentIndex, item.metadata.documentId.toString(), '');
307+
currDocTxt = await currDoc.loadText();
308+
const startPos = item.metadata.startPos;
309+
const endPos = item.metadata.endPos;
310+
const chunkText = currDocTxt.substring(Number(startPos), Number(endPos) + 1);
311+
this._bm25Engine.addDoc({body: chunkText}, i);
312+
}
313+
}
314+
this._bm25Engine.consolidate();
315+
var results = await this.bm25Search(query, items, topK);
316+
results.forEach((res: any) => {
317+
top.push({
318+
item: Object.assign({}, {...items[res[0]], metadata: {...items[res[0]].metadata, isBm25: true}}) as any,
319+
score: res[1]
320+
});
321+
});
322+
323+
}
288324
return top;
289325
}
290326

@@ -385,6 +421,37 @@ export class LocalIndex<TMetadata extends Record<string,MetadataTypes> = Record<
385421
return newItem;
386422
}
387423
}
424+
425+
private async setupbm25(): Promise<any> {
426+
this._bm25Engine = bm25();
427+
const nlp = winkNLP( model );
428+
const its = nlp.its;
429+
430+
const prepTask = function ( text: string ) {
431+
const tokens: any[] = [];
432+
nlp.readDoc(text)
433+
.tokens()
434+
// Use only words ignoring punctuations etc and from them remove stop words
435+
.filter( (t: any) => ( t.out(its.type) === 'word' && !t.out(its.stopWordFlag) ) )
436+
// Handle negation and extract stem of the word
437+
.each( (t: any) => tokens.push( (t.out(its.negationFlag)) ? '!' + t.out(its.stem) : t.out(its.stem) ) );
438+
439+
return tokens;
440+
};
441+
442+
this._bm25Engine.defineConfig( { fldWeights: { body: 1 } } );
443+
// Step II: Define PrepTasks pipe.
444+
this._bm25Engine.definePrepTasks( [ prepTask ] );
445+
}
446+
447+
private async bm25Search(searchQuery: string, items: any, topK: number): Promise<any> {
448+
var query = searchQuery;
449+
// `results` is an array of [ doc-id, score ], sorted by score
450+
var results = this._bm25Engine.search( query );
451+
452+
return results.slice(0, topK);
453+
}
454+
388455
}
389456

390457
interface IndexData {

src/internals/Colorize.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ export class Colorize {
1616
}
1717
}
1818

19-
public static output(output: object | string, quote: string = '', units: string = ''): string {
19+
public static output(output: object | string, isBm25: boolean = false, quote: string = '', units: string = ''): string {
2020
if (typeof output === 'string') {
21-
return `\x1b[32m${quote}${output}${quote}\x1b[0m`;
21+
return isBm25 ? `\x1b[34m${quote}${output}${quote}\x1b[0m` : `\x1b[32m${quote}${output}${quote}\x1b[0m`;
2222
} else if (typeof output === 'object' && output !== null) {
2323
return colorizer(output, {
2424
pretty: true,
@@ -54,7 +54,7 @@ export class Colorize {
5454
}
5555

5656
public static value(field: string, value: any, units: string = ''): string {
57-
return `${field}: ${Colorize.output(value, '"', units)}`;
57+
return `${field}: ${Colorize.output(value, false, '"', units)}`;
5858
}
5959

6060
public static warning(warning: string): string {
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
declare module 'wink-bm25-text-search' {
2+
const bm25: any;
3+
export default bm25;
4+
}

src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,5 @@ export interface DocumentTextSection {
172172
text: string;
173173
tokenCount: number;
174174
score: number;
175+
isBm25: boolean;
175176
}

0 commit comments

Comments
 (0)