-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathembeddings.py
More file actions
141 lines (118 loc) · 4.84 KB
/
embeddings.py
File metadata and controls
141 lines (118 loc) · 4.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
Embedding Engine
Converts text → dense vector representations for semantic search
Uses sentence-transformers (all-MiniLM-L6-v2, 384-dim, fast & accurate)
"""
import logging
import asyncio
import hashlib
import json
import os
from typing import List, Dict, Union
from functools import lru_cache
logger = logging.getLogger(__name__)
class EmbeddingEngine:
"""
Wraps sentence-transformers to produce 384-dimensional embeddings.
Includes an in-process LRU cache to avoid re-embedding identical text.
"""
MODEL_NAME = "all-MiniLM-L6-v2" # 384-dim | 22 MB | ~14k tokens/sec on CPU
DIMENSION = 384
def __init__(self):
self._model = None
self._cache: Dict[str, List[float]] = {} # text_hash → vector
self._cache_hits = 0
self._cache_misses = 0
def _load_model(self):
"""Lazy-load the model on first use (avoids startup overhead)."""
if self._model is None:
try:
from sentence_transformers import SentenceTransformer
logger.info(f"Loading embedding model: {self.MODEL_NAME}")
self._model = SentenceTransformer(self.MODEL_NAME)
logger.info("Embedding model loaded ✅")
except ImportError:
logger.warning("sentence-transformers not installed – using deterministic mock embeddings")
self._model = "MOCK"
def _text_hash(self, text: str) -> str:
return hashlib.md5(text.encode()).hexdigest()
def _mock_embed(self, text: str) -> List[float]:
"""Deterministic mock embedding for CI / offline testing."""
import math
seed = int(self._text_hash(text), 16) % (2 ** 32)
# Simple lcg
state = seed
vec = []
for _ in range(self.DIMENSION):
state = (state * 1664525 + 1013904223) & 0xFFFFFFFF
vec.append((state / 0xFFFFFFFF) * 2 - 1)
# L2-normalise
norm = math.sqrt(sum(x ** 2 for x in vec)) or 1.0
return [x / norm for x in vec]
def embed_single(self, text: str) -> List[float]:
"""Embed a single text string. Uses cache."""
self._load_model()
key = self._text_hash(text.strip().lower())
if key in self._cache:
self._cache_hits += 1
return self._cache[key]
self._cache_misses += 1
if self._model == "MOCK":
vector = self._mock_embed(text)
else:
vector = self._model.encode(text, normalize_embeddings=True).tolist()
self._cache[key] = vector
return vector
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""
Embed a batch of texts efficiently.
Texts already in cache are returned immediately; uncached texts are
sent to the model in a single batch call.
"""
self._load_model()
# Partition: cached vs. uncached
cached_map: Dict[int, List[float]] = {}
uncached_indices: List[int] = []
uncached_texts: List[str] = []
for i, text in enumerate(texts):
key = self._text_hash(text.strip().lower())
if key in self._cache:
cached_map[i] = self._cache[key]
self._cache_hits += 1
else:
uncached_indices.append(i)
uncached_texts.append(text)
self._cache_misses += 1
# Batch encode uncached texts
if uncached_texts:
if self._model == "MOCK":
new_vectors = [self._mock_embed(t) for t in uncached_texts]
else:
new_vectors = self._model.encode(
uncached_texts,
normalize_embeddings=True,
batch_size=64,
show_progress_bar=len(uncached_texts) > 100
).tolist()
for idx, vec, text in zip(uncached_indices, new_vectors, uncached_texts):
key = self._text_hash(text.strip().lower())
self._cache[key] = vec
cached_map[idx] = vec
return [cached_map[i] for i in range(len(texts))]
async def async_embed(self, text: str) -> List[float]:
"""Async wrapper — runs embedding in a thread pool."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.embed_single, text)
async def async_embed_batch(self, texts: List[str]) -> List[List[float]]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.embed_batch, texts)
@property
def cache_stats(self) -> Dict[str, int]:
return {
"hits": self._cache_hits,
"misses": self._cache_misses,
"cached_entries": len(self._cache),
"hit_rate_pct": round(
100 * self._cache_hits / max(1, self._cache_hits + self._cache_misses), 1
)
}