-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathembeddings.py
88 lines (75 loc) · 3.41 KB
/
embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
from typing import List, Optional, Generator, Tuple
import numpy
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
def torch_device():
if torch.backends.mps.is_available():
return 'mps'
elif torch.cuda.is_available():
return 'cuda'
else:
return 'cpu'
class EmbeddingModel:
def __init__(self, model_name: str, device: str = torch_device()):
self.device = device
print(f'Using {self.device}')
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(self.device)
print(f"Using {model_name}")
def tokenize_and_embed(self, strings: List[str]) -> torch.Tensor:
encoded_input = self.tokenizer(strings, padding=True, truncation=True, return_tensors='pt').to(self.device)
with torch.no_grad():
model_output = self.model(**encoded_input)
embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
embeddings = F.normalize(embeddings, p=2, dim=1)
return embeddings
def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def embeddings_loader(file_path_pattern, max_total: Optional[int] = None) -> Generator[numpy.array, None, None]:
"""Loads up to max_total (or all) embeddings from files following file_path_pattern which has to be numpy files
including a {i} placeholder that will be replaced by a 0-based index. We do this to consume embeddings in a
defined order."""
total = 0
file_idx = 0
while True:
file_path = file_path_pattern.replace('{i}', str(file_idx))
if not os.path.isfile(file_path):
break
embeddings = np.load(file_path, mmap_mode='r')
for embedding in embeddings:
if total == max_total:
break
yield embedding
total += 1
if total == max_total:
break
file_idx += 1
def batch_embeddings_loader(file_path_pattern, batch_size: int, max_total: Optional[int]) -> Generator[Tuple[int, torch.Tensor], None, None]:
"""Load batched numpy arrays from files following a file{i} pattern (with 0-based increasing i) and return them as Tensors"""
device = torch_device()
total = 0
batch_start_idx = 0
file_idx = 0
while True:
file_path = file_path_pattern.replace('{i}', str(file_idx))
if not os.path.isfile(file_path):
break
embeddings = np.load(file_path, mmap_mode='r')
for i in range(0, embeddings.shape[0], batch_size):
adjusted_batch_size = min(batch_size, len(embeddings) - i)
if max_total is not None:
adjusted_batch_size = min(adjusted_batch_size, max_total - total)
if adjusted_batch_size == 0:
break
embeddings_slice = embeddings[i:i+adjusted_batch_size]
total += adjusted_batch_size
yield batch_start_idx, torch.tensor(embeddings_slice).to(device)
batch_start_idx += adjusted_batch_size
if total == max_total:
break
file_idx += 1