Skip to content

Commit cdebc3f

Browse files
committed
Update get neighbours
1 parent 9102ee9 commit cdebc3f

File tree

1 file changed

+30
-117
lines changed

1 file changed

+30
-117
lines changed

delphi/semantic_index/index.py

Lines changed: 30 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,25 @@
22
from pathlib import Path
33

44
import faiss
5+
import numpy as np
56
from datasets import Dataset
6-
from langchain.text_splitter import RecursiveCharacterTextSplitter
77
from sentence_transformers import SentenceTransformer
88

99
from delphi.config import CacheConfig
1010
from delphi.logger import logger
1111

1212

13-
def get_neighbors_by_id(index: faiss.IndexIDMap, vector_id: int, k: int = 10):
14-
# First reconstruct the vector for the given ID
15-
vector = index.reconstruct(vector_id)
16-
17-
# Reshape to match FAISS expectations (needs 2D array)
18-
vector = vector.reshape(1, -1)
19-
20-
# Search for nearest neighbors
21-
distances, neighbor_ids = index.search(
22-
vector, k + 1
23-
) # k+1 since it will find itself
13+
def get_neighbors(model, index, query: str, k: int = 1000):
14+
q_embedding = model.encode([query])
15+
result = index.search(q_embedding, k=k)
16+
# result: tuple of (L2 distances, top match indices).
17+
# supports matrix indexing for some reason so the top match index
18+
# requires two indices
19+
result[1][0][0]
20+
# text_data[first_result]
2421

2522
# Remove the first result (which will be the query vector itself)
26-
return distances[0][1:], neighbor_ids[0][1:]
23+
# return distances[0][1:], neighbor_ids[0][1:]
2724

2825

2926
def get_index_path(base_path: Path, cfg: CacheConfig):
@@ -46,124 +43,40 @@ def save_index(index: faiss.IndexFlatL2, base_path: Path, cfg: CacheConfig):
4643
json.dump(
4744
{
4845
"index_path": str(index_path),
49-
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
46+
"embedding_model": cfg.faiss_embedding_model,
5047
},
5148
f,
5249
)
5350

5451

55-
# def split_text(text: str, cfg: CacheConfig):
56-
# splitter = RecursiveCharacterTextSplitter(
57-
# chunk_size=cfg.ctx_len, chunk_overlap=cfg.ctx_len // 4
58-
# )
59-
# return splitter.split_text(text)
60-
61-
62-
def split_text(text: str, cfg: CacheConfig):
63-
splitter = RecursiveCharacterTextSplitter(
64-
chunk_size=cfg.ctx_len,
65-
chunk_overlap=cfg.ctx_len // 4,
66-
length_function=lambda x: 1 + len(x) // 4,
67-
)
68-
return splitter.split_text(text)
69-
70-
71-
def build_semantic_index(data: Dataset, cfg: CacheConfig):
52+
def build_semantic_index(data: Dataset, cfg: CacheConfig, batch_size: int = 1024):
7253
"""
73-
Build a semantic index of the token sequences.
54+
Build a semantic index, assuming data['text'] is of appropriate length.
7455
"""
7556

7657
model = SentenceTransformer(cfg.faiss_embedding_model, device="cuda")
77-
d = next(model.parameters()).dtype
58+
d = model[1].word_embedding_dimension
7859

7960
index = faiss.IndexHNSWFlat(d, cfg.faiss_hnsw_config["M"])
8061
index.hnsw.efConstruction = cfg.faiss_hnsw_config["efConstruction"]
8162
index.hnsw.efSearch = cfg.faiss_hnsw_config["efSearch"]
8263

