22from pathlib import Path
33
44import faiss
5+ import numpy as np
56from datasets import Dataset
6- from langchain .text_splitter import RecursiveCharacterTextSplitter
77from sentence_transformers import SentenceTransformer
88
99from delphi .config import CacheConfig
1010from delphi .logger import logger
1111
1212
13- def get_neighbors_by_id (index : faiss .IndexIDMap , vector_id : int , k : int = 10 ):
14- # First reconstruct the vector for the given ID
15- vector = index .reconstruct (vector_id )
16-
17- # Reshape to match FAISS expectations (needs 2D array)
18- vector = vector .reshape (1 , - 1 )
19-
20- # Search for nearest neighbors
21- distances , neighbor_ids = index .search (
22- vector , k + 1
23- ) # k+1 since it will find itself
13+ def get_neighbors (model , index , query : str , k : int = 1000 ):
14+ q_embedding = model .encode ([query ])
15+ result = index .search (q_embedding , k = k )
16+ # result: tuple of (L2 distances, top match indices).
17+ # supports matrix indexing for some reason so the top match index
18+ # requires two indices
19+ result [1 ][0 ][0 ]
20+ # text_data[first_result]
2421
2522 # Remove the first result (which will be the query vector itself)
26- return distances [0 ][1 :], neighbor_ids [0 ][1 :]
23+ # return distances[0][1:], neighbor_ids[0][1:]
2724
2825
2926def get_index_path (base_path : Path , cfg : CacheConfig ):
@@ -46,124 +43,40 @@ def save_index(index: faiss.IndexFlatL2, base_path: Path, cfg: CacheConfig):
4643 json .dump (
4744 {
4845 "index_path" : str (index_path ),
49- "embedding_model" : "sentence-transformers/all-MiniLM-L6-v2" ,
46+ "embedding_model" : cfg . faiss_embedding_model ,
5047 },
5148 f ,
5249 )
5350
5451
55- # def split_text(text: str, cfg: CacheConfig):
56- # splitter = RecursiveCharacterTextSplitter(
57- # chunk_size=cfg.ctx_len, chunk_overlap=cfg.ctx_len // 4
58- # )
59- # return splitter.split_text(text)
60-
61-
62- def split_text (text : str , cfg : CacheConfig ):
63- splitter = RecursiveCharacterTextSplitter (
64- chunk_size = cfg .ctx_len ,
65- chunk_overlap = cfg .ctx_len // 4 ,
66- length_function = lambda x : 1 + len (x ) // 4 ,
67- )
68- return splitter .split_text (text )
69-
70-
71- def build_semantic_index (data : Dataset , cfg : CacheConfig ):
52+ def build_semantic_index (data : Dataset , cfg : CacheConfig , batch_size : int = 1024 ):
7253 """
73- Build a semantic index of the token sequences .
54+ Build a semantic index, assuming data['text'] is of appropriate length .
7455 """
7556
7657 model = SentenceTransformer (cfg .faiss_embedding_model , device = "cuda" )
77- d = next ( model . parameters ()). dtype
58+ d = model [ 1 ]. word_embedding_dimension
7859
7960 index = faiss .IndexHNSWFlat (d , cfg .faiss_hnsw_config ["M" ])
8061 index .hnsw .efConstruction = cfg .faiss_hnsw_config ["efConstruction" ]
8162 index .hnsw .efSearch = cfg .faiss_hnsw_config ["efSearch" ]
8263
83- data ["text" ]
84- breakpoint ()
85-
86- # index_tokenizer = AutoTokenizer.from_pretrained
87- # ('sentence-transformers/all-MiniLM-L6-v2')
88- # index_model = AutoModel.from_pretrained(
89- # 'sentence-transformers/all-MiniLM-L6-v2').to("cuda")
90-
91- # index_tokens = chunk_and_tokenize(data, index_tokenizer, max_seq_len=cfg.ctx_len,
92- # text_key=cfg.dataset_row)
93- # index_tokens = index_tokens["input_ids"]
94- # index_tokens = assert_type(Tensor, index_tokens)
95-
96- # token_embeddings = index_model(index_tokens[:2].to("cuda")).last_hidden_state
97-
98- # base_index = faiss.IndexFlatL2(token_embeddings.shape[-1])
99- # index = faiss.IndexIDMap(base_index)
100-
101- # batch_size = 512
102- # dataloader = DataLoader(index_tokens, batch_size=batch_size) # type: ignore
103-
104- # from tqdm import tqdm
105- # with torch.no_grad():
106- # for batch_idx, batch in enumerate(tqdm(dataloader)):
107- # batch = batch.to("cuda")
108- # token_embeddings = index_model(batch).last_hidden_state
109- # sentence_embeddings = token_embeddings.mean(dim=1)
110- # sentence_embeddings = sentence_embeddings.cpu().numpy().astype(np.float32)
111-
112- # ids = np.arange(batch_idx * batch_size, batch_idx * batch_size +
113- # len(batch))
114- # index.add_with_ids(sentence_embeddings, ids)
115-
116- return None
117- # """
118- # Build a semantic index of the token sequences.
119- # """
120-
121- # model = SentenceTransformer(cfg.faiss_embedding_model, device="cuda")
122- # d = next(model.parameters()).dtype
123-
124- # text = data['text']
125- # chunks = []
126- # for t in text:
127- # chunks.extend(split_text(t, cfg))
128-
129- # breakpoint()
130- # index = faiss.IndexHNSWFlat(d, cfg.faiss_hnsw_config["M"])
131- # index.metric_type = faiss.METRIC_L2
132- # index.hnsw.efConstruction = cfg.faiss_hnsw_config["efConstruction"]
133- # index.hnsw.efSearch = cfg.faiss_hnsw_config["efSearch"]
134-
135- # index_tokenizer = AutoTokenizer.from_pretraine
136- # d('sentence-transformers/all-MiniLM-L6-v2')
137- # index_model = AutoModel.from_pretrained('sentence-transform
138- # ers/all-MiniLM-L6-v2').to("cuda")
139-
140- # index_tokens = chunk_and_tokenize(data, index_tokenizer,
141- # max_seq_len=cfg.ctx_len, text_key=cfg.dataset_row)
142- # index_tokens = index_tokens["input_ids"]
143- # index_tokens = assert_type(Tensor, index_tokens)
144-
145- # token_embeddings = index_model(index_tokens[:2].to("cuda")).last_hidden_state
146-
147- # base_index = faiss.IndexFlatL2(token_embeddings.shape[-1])
148- # index = faiss.IndexIDMap(base_index)
149-
150- # batch_size = 512
151- # dataloader = DataLoader(index_tokens, batch_size=batch_size) # type: ignore
152-
153- # from tqdm import tqdm
154- # with torch.no_grad():
155- # for batch_idx, batch in enumerate(tqdm(dataloader)):
156- # batch = batch.to("cuda")
157- # token_embeddings = index_model(batch).last_hidden_state
158- # sentence_embeddings = token_embeddings.mean(dim=1)
159- # sentence_embeddings = sentence_embeddings.cpu().numpy()
160- # .astype(np.float32)
161-
162- # ids = np.arange(batch_idx * batch_size, batch_idx * batch_size
163- # + len(batch))
164- # index.add_with_ids(sentence_embeddings, ids)
165-
166- # return None
64+ text_data = data ["text" ]
65+
66+ embeddings = []
67+ for i in range (0 , len (text_data ), batch_size ):
68+ print (f"Processing batch { i } of { len (text_data )} " )
69+ batch = text_data [i : i + batch_size ]
70+ batch_embeddings = model .encode (
71+ batch , batch_size = batch_size , device = "cuda" , convert_to_numpy = True
72+ )
73+ embeddings .append (batch_embeddings )
74+
75+ embeddings = np .vstack (embeddings )
76+
77+ index .add (embeddings ) # type: ignore
78+
79+ return index
16780
16881
16982def build_or_load_index (data : Dataset , base_path : Path , cfg : CacheConfig ):
0 commit comments