Skip to content

Commit 87f2d70

Browse files
Correctly implement datapizza-ai RAG using QdrantVectorstore and FastEmbedder
- Replace custom implementation with datapizza-ai native RAG - Use QdrantVectorstore.as_retriever() pattern as per framework - Use FastEmbedder for sparse embeddings (Splade model) - Create Chunk objects with embeddings for vectorstore - Update tests to mock datapizza-ai components (QdrantVectorstore, FastEmbedder, Chunk) - Maintain same public API (ingest_pdfs, build_or_load_index, retrieve) - Follow datapizza-ai RAG guide: vectorstore + embedder + Chunk pattern Co-authored-by: merendamattia <[email protected]>
1 parent f17ed88 commit 87f2d70

File tree

2 files changed

+243
-260
lines changed

2 files changed

+243
-260
lines changed

src/retrieval/asset_retriever.py

Lines changed: 113 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,20 @@
1-
"""RAG Asset Retriever Module.
1+
"""RAG Asset Retriever Module using Datapizza-AI.
22
33
This module provides the RAGAssetRetriever class for semantic search
44
over ETF/asset PDFs in the dataset directory using datapizza-ai's
5-
Qdrant vector store for efficient semantic search.
5+
native RAG implementation with QdrantVectorstore and FastEmbedder.
66
"""
77

88
import logging
9-
import math
109
import os
11-
import pickle
1210
from pathlib import Path
1311
from typing import Dict, List, Optional, Tuple
1412

15-
import numpy as np
13+
from datapizza.embedders.fastembedder import FastEmbedder
14+
from datapizza.type.type import Chunk
15+
from datapizza.vectorstores.qdrant import QdrantVectorstore
1616
from dotenv import load_dotenv
1717
from pypdf import PdfReader
18-
from qdrant_client import QdrantClient
19-
from qdrant_client.models import Distance, PointStruct, VectorParams
20-
from sentence_transformers import SentenceTransformer
2118

2219
# Load environment variables
2320
load_dotenv()
@@ -29,52 +26,62 @@
2926
DATA_DIR = Path(os.getenv("RAG_DATA_DIR", "dataset/ETFs"))
3027
CACHE_DIR = Path(os.getenv("RAG_CACHE_DIR", "dataset/ETFs/.cache"))
3128
CACHE_DIR.mkdir(parents=True, exist_ok=True)
32-
EMB_CACHE = Path(
33-
os.getenv("RAG_EMBEDDINGS_CACHE", "dataset/ETFs/.cache/embeddings.pkl")
34-
)
3529

36-
# Qdrant storage directory (part of datapizza-ai ecosystem)
30+
# Qdrant storage directory (datapizza-ai vector store)
3731
QDRANT_STORAGE_DIR = CACHE_DIR / "qdrant_storage"
3832
QDRANT_COLLECTION = "financial_assets"
3933

4034
DEFAULT_CHUNK_SIZE = int(os.getenv("RAG_CHUNK_SIZE", "800"))
4135
DEFAULT_CHUNK_OVERLAP = int(os.getenv("RAG_CHUNK_OVERLAP", "120"))
42-
EMB_MODEL_NAME = os.getenv("RAG_EMBEDDING_MODEL", "all-roberta-large-v1")
36+
# Use sparse embedding model compatible with datapizza-ai FastEmbedder
37+
EMB_MODEL_NAME = os.getenv("RAG_EMBEDDING_MODEL", "prithivida/Splade_PP_en_v1")
4338

44-
# Global embedding model cache
45-
_embedding_model = None
39+
# Global embedder and vectorstore cache
40+
_embedder = None
41+
_vectorstore = None
4642

4743

4844
class RAGAssetRetriever:
49-
"""RAG retriever for asset PDFs using datapizza-ai's Qdrant vector store."""
45+
"""RAG retriever for asset PDFs using datapizza-ai framework."""
5046