83-
data["text"]
84-
breakpoint()
85-
86-
# index_tokenizer = AutoTokenizer.from_pretrained
87-
# ('sentence-transformers/all-MiniLM-L6-v2')
88-
# index_model = AutoModel.from_pretrained(
89-
# 'sentence-transformers/all-MiniLM-L6-v2').to("cuda")
90-
91-
# index_tokens = chunk_and_tokenize(data, index_tokenizer, max_seq_len=cfg.ctx_len,
92-
# text_key=cfg.dataset_row)
93-
# index_tokens = index_tokens["input_ids"]
94-
# index_tokens = assert_type(Tensor, index_tokens)
95-
96-
# token_embeddings = index_model(index_tokens[:2].to("cuda")).last_hidden_state
97-
98-
# base_index = faiss.IndexFlatL2(token_embeddings.shape[-1])
99-
# index = faiss.IndexIDMap(base_index)
100-
101-
# batch_size = 512
102-
# dataloader = DataLoader(index_tokens, batch_size=batch_size) # type: ignore
103-
104-
# from tqdm import tqdm
105-
# with torch.no_grad():
106-
# for batch_idx, batch in enumerate(tqdm(dataloader)):
107-
# batch = batch.to("cuda")
108-
# token_embeddings = index_model(batch).last_hidden_state
109-
# sentence_embeddings = token_embeddings.mean(dim=1)
110-
# sentence_embeddings = sentence_embeddings.cpu().numpy().astype(np.float32)
111-
112-
# ids = np.arange(batch_idx * batch_size, batch_idx * batch_size +
113-
# len(batch))
114-
# index.add_with_ids(sentence_embeddings, ids)
115-
116-
return None
117-
# """
118-
# Build a semantic index of the token sequences.
119-
# """
120-
121-
# model = SentenceTransformer(cfg.faiss_embedding_model, device="cuda")
122-
# d = next(model.parameters()).dtype
123-
124-
# text = data['text']
125-
# chunks = []
126-
# for t in text:
127-
# chunks.extend(split_text(t, cfg))
128-
129-
# breakpoint()
130-
# index = faiss.IndexHNSWFlat(d, cfg.faiss_hnsw_config["M"])
131-
# index.metric_type = faiss.METRIC_L2
132-
# index.hnsw.efConstruction = cfg.faiss_hnsw_config["efConstruction"]
133-
# index.hnsw.efSearch = cfg.faiss_hnsw_config["efSearch"]
134-
135-
# index_tokenizer = AutoTokenizer.from_pretraine
136-
# d('sentence-transformers/all-MiniLM-L6-v2')
137-
# index_model = AutoModel.from_pretrained('sentence-transform
138-
# ers/all-MiniLM-L6-v2').to("cuda")
139-
140-
# index_tokens = chunk_and_tokenize(data, index_tokenizer,
141-
# max_seq_len=cfg.ctx_len, text_key=cfg.dataset_row)
142-
# index_tokens = index_tokens["input_ids"]
143-
# index_tokens = assert_type(Tensor, index_tokens)
144-
145-
# token_embeddings = index_model(index_tokens[:2].to("cuda")).last_hidden_state
146-
147-
# base_index = faiss.IndexFlatL2(token_embeddings.shape[-1])
148-
# index = faiss.IndexIDMap(base_index)
149-
150-
# batch_size = 512
151-
# dataloader = DataLoader(index_tokens, batch_size=batch_size) # type: ignore
152-
153-
# from tqdm import tqdm
154-
# with torch.no_grad():
155-
# for batch_idx, batch in enumerate(tqdm(dataloader)):
156-
# batch = batch.to("cuda")
157-
# token_embeddings = index_model(batch).last_hidden_state
158-
# sentence_embeddings = token_embeddings.mean(dim=1)
159-
# sentence_embeddings = sentence_embeddings.cpu().numpy()
160-
# .astype(np.float32)
161-
162-
# ids = np.arange(batch_idx * batch_size, batch_idx * batch_size
163-
# + len(batch))
164-
# index.add_with_ids(sentence_embeddings, ids)
165-
166-
# return None
64+
text_data = data["text"]
65+
66+
embeddings = []
67+
for i in range(0, len(text_data), batch_size):
68+
print(f"Processing batch {i} of {len(text_data)}")
69+
batch = text_data[i : i + batch_size]
70+
batch_embeddings = model.encode(
71+
batch, batch_size=batch_size, device="cuda", convert_to_numpy=True
72+
)
73+
embeddings.append(batch_embeddings)
74+
75+
embeddings = np.vstack(embeddings)
76+
77+
index.add(embeddings) # type: ignore
78+
79+
return index
16780

16881

16982
def build_or_load_index(data: Dataset, base_path: Path, cfg: CacheConfig):

0 commit comments

Comments
 (0)