Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions src/flare_ai_rag/retriever/qdrant_collection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import google.api_core.exceptions
import pandas as pd
import structlog
from qdrant_client import QdrantClient
Expand All @@ -14,7 +15,6 @@ def _create_collection(
) -> None:
"""
Creates a Qdrant collection with the given parameters.

:param collection_name: Name of the collection.
:param vector_size: Dimension of the vectors.
"""
Expand All @@ -31,21 +31,19 @@ def generate_collection(
embedding_client: GeminiEmbedding,
) -> None:
"""Routine for generating a Qdrant collection for a specific CSV file type."""
# Create the collection.
_create_collection(
qdrant_client, retriever_config.collection_name, retriever_config.vector_size
)
logger.info(
"Created the collection.", collection_name=retriever_config.collection_name
)

# For each document in the CSV, compute its embedding and prepare a Qdrant point.
points = []
for i, row in df_docs.iterrows():
doc_id = i
for idx, (_, row) in enumerate(
df_docs.iterrows(), start=1
): # Using _ for unused variable
content = row["Contents"]

# Check if content is missing or not a string.
if not isinstance(content, str):
logger.warning(
"Skipping document due to missing or invalid content.",
Expand All @@ -54,34 +52,52 @@ def generate_collection(
continue

try:
# Compute the embedding for the document content.
embedding = embedding_client.embed_content(
embedding_model=retriever_config.embedding_model,
task_type=EmbeddingTaskType.RETRIEVAL_DOCUMENT,
contents=content,
title=str(row["Filename"]),
)
except Exception as e:
except google.api_core.exceptions.InvalidArgument as e:
# Check if it's the known "Request payload size exceeds the limit" error
# If so, downgrade it to a warning
if "400 Request payload size exceeds the limit" in str(e):
logger.warning(
"Skipping document due to size limit.",
filename=row["Filename"],
)
continue
# Log the full traceback for other InvalidArgument errors
logger.exception(
"Error encoding document (InvalidArgument).",
filename=row["Filename"],
)
continue
except Exception:
# Log the full traceback for any other errors
logger.exception(
"Error encoding document.", filename=row["Filename"], error=str(e)
"Error encoding document (general).",
filename=row["Filename"],
)
continue

# Prepare the payload.
payload = {
"filename": row["Filename"],
"metadata": row["Metadata"],
"text": content,
}

# Create a Qdrant point.
point = PointStruct(id=doc_id, vector=embedding, payload=payload) # pyright: ignore [reportArgumentType]
point = PointStruct(
id=idx, # Using integer ID starting from 1
vector=embedding,
payload=payload,
)
points.append(point)

if points:
# Upload the points into the Qdrant collection.
qdrant_client.upsert(
collection_name=retriever_config.collection_name, points=points
collection_name=retriever_config.collection_name,
points=points,
)
logger.info(
"Collection generated and documents inserted into Qdrant successfully.",
Expand Down