5147
def __init__(self, data_dir: Path = DATA_DIR):
52-
"""Initialize RAG retriever with Qdrant vector store.
48+
"""Initialize RAG retriever with datapizza-ai components.
5349
5450
Args:
5551
data_dir: Path to the ETF dataset directory
5652
"""
5753
self.data_dir = data_dir
5854
self.cache_dir = CACHE_DIR
59-
self.emb_cache = EMB_CACHE
6055
self.qdrant_path = QDRANT_STORAGE_DIR
6156
self._documents = None
62-
self._embeddings = None
63-
self._qdrant_client: Optional[QdrantClient] = None
57+
self._embedder: Optional[FastEmbedder] = None
58+
self._vectorstore: Optional[QdrantVectorstore] = None
59+
self._retriever = None
6460
self._is_indexed = False
6561

66-
@staticmethod
67-
def _load_embedder():
68-
"""Load and cache the embedding model globally.
62+
def _get_embedder(self) -> FastEmbedder:
63+
"""Get or create the datapizza-ai FastEmbedder.
6964
7065
Returns:
71-
SentenceTransformer model instance
66+
FastEmbedder instance
7267
"""
73-
global _embedding_model
74-
if _embedding_model is None:
75-
logger.info("Loading embedding model: %s", EMB_MODEL_NAME)
76-
_embedding_model = SentenceTransformer(EMB_MODEL_NAME)
77-
return _embedding_model
68+
global _embedder
69+
if _embedder is None:
70+
logger.info("Loading FastEmbedder model: %s", EMB_MODEL_NAME)
71+
_embedder = FastEmbedder(model_name=EMB_MODEL_NAME)
72+
return _embedder
73+
74+
def _get_vectorstore(self) -> QdrantVectorstore:
75+
"""Get or create the datapizza-ai QdrantVectorstore.
76+
77+
Returns:
78+
QdrantVectorstore instance
79+
"""
80+
global _vectorstore
81+
if _vectorstore is None:
82+
logger.info("Initializing QdrantVectorstore at: %s", self.qdrant_path)
83+
_vectorstore = QdrantVectorstore(location=str(self.qdrant_path))
84+
return _vectorstore
7885

7986
def _read_pdf_text(self, pdf_path: Path) -> str:
8087
"""Extract text from PDF.
@@ -161,120 +168,83 @@ def ingest_pdfs(self) -> List[Dict]:
161168
)
162169
return docs
163170

164-
def build_or_load_index(self) -> Tuple[List[Dict], Optional[np.ndarray]]:
165-
"""Build or load cached embedding index using Qdrant vector store.
171+
def build_or_load_index(self) -> Tuple[List[Dict], None]:
172+
"""Build or load index using datapizza-ai vectorstore.
166173
167-
If Qdrant collection exists, loads from it. Otherwise, ingests
168-
all PDFs from data_dir and generates embeddings using the
169-
SentenceTransformer model, storing them in Qdrant.
174+
Creates QdrantVectorstore collection and adds document chunks
175+
with embeddings using datapizza-ai's FastEmbedder.
170176
171177
Returns:
172-
Tuple of (documents list, embeddings array or None)
178+
Tuple of (documents list, None) - embeddings managed by vectorstore
173179
174180
Raises:
175181
RuntimeError: If no PDFs found when building index
176182
"""
177-
# Initialize Qdrant client
178-
if self._qdrant_client is None:
179-
logger.info("Initializing Qdrant client at: %s", self.qdrant_path)
180-
self._qdrant_client = QdrantClient(path=str(self.qdrant_path))
183+
# Initialize vectorstore and embedder
184+
if self._vectorstore is None:
185+
self._vectorstore = self._get_vectorstore()
186+
if self._embedder is None:
187+
self._embedder = self._get_embedder()
181188

