Skip to content

Commit 1d548fc

Browse files
authored
Merge pull request #16 from ClipABit/pr/search-module-4
Pr/search module 4
2 parents 6f45f5a + 5d4f4d1 commit 1d548fc

File tree

3 files changed

+206
-0
lines changed

3 files changed

+206
-0
lines changed

backend/search/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""
2+
Search module for semantic search using CLIP embeddings and Pinecone.
3+
"""
4+
5+
from search.embedder import TextEmbedder
6+
from search.searcher import Searcher
7+
8+
__all__ = ["TextEmbedder", "Searcher"]

backend/search/embedder.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import logging
2+
from typing import Union, List
3+
import numpy as np
4+
import torch
5+
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
6+
7+
logging.basicConfig(level=logging.INFO)
8+
logger = logging.getLogger(__name__)
9+
10+
"""
11+
Text Embedding module using CLIP model.
12+
13+
Provides text-to-vector conversion using OpenAI's CLIP model
14+
for semantic search capabilities.
15+
"""
16+
17+
18+
class TextEmbedder:
19+
"""
20+
CLIP-based text embedder for semantic search.
21+
22+
Converts text queries into 512-dimensional embeddings using
23+
OpenAI's CLIP text model (ViT-B/32 variant).
24+
25+
Uses CLIPTextModelWithProjection for efficiency (loads only text encoder,
26+
not the full CLIP model with vision encoder).
27+
28+
Usage:
29+
embedder = TextEmbedder()
30+
vector = embedder.embed_text("woman on a train")
31+
"""
32+
33+
def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
34+
"""
35+
Initialize the text embedder.
36+
37+
Args:
38+
model_name: HuggingFace model identifier for CLIP.
39+
Defaults to "openai/clip-vit-base-patch32".
40+
"""
41+
self.model_name = model_name
42+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
43+
self.model = None
44+
self.tokenizer = None
45+
46+
logger.info(f"TextEmbedder initialized (device: {self.device})")
47+
48+
def _load_model(self):
49+
"""Lazy load the CLIP text model on first use."""
50+
if self.model is None:
51+
logger.info(f"Loading CLIP text model: {self.model_name}")
52+
self.tokenizer = CLIPTokenizer.from_pretrained(self.model_name)
53+
self.model = CLIPTextModelWithProjection.from_pretrained(self.model_name).to(self.device)
54+
self.model.eval()
55+
logger.info("CLIP text model loaded successfully")
56+
57+
def embed_text(self, text: Union[str, List[str]]) -> np.ndarray:
58+
"""
59+
Generate embeddings for text input(s).
60+
61+
Args:
62+
text: Single text string or list of text strings
63+
64+
Returns:
65+
numpy array of embeddings (512-d, L2-normalized)
66+
Shape: (512,) for single text, (N, 512) for batch
67+
"""
68+
self._load_model()
69+
70+
# Handle single string
71+
if isinstance(text, str):
72+
text = [text]
73+
74+
# Tokenize inputs
75+
inputs = self.tokenizer(
76+
text,
77+
return_tensors="pt",
78+
padding=True,
79+
truncation=True,
80+
max_length=77 # CLIP's max sequence length
81+
).to(self.device)
82+
83+
# Generate embeddings
84+
with torch.no_grad():
85+
# CLIPTextModelWithProjection outputs already-projected features
86+
text_features = self.model(**inputs).text_embeds
87+
# L2 normalize (essential for cosine similarity search)
88+
text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
89+
90+
# Convert to numpy
91+
embeddings = text_features.cpu().numpy()
92+
93+
# Return single vector if single input
94+
if len(embeddings) == 1:
95+
return embeddings[0]
96+
97+
return embeddings

backend/search/searcher.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""
2+
Semantic Searcher using Pinecone vector database.
3+
4+
Coordinates text embedding and vector search to find semantically
5+
similar content.
6+
"""
7+
8+
import logging
9+
from typing import List, Dict, Any
10+
11+
from database.pinecone_connector import PineconeConnector
12+
from search.embedder import TextEmbedder
13+
14+
logging.basicConfig(level=logging.INFO)
15+
logger = logging.getLogger(__name__)
16+
17+
18+
class Searcher:
19+
"""
20+
High-level semantic search coordinator.
21+
22+
Combines text embedding and Pinecone vector search to provide
23+
an easy-to-use interface for semantic similarity search.
24+
25+
Usage:
26+
searcher = Searcher(api_key="...", index_name="chunks-index")
27+
results = searcher.search("woman on a train", top_k=3)
28+
"""
29+
30+
def __init__(
31+
self,
32+
api_key: str,
33+
index_name: str,
34+
namespace: str = "__default__"
35+
):
36+
"""
37+
Initialize searcher with Pinecone connection.
38+
39+
Args:
40+
api_key: Pinecone API key
41+
index_name: Name of Pinecone index to search
42+
namespace: Optional namespace for partitioning data
43+
"""
44+
self.embedder = TextEmbedder()
45+
self.connector = PineconeConnector(api_key=api_key, index_name=index_name)
46+
self.namespace = namespace
47+
48+
logger.info(
49+
f"Searcher initialized (index={index_name}, namespace='{namespace}')"
50+
)
51+
52+
@property
53+
def device(self) -> str:
54+
"""Get the device being used for embeddings (cpu/cuda)."""
55+
return self.embedder.device
56+
57+
def search(
58+
self,
59+
query: str,
60+
top_k: int = 5
61+
) -> List[Dict[str, Any]]:
62+
"""
63+
Search for semantically similar content.
64+
65+
Args:
66+
query: Natural language search query
67+
top_k: Number of results to return (default: 5)
68+
69+
Returns:
70+
List of matches with scores and metadata, sorted by similarity
71+
72+
Example:
73+
results = searcher.search("cooking in kitchen", top_k=3)
74+
for result in results:
75+
print(f"Score: {result['score']}")
76+
print(f"Metadata: {result['metadata']}")
77+
"""
78+
logger.info(f"Searching for: '{query}' (top_k={top_k})")
79+
80+
# Generate query embedding
81+
query_embedding = self.embedder.embed_text(query)
82+
83+
# Search Pinecone with optional filters
84+
matches = self.connector.query_chunks(
85+
query_embedding=query_embedding,
86+
namespace=self.namespace,
87+
top_k=top_k
88+
)
89+
90+
# Format results
91+
results = []
92+
for match in matches:
93+
result = {
94+
'id': match.get('id'),
95+
'score': match.get('score', 0.0),
96+
'metadata': match.get('metadata', {})
97+
}
98+
results.append(result)
99+
100+
logger.info(f"Found {len(results)} results")
101+
return results

0 commit comments

Comments
 (0)