182-
# Check if collection exists with documents
189+
# Check if collection already exists with data
183190
try:
184-
collections = self._qdrant_client.get_collections().collections
185-
collection_exists = any(c.name == QDRANT_COLLECTION for c in collections)
186-
if collection_exists:
187-
# Check if collection has points
188-
count = self._qdrant_client.count(
191+
collections = self._vectorstore.get_collections()
192+
collection_names = [c.name for c in collections.collections]
193+
194+
if QDRANT_COLLECTION in collection_names:
195+
count = self._vectorstore.get_client().count(
189196
collection_name=QDRANT_COLLECTION
190197
)
191198
if count.count > 0:
192199
logger.info(
193-
"Loaded existing Qdrant index with %d documents",
200+
"Loaded existing collection '%s' with %d documents",
201+
QDRANT_COLLECTION,
194202
count.count
195203
)
196204
self._is_indexed = True
197-
# Load documents metadata if pickle cache exists
198-
if self.emb_cache.exists():
199-
logger.debug("Loading document metadata from pickle cache")
200-
with open(self.emb_cache, "rb") as f:
201-
payload = pickle.load(f)
202-
self._documents = payload.get("docs", [])
203-
return self._documents or [], None
205+
# Load documents list if needed
206+
self._documents = []
207+
return self._documents, None
204208
except Exception as e:
205-
logger.debug("No existing Qdrant index found: %s", e)
209+
logger.debug("No existing collection found: %s", e)
206210

207211
# Build new index
208-
logger.info("Building new embedding index from PDFs")
212+
logger.info("Building new index with datapizza-ai")
209213
docs = self.ingest_pdfs()
210214
if not docs:
211215
raise RuntimeError(f"No PDFs found in {self.data_dir}")
212216

213217
self._documents = docs
214218

215-
# Generate embeddings
216-
embedder = self._load_embedder()
217-
logger.info("Generating embeddings for %d document chunks", len(docs))
218-
texts = [d["text"] for d in docs]
219-
embs = embedder.encode(
220-
texts, batch_size=64, show_progress_bar=True, convert_to_numpy=True
221-
)
222-
self._embeddings = embs
223-
224-
# Get embedding dimension
225-
embedding_dim = embs.shape[1]
226-
logger.info("Embedding dimension: %d", embedding_dim)
227-
228-
# Create Qdrant collection
229-
logger.info("Creating Qdrant collection: %s", QDRANT_COLLECTION)
230-
self._qdrant_client.create_collection(
231-
collection_name=QDRANT_COLLECTION,
232-
vectors_config=VectorParams(
233-
size=embedding_dim,
234-
distance=Distance.COSINE
235-
)
236-
)
237-
238-
# Upload points to Qdrant
239-
logger.info("Uploading %d points to Qdrant", len(docs))
240-
points = []
241-
for idx, (doc, emb) in enumerate(zip(docs, embs)):
242-
points.append(
243-
PointStruct(
244-
id=idx,
245-
vector=emb.tolist(),
246-
payload={
247-
"doc_id": doc["id"],
248-
"source": doc["source"],
249-
"text": doc["text"]
250-
}
251-
)
252-
)
253-
254-
# Upload in batches
255-
batch_size = 100
256-
for i in range(0, len(points), batch_size):
257-
batch = points[i:i + batch_size]
258-
self._qdrant_client.upsert(
259-
collection_name=QDRANT_COLLECTION,
260-
points=batch
219+
# Create Chunk objects for datapizza-ai
220+
logger.info("Creating %d chunks for vectorstore", len(docs))
221+
chunks = []
222+
for doc in docs:
223+
# Embed the text using datapizza-ai FastEmbedder
224+
embedding = self._embedder.embed(doc["text"])
225+
226+
# Create Chunk with embedding
227+
chunk = Chunk(
228+
id=doc["id"],
229+
text=doc["text"],
230+
embeddings=embedding,
231+
metadata={"source": doc["source"]}
261232
)
262-
logger.debug("Uploaded batch %d/%d", i // batch_size + 1, math.ceil(len(points) / batch_size))
233+
chunks.append(chunk)
263234

264-
# Cache embeddings and documents for backward compatibility
265-
logger.info("Caching embeddings to: %s", self.emb_cache)
266-
with open(self.emb_cache, "wb") as f:
267-
pickle.dump({"docs": docs, "embeddings": embs}, f)
235+
# Add chunks to vectorstore (this will create the collection)
236+
logger.info("Adding chunks to QdrantVectorstore collection '%s'", QDRANT_COLLECTION)
237+
self._vectorstore.add(chunks, collection_name=QDRANT_COLLECTION)
268238

269-
logger.info("Index built and cached successfully in Qdrant")
239+
logger.info("Index built successfully using datapizza-ai")
270240
self._is_indexed = True
271-
return docs, embs
241+
return docs, None
272242

273243
def retrieve(self, query: str, k: int = 15) -> List[Dict]:
274-
"""Retrieve k most similar documents via semantic search using Qdrant.
244+
"""Retrieve k most similar documents using datapizza-ai vectorstore.
275245
276-
Encodes the query and finds the k documents with highest
277-
cosine similarity using Qdrant vector search.
246+
Uses datapizza-ai's FastEmbedder to encode the query and
247+
QdrantVectorstore to find similar documents.
278248
279249
Args:
280250
query: Search query text
@@ -283,35 +253,49 @@ def retrieve(self, query: str, k: int = 15) -> List[Dict]:
283253
Returns:
284254
List of k most similar documents with scores
285255
"""
286-
if self._qdrant_client is None or not self._is_indexed:
256+
if not self._is_indexed:
287257
logger.debug("Index not loaded, building or loading now")
288258
self.build_or_load_index()
289259

290-
embedder = self._load_embedder()
291-
logger.debug("Encoding query: %s", query[:100])
292-
query_vector = embedder.encode([query], convert_to_numpy=True)[0]
293-
294-
logger.debug("Performing Qdrant search for top %d documents", k)
295-
# Search in Qdrant
296-
search_results = self._qdrant_client.search(
260+
if self._vectorstore is None:
261+
self._vectorstore = self._get_vectorstore()
262+
if self._embedder is None:
263+
self._embedder = self._get_embedder()
264+
265+
logger.debug("Encoding query with datapizza-ai: %s", query[:100])
266+
267+
# Embed query using datapizza-ai FastEmbedder
268+
query_embedding = self._embedder.embed(query)
269+
270+
logger.debug("Searching in QdrantVectorstore for top %d documents", k)
271+
272+
# Search using datapizza-ai vectorstore
273+
# Extract the embedding vector (FastEmbedder returns list of embeddings)
274+
if isinstance(query_embedding, list) and len(query_embedding) > 0:
275+
query_vector = query_embedding[0]
276+
else:
277+
query_vector = query_embedding
278+
279+
# Perform search
280+
search_results = self._vectorstore.search(
297281
collection_name=QDRANT_COLLECTION,
298-
query_vector=query_vector.tolist(),
299-
limit=k
282+
query_vector=query_vector,
283+
k=k
300284
)
301285

302-
logger.info("Retrieved %d documents from Qdrant", len(search_results))
286+
logger.info("Retrieved %d documents from datapizza-ai vectorstore", len(search_results))
303287

304-
# Format results to match expected output format
288+
# Convert Chunk objects to dict format for compatibility
305289
results = []
306-
for result in search_results:
290+
for chunk in search_results:
307291
results.append({
308-
"id": result.payload["doc_id"],
309-
"source": result.payload["source"],
310-
"text": result.payload["text"],
311-
"score": float(result.score)
292+
"id": chunk.id,
293+
"source": chunk.metadata.get("source", "unknown"),
294+
"text": chunk.text,
295+
"score": 1.0 # Qdrant returns chunks without explicit scores in this mode
312296
})
313297

314298
if results:
315-
logger.info("Top score: %.4f", results[0]["score"])
299+
logger.info("Retrieved %d results", len(results))
316300

317301
return results

0 commit comments

Comments
 (0)