diff --git a/examples/08_rag_reranker_demo.py b/examples/08_rag_reranker_demo.py new file mode 100644 index 00000000..42a9831c --- /dev/null +++ b/examples/08_rag_reranker_demo.py @@ -0,0 +1,233 @@ +# pyright: reportMissingImports=false +# pyright: reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false +""" +Demo: RAG with LLM-based Reranking + +Demonstrates why reranking helps: compare vector-only retrieval vs LLM reranking on an +adversarial doc set (high keyword overlap but low true relevance). +""" + +import asyncio +import sys +from pathlib import Path +from typing import Final +from uuid import uuid4 + +from dotenv import load_dotenv + +# pyright: ignore[reportMissingImports] +from qdrant_client import QdrantClient + +from flare_ai_kit.agent import AgentSettings +from flare_ai_kit.common import SemanticSearchResult +from flare_ai_kit.rag.vector.api import ( + FixedSizeChunker, + GeminiEmbedding, + GeminiPointwiseReranker, + LocalFileIndexer, + QdrantRetriever, + ingest_and_embed, + upsert_to_qdrant, +) +from flare_ai_kit.rag.vector.settings import RerankerSettings, VectorDbSettings + + +def _print_results( + title: str, + results: list[SemanticSearchResult], + *, + show_reranker_scores: bool = False, + max_text_chars: int = 96, +) -> None: + print(f"\n{title} (top {len(results)}):\n") + for i, res in enumerate(results, 1): + source = res.metadata.get("file_name", "unknown") + if show_reranker_scores: + original_score = res.metadata.get("original_score", "N/A") + reranker_score = res.metadata.get("reranker_score", "N/A") + scores = f"orig={original_score}, rerank={reranker_score}/10" + else: + scores = f"vector={res.score:.3f}" + preview = res.text.replace("\n", " ")[:max_text_chars] + print(f" {i:>2}. [{source}] ({scores})") + print(f" {preview}...") + + +def _hit_rate(results: list[SemanticSearchResult], relevant_sources: set[str]) -> int: + return sum(1 for r in results if r.metadata.get("file_name") in relevant_sources) + + +async def main() -> None: # noqa: PLR0915 + """Run the RAG + reranking demo.""" + load_dotenv() + agent = AgentSettings() # pyright: ignore[reportCallIssue] + vector_db = VectorDbSettings(qdrant_batch_size=8) + reranker_settings = RerankerSettings(enabled=True, score_threshold=5.0) + + if agent.gemini_api_key is None: + print("GEMINI_API_KEY environment variable not set.") + print("Please set the GEMINI_API_KEY environment variable to run this demo.") + print("Example: export GEMINI_API_KEY='your_api_key_here'") + sys.exit(1) + + demo_dir = Path("demo_rerank_data") + demo_dir.mkdir(parents=True, exist_ok=True) + + # Intentionally include "keyword trap" docs that mention the right terms but are not helpful. + documents: Final[list[tuple[str, str]]] = [ + ( + "how_to_reset_password.txt", + "Password reset steps: On the login page, click 'Forgot Password'. " + "Enter your email. Check your inbox for the reset link (valid for 24 hours). " + "If the link expired, request a new one. If you still can't reset, contact support.", + ), + ( + "reset_link_expired.txt", + "If your password reset link has expired: return to the login page, click 'Forgot Password', " + "and request a new link. For security, links expire after 24 hours and can only be used once.", + ), + ( + "account_lockout_policy.txt", + "Account lockout: after 5 failed login attempts, your account is locked for 15 minutes. " + "Password reset does not unlock the account immediately. Wait 15 minutes or contact support.", + ), + ( + "billing_address_change.txt", + "To change your billing address: go to Settings → Billing → Address, update fields, and Save. " + "This does not change your login email or password.", + ), + ( + "two_factor_setup.txt", + "Enable two-factor authentication (2FA): Settings → Security → Two-Factor. " + "Scan the QR code in your authenticator app and enter the 6-digit code to confirm.", + ), + ( + "keyword_trap_marketing_blog.txt", + "Password reset, reset password, forgot password, reset link, password reset email — " + "we love talking about password reset in our marketing. This post resets expectations. " + "It does not contain support steps, but it repeats password reset keywords many times. " + "Password reset. Reset password. Forgot password. Reset link.", + ), + ( + "keyword_trap_shipping_reset.txt", + "If your package tracking looks wrong, reset the tracking view in the shipping portal. " + "This is unrelated to accounts, login, or password reset. It mentions reset and link and email.", + ), + ( + "password_requirements.txt", + "Password requirements: 8+ chars, one uppercase, one lowercase, one number. " + "If your password reset fails, ensure the new password meets requirements.", + ), + ( + "sso_login_note.txt", + "If you signed up with Google/SSO, you may not have a password to reset. " + "Use 'Continue with Google' or contact support to convert your account to email/password login.", + ), + ] + + try: + for filename, content in documents: + (demo_dir / filename).write_text(content, encoding="utf-8") + + print(f"Created {len(documents)} sample documents in {demo_dir}/\n") + + # Use larger chunks so each doc stays mostly intact for clearer ranking diffs. + chunker = FixedSizeChunker(chunk_size=256) + indexer = LocalFileIndexer( + root_dir=str(demo_dir), chunker=chunker, allowed_extensions={".txt"} + ) + + embedding_model = GeminiEmbedding( + api_key=agent.gemini_api_key.get_secret_value(), + model=vector_db.embeddings_model, + output_dimensionality=vector_db.embeddings_output_dimensionality, + ) + + data = ingest_and_embed(indexer, embedding_model, batch_size=8) + print(f"Ingested and embedded {len(data)} chunks.\n") + + collection_name = f"reranker-demo-{uuid4().hex}" + vector_size = 768 + upsert_to_qdrant( + data, + str(vector_db.qdrant_url), + collection_name, + vector_size, + batch_size=vector_db.qdrant_batch_size, + ) + print( + f"Upserted {len(data)} vectors to Qdrant collection '{collection_name}'.\n" + ) + + client = QdrantClient(str(vector_db.qdrant_url)) + retriever = QdrantRetriever(client, embedding_model, vector_db) + + reranker = GeminiPointwiseReranker( + api_key=agent.gemini_api_key.get_secret_value(), + model=reranker_settings.model, + num_batches=reranker_settings.num_batches, + timeout_seconds=reranker_settings.timeout_seconds, + score_threshold=reranker_settings.score_threshold, + ) + + scenarios: Final[list[tuple[str, str, set[str]]]] = [ + ( + "Reset password (needs actionable steps; keyword traps present)", + "My password reset link expired. How do I get a new one?", + {"how_to_reset_password.txt", "reset_link_expired.txt"}, + ), + ( + "SSO edge case (vector overlap is misleading)", + "I signed up with Google SSO. Why doesn't 'Forgot Password' work?", + {"sso_login_note.txt"}, + ), + ( + "Security settings (should prefer the exact setup steps)", + "How do I enable two-factor authentication?", + {"two_factor_setup.txt"}, + ), + ] + + top_k_retrieve = 8 + top_k_show = 5 + for title, query, relevant_sources in scenarios: + print("\n" + "=" * 80) + print(f"{title}\nQuery: {query}") + + candidates = retriever.semantic_search( + query, collection_name, top_k=top_k_retrieve + ) + baseline = candidates[:top_k_show] + + _print_results("Vector-only results", baseline, show_reranker_scores=False) + + print("\nReranking with LLM...\n") + reranked = await reranker.rerank(query, candidates, top_k=top_k_show) + _print_results( + "Reranked results", + reranked, + show_reranker_scores=True, + ) + + base_hits = _hit_rate(baseline, relevant_sources) + rerank_hits = _hit_rate(reranked, relevant_sources) + print("\nSummary:") + print(f" - Relevant sources: {sorted(relevant_sources)}") + print( + f" - Hit@{top_k_show}: vector-only={base_hits}/{top_k_show}, reranked={rerank_hits}/{top_k_show}" + ) + + print("\n" + "=" * 80) + print("Demo complete.") + print( + f"Notes: reranker score threshold={reranker_settings.score_threshold}, " + f"batches={reranker_settings.num_batches}, timeout={reranker_settings.timeout_seconds}s" + ) + finally: + for filename, _ in documents: + (demo_dir / filename).unlink(missing_ok=True) + demo_dir.rmdir() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/flare_ai_kit/common/exceptions.py b/src/flare_ai_kit/common/exceptions.py index de8fdd5a..88e6bb78 100644 --- a/src/flare_ai_kit/common/exceptions.py +++ b/src/flare_ai_kit/common/exceptions.py @@ -138,3 +138,12 @@ class A2AClientError(FlareAIKitError): # --- PDF Processing Errors --- class PdfPostingError(FlareAIKitError): """Error class concerned with onchain PDF data posting errors.""" + + +# --- Reranker Errors --- +class RerankerError(FlareAIKitError): + """Base exception for reranker errors.""" + + +class RerankerParseError(RerankerError): + """Raised when LLM response cannot be parsed.""" diff --git a/src/flare_ai_kit/rag/vector/api.py b/src/flare_ai_kit/rag/vector/api.py index 2ea38bb2..c2b647b4 100644 --- a/src/flare_ai_kit/rag/vector/api.py +++ b/src/flare_ai_kit/rag/vector/api.py @@ -10,6 +10,7 @@ ingest_and_embed, upsert_to_qdrant, ) +from .reranker import BaseReranker, GeminiPointwiseReranker from .responder import BaseResponder from .retriever import BaseRetriever, QdrantRetriever @@ -17,10 +18,12 @@ "BaseChunker", "BaseEmbedding", "BaseIndexer", + "BaseReranker", "BaseResponder", "BaseRetriever", "FixedSizeChunker", "GeminiEmbedding", + "GeminiPointwiseReranker", "LocalFileIndexer", "QdrantRetriever", "VectorRAGPipeline", diff --git a/src/flare_ai_kit/rag/vector/factory.py b/src/flare_ai_kit/rag/vector/factory.py index 1673e728..b9a20cc6 100644 --- a/src/flare_ai_kit/rag/vector/factory.py +++ b/src/flare_ai_kit/rag/vector/factory.py @@ -1,20 +1,21 @@ """Factory functions for creating RAG (Retrieval-Augmented Generation) pipelines.""" -from dataclasses import dataclass +from dataclasses import dataclass, field import structlog from qdrant_client import QdrantClient from flare_ai_kit.agent.settings import AgentSettings -from flare_ai_kit.common import FlareAIKitError +from flare_ai_kit.common import FlareAIKitError, SemanticSearchResult from flare_ai_kit.rag.vector.embedding import GeminiEmbedding +from flare_ai_kit.rag.vector.reranker import BaseReranker, GeminiPointwiseReranker from flare_ai_kit.rag.vector.retriever import QdrantRetriever -from flare_ai_kit.rag.vector.settings import VectorDbSettings +from flare_ai_kit.rag.vector.settings import RerankerSettings, VectorDbSettings logger = structlog.get_logger(__name__) -@dataclass(frozen=True) +@dataclass class VectorRAGPipeline: """ A container for the components of a vector-based RAG pipeline. @@ -25,31 +26,82 @@ class VectorRAGPipeline: Attributes: retriever: A retriever instance (e.g., QdrantRetriever) used to embed chunks, store them, and perform semantic search. + reranker: Optional reranker instance for improving retrieval quality. """ retriever: QdrantRetriever + reranker: BaseReranker | None = field(default=None) + + async def retrieve_and_rerank( + self, + query: str, + collection_name: str, + top_k_retrieve: int = 20, + top_k_rerank: int = 5, + score_threshold: float | None = None, + ) -> list[SemanticSearchResult]: + """ + Retrieve candidates and optionally rerank them. + + This is the main entry point for RAG retrieval with reranking. + First retrieves top_k_retrieve candidates, then reranks them + (if reranker is configured) and returns top_k_rerank results. + + Args: + query: The search query string. + collection_name: Name of the Qdrant collection to search. + top_k_retrieve: Number of candidates to retrieve initially. + top_k_rerank: Number of results to return after reranking. + score_threshold: Optional minimum similarity score for retrieval. + + Returns: + List of SemanticSearchResult, reranked if reranker is configured. + + """ + # Initial retrieval + candidates = self.retriever.semantic_search( + query=query, + collection_name=collection_name, + top_k=top_k_retrieve, + score_threshold=score_threshold, + ) + + if not candidates: + return [] + + # Rerank if configured + if self.reranker: + return await self.reranker.rerank(query, candidates, top_k=top_k_rerank) + + # No reranker - just return top_k results + return candidates[:top_k_rerank] def create_vector_rag_pipeline( - vector_db_settings: VectorDbSettings, agent_settings: AgentSettings + vector_db_settings: VectorDbSettings, + agent_settings: AgentSettings, + reranker_settings: RerankerSettings | None = None, ) -> VectorRAGPipeline: """ Builds and configures a complete vector RAG pipeline. This factory function initializes and wires together all the necessary components for a vector-based RAG system, including the embedding model, - the vector database client, the data indexer, and the retriever. + the vector database client, the data indexer, the retriever, and optionally + a reranker for improved retrieval quality. Args: vector_db_settings: Configuration specific to the vector database and data chunking/indexing process. agent_settings: Configuration for the AI agent, which includes the necessary API keys for the embedding model. + reranker_settings: Optional configuration for LLM-based reranking. + If None or disabled, no reranker is initialized. Returns: - A `VectorRAGPipeline` object containing the fully configured indexer - and retriever. + A `VectorRAGPipeline` object containing the fully configured indexer, + retriever, and optionally a reranker. Raises: FlareAIKitError: If essential configuration like the Qdrant URL or @@ -90,12 +142,32 @@ def create_vector_rag_pipeline( ) logger.debug("QdrantRetriever initialized.") + # 4. Optional Reranker + reranker: BaseReranker | None = None + if reranker_settings and reranker_settings.enabled: + reranker = GeminiPointwiseReranker( + api_key=agent_settings.gemini_api_key.get_secret_value(), + model=reranker_settings.model, + num_batches=reranker_settings.num_batches, + timeout_seconds=reranker_settings.timeout_seconds, + score_threshold=reranker_settings.score_threshold, + few_shot_examples=reranker_settings.few_shot_examples, + ) + logger.debug( + "GeminiPointwiseReranker initialized.", + model=reranker_settings.model, + num_batches=reranker_settings.num_batches, + ) + except Exception as e: logger.exception("Failed to initialize a component for the RAG pipeline.") msg = "Could not create vector RAG pipeline." raise FlareAIKitError(msg) from e - pipeline = VectorRAGPipeline(retriever=retriever) - logger.info("Vector RAG pipeline created successfully.") + pipeline = VectorRAGPipeline(retriever=retriever, reranker=reranker) + logger.info( + "Vector RAG pipeline created successfully.", + has_reranker=reranker is not None, + ) return pipeline diff --git a/src/flare_ai_kit/rag/vector/reranker/__init__.py b/src/flare_ai_kit/rag/vector/reranker/__init__.py new file mode 100644 index 00000000..219ef698 --- /dev/null +++ b/src/flare_ai_kit/rag/vector/reranker/__init__.py @@ -0,0 +1,9 @@ +"""Reranker module for improving retrieval quality.""" + +from .base import BaseReranker +from .gemini_reranker import GeminiPointwiseReranker + +__all__ = [ + "BaseReranker", + "GeminiPointwiseReranker", +] diff --git a/src/flare_ai_kit/rag/vector/reranker/base.py b/src/flare_ai_kit/rag/vector/reranker/base.py new file mode 100644 index 00000000..dd4de65c --- /dev/null +++ b/src/flare_ai_kit/rag/vector/reranker/base.py @@ -0,0 +1,37 @@ +"""Base class for reranking retrieved passages.""" + +from abc import ABC, abstractmethod + +from flare_ai_kit.common import SemanticSearchResult + + +class BaseReranker(ABC): + """ + Abstract base class for reranking modules. + + Rerankers take a query and a list of candidate passages from initial + retrieval and re-score them based on relevance to the query. + """ + + @abstractmethod + async def rerank( + self, + query: str, + candidates: list[SemanticSearchResult], + top_k: int | None = None, + ) -> list[SemanticSearchResult]: + """ + Rerank candidate passages based on relevance to the query. + + Args: + query: The user query to score passages against. + candidates: List of SemanticSearchResult from initial retrieval. + top_k: Optional limit on number of results to return. + If None, returns all candidates that pass the threshold. + + Returns: + Reranked list of SemanticSearchResult sorted by relevance score + (highest first). The score field will contain the reranker score + normalized to 0-1 range. + + """ diff --git a/src/flare_ai_kit/rag/vector/reranker/gemini_reranker.py b/src/flare_ai_kit/rag/vector/reranker/gemini_reranker.py new file mode 100644 index 00000000..c2cf6906 --- /dev/null +++ b/src/flare_ai_kit/rag/vector/reranker/gemini_reranker.py @@ -0,0 +1,338 @@ +"""LLM-based pointwise reranker using Gemini API.""" + +import asyncio +import json +import re +from typing import Any, Final, cast, override + +import structlog +from google import genai # pyright: ignore[reportMissingTypeStubs] +from google.genai import types # pyright: ignore[reportMissingTypeStubs] + +from flare_ai_kit.common import RerankerError, RerankerParseError, SemanticSearchResult + +from .base import BaseReranker +from .prompts import build_system_prompt, build_user_prompt + +logger = structlog.get_logger(__name__) + +MIN_SCORE: Final[float] = 0.0 +MAX_SCORE: Final[float] = 10.0 + + +class GeminiPointwiseReranker(BaseReranker): + """ + LLM-based pointwise reranker using Google's Gemini API. + + This reranker scores each passage on a 1-10 scale based on relevance + to the query. It uses parallel batch processing with round-robin + distribution for improved latency and accuracy. + """ + + def __init__( # noqa: PLR0913 + self, + api_key: str, + model: str = "gemini-3-flash-preview", + num_batches: int = 4, + timeout_seconds: float = 5.0, + score_threshold: float = 5.0, + few_shot_examples: str | None = None, + ) -> None: + """ + Initialize the Gemini pointwise reranker. + + Args: + api_key: Google API key for accessing Gemini models. + model: Gemini model identifier (e.g., "gemini-2.0-flash"). + num_batches: Number of parallel batches for scoring. + timeout_seconds: Timeout per batch in seconds. + score_threshold: Minimum score (0-10) to include a passage. + few_shot_examples: Optional custom few-shot examples for calibration. + + """ + self.model = model + self.num_batches = num_batches + self.timeout_seconds = timeout_seconds + self.score_threshold = score_threshold + self.few_shot_examples = few_shot_examples + self.client = genai.Client(api_key=api_key) + self._system_prompt = build_system_prompt(few_shot_examples) + + @override + async def rerank( + self, + query: str, + candidates: list[SemanticSearchResult], + top_k: int | None = None, + ) -> list[SemanticSearchResult]: + """ + Rerank candidate passages using parallel batched LLM scoring. + + Args: + query: The user query to score passages against. + candidates: List of SemanticSearchResult from initial retrieval. + top_k: Optional limit on number of results to return. + + Returns: + Reranked list of SemanticSearchResult sorted by relevance score. + + """ + if not candidates: + return [] + + logger.debug( + "Starting rerank", + num_candidates=len(candidates), + num_batches=self.num_batches, + top_k=top_k, + ) + + batches = self._distribute_to_batches(candidates, self.num_batches) + + tasks = [self._score_batch(query, batch) for batch in batches if batch] + + try: + results = await asyncio.wait_for( + asyncio.gather(*tasks, return_exceptions=True), + timeout=self.timeout_seconds * 2, + ) + except TimeoutError: + logger.warning( + "Reranker timeout, using fallback", + timeout_seconds=self.timeout_seconds, + ) + return self._fallback_rerank(candidates, top_k) + + all_scores: dict[int, float] = {} + for result in results: + if isinstance(result, dict): + all_scores.update(result) + elif isinstance(result, Exception): + logger.warning("Batch failed", error=str(result)) + + if not all_scores: + logger.warning("No scores obtained from LLM, using fallback") + return self._fallback_rerank(candidates, top_k) + + reranked: list[SemanticSearchResult] = [] + for idx, candidate in enumerate(candidates): + score = all_scores.get(idx) + if score is not None and score >= self.score_threshold: + updated_metadata = dict(candidate.metadata) + updated_metadata["original_score"] = str(candidate.score) + updated_metadata["reranker_score"] = str(score) + + reranked.append( + SemanticSearchResult( + text=candidate.text, + score=score / 10.0, + metadata=updated_metadata, + ) + ) + + reranked.sort(key=lambda x: x.score, reverse=True) + + logger.debug( + "Rerank complete", + input_count=len(candidates), + output_count=len(reranked), + filtered_count=len(candidates) - len(reranked), + ) + + return reranked[:top_k] if top_k else reranked + + def _distribute_to_batches( + self, + candidates: list[SemanticSearchResult], + num_batches: int, + ) -> list[list[tuple[int, SemanticSearchResult]]]: + """ + Distribute candidates to batches using round-robin assignment. + + This ensures each batch sees a mix of high/medium/low similarity + passages, reducing positional bias from vector search ordering. + + Args: + candidates: List of candidate passages. + num_batches: Number of batches to distribute to. + + Returns: + List of batches, each containing (original_index, candidate) tuples. + + """ + batches: list[list[tuple[int, SemanticSearchResult]]] = [ + [] for _ in range(num_batches) + ] + for idx, candidate in enumerate(candidates): + batch_idx = idx % num_batches + batches[batch_idx].append((idx, candidate)) + return batches + + async def _score_batch( + self, + query: str, + batch: list[tuple[int, SemanticSearchResult]], + ) -> dict[int, float]: + """ + Score a single batch of passages via LLM. + + Args: + query: The user query. + batch: List of (original_index, candidate) tuples. + + Returns: + Dictionary mapping original indices to scores. + + Raises: + RerankerError: If the LLM call fails. + + """ + # Prompt IDs are batch-local (id0..idN); map back to original indices below. + passages = [ + (local_idx, candidate.text) + for local_idx, (_, candidate) in enumerate(batch) + ] + user_prompt = build_user_prompt(query, passages) + + local_to_original = {i: orig_idx for i, (orig_idx, _) in enumerate(batch)} + + try: + response = await asyncio.wait_for( + asyncio.to_thread( + self._call_gemini, + user_prompt, + ), + timeout=self.timeout_seconds, + ) + + local_scores = self._parse_scores(response) + return { + local_to_original[local_idx]: score + for local_idx, score in local_scores.items() + } + + except TimeoutError: + logger.warning("Batch scoring timeout", batch_size=len(batch)) + raise + except RerankerParseError: + raise + except Exception as e: + logger.warning("Batch scoring failed", error=str(e)) + msg = f"Failed to score batch: {e}" + raise RerankerError(msg) from e + + def _call_gemini(self, user_prompt: str) -> str: + """ + Make synchronous Gemini API call. + + Args: + user_prompt: The user prompt with query and passages. + + Returns: + The LLM response text. + + """ + response = self.client.models.generate_content( # pyright: ignore[reportUnknownMemberType] + model=self.model, + contents=[ + types.Content( + role="user", + parts=[types.Part(text=user_prompt)], + ), + ], + config=types.GenerateContentConfig( + system_instruction=self._system_prompt, + temperature=0.0, # Deterministic scoring + max_output_tokens=1024, + ), + ) + + if response.text: + return response.text + msg = "Gemini API returned empty response" + raise RerankerError(msg) + + def _parse_scores(self, response: str) -> dict[int, float]: + """ + Parse JSON scores from LLM response. + + Expected format: {"id0":8,"id1":10,"id3":7} + + Args: + response: Raw LLM response text. + + Returns: + Dictionary mapping local indices to scores. + + Raises: + RerankerParseError: If response cannot be parsed. + + """ + response = response.strip() + + json_match = re.search(r"\{[^{}]*\}", response) + if json_match: + response = json_match.group() + + try: + raw_data: Any = json.loads(response) + except json.JSONDecodeError as e: + logger.warning("Failed to parse JSON response", response=response[:200]) + msg = f"Invalid JSON in response: {e}" + raise RerankerParseError(msg) from e + + if not isinstance(raw_data, dict): + msg = f"Expected dict, got {type(raw_data).__name__}" + raise RerankerParseError(msg) + + # Parse id keys and validate scores + data = cast("dict[str, Any]", raw_data) + scores: dict[int, float] = {} + for key, value in data.items(): + # Extract index from "idN" format + if not key.startswith("id"): + logger.warning("Invalid key format", key=key) + continue + + try: + idx = int(key[2:]) + except ValueError: + logger.warning("Cannot parse index from key", key=key) + continue + + # Validate score + if not isinstance(value, int | float): + logger.warning("Invalid score type", key=key, value=value) + continue + + score = float(value) + if not MIN_SCORE <= score <= MAX_SCORE: + logger.warning("Score out of range", key=key, score=score) + continue + + scores[idx] = score + + return scores + + def _fallback_rerank( + self, + candidates: list[SemanticSearchResult], + top_k: int | None = None, + ) -> list[SemanticSearchResult]: + """ + Fallback: return candidates sorted by original embedding score. + + Args: + candidates: List of candidate passages. + top_k: Optional limit on number of results. + + Returns: + Candidates sorted by original score (highest first). + + """ + logger.info( + "Using fallback reranking (original scores)", + num_candidates=len(candidates), + ) + sorted_candidates = sorted(candidates, key=lambda x: x.score, reverse=True) + return sorted_candidates[:top_k] if top_k else sorted_candidates diff --git a/src/flare_ai_kit/rag/vector/reranker/prompts.py b/src/flare_ai_kit/rag/vector/reranker/prompts.py new file mode 100644 index 00000000..440635a8 --- /dev/null +++ b/src/flare_ai_kit/rag/vector/reranker/prompts.py @@ -0,0 +1,139 @@ +"""Prompt templates for LLM-based pointwise reranking.""" +# ruff: noqa: E501 +# Long lines are intentional in prompts for readability + +POINTWISE_SYSTEM_PROMPT = """You are a customer support answer service. Your task is to evaluate help center passages and score their relevance to a given customer query for a retrieval augmented generation (RAG) system. + +Evaluation Process: +1. Analyze the customer's query to identify both explicit needs and implicit context including underlying user goals +2. Assess each passage's ability to directly resolve the query or provide substantive supporting information with actionable guidance +3. Score based on how effectively the passage addresses the query's core intent while considering potential interpretations + +Grading Criteria: + +10: EXCEPTIONAL match - Contains exact step-by-step instructions that perfectly match the query's specific scenario. Must include all required parameters/context and resolve the issue completely without any ambiguity. Reserved for definitive solutions that exactly mirror the user's described situation and require no interpretation. + +9: NEAR-PERFECT solution - Contains all critical steps for resolution but may lack one minor non-essential detail. Addresses the precise query parameters with specialized information. Solution must be directly applicable without requiring adaptation or assumptions. + +8: STRONG MATCH - Provides complete technical resolution through specific instructions, but may require simple logical inferences for full application. Covers all essential components but might need minor contextualization. + +7: GOOD MATCH - Contains substantial relevant details that address core aspects of the query, but lacks one important element for complete resolution. Provides concrete guidance requiring some user interpretation. + +6: PARTIAL match - General guidance on the right topic but lacks the specifics for direct application. May only resolve a subset of the request. + +5: LIMITED relevance - Related context or approach, but indirect. Requires substantial effort to adapt to the user's exact need. + +4: TANGENTIAL - Mentions related concepts/keywords with little practical connection to the request. Minimal actionable value. + +3: VAGUE domain info - Talks about the general area but not the query's specifics. No concrete, actionable steps. + +2: TOKEN overlap - Shares isolated terms without context or intent aligned to the request. Similarity is coincidental. + +1: IRRELEVANT - Uses query terms in a completely unrelated way. No meaningful link to the user's goal. + +0: UNRELATED - No thematic or contextual connection to the query at all. + + +Input Format: + + +// The customer's question or request + + +... +... +... + + + +Output Format: + +Return your response in a valid JSON (skip spaces): +{{"id0":score0,"id1":score1,...}} + +Strict guidelines: +- Return ONLY a well-formed valid JSON with passage IDs as keys +- Each key must be a passage id in the format "idN" +- Each score must be an integer between 5 to 10. EXCLUDE passages that score below 5 (i.e. 0, 1, 2, 3 or 4) +- Integer values only, no decimals +- Skip spaces in the JSON +- No additional text or formatting +- Maintain original passage ID order +- Note: If NO passages score 5+, return empty JSON object + + +{few_shot_section}""" + +POINTWISE_USER_TEMPLATE = """{query} + +{passages} +""" + +DEFAULT_FEW_SHOT_EXAMPLES = """ +Example 1: +How do I reset my password? + +To reset your password, click 'Forgot Password' on the login page, enter your email, and follow the link sent to your inbox. +Our company was founded in 2010 and has offices worldwide. +Password requirements include at least 8 characters with one uppercase letter. + +Output: {"id0":10,"id2":6} + +Example 2: +What are your business hours? + +Contact our sales team for enterprise pricing options. +We are open Monday to Friday, 9 AM to 5 PM EST. Weekend support is available via email only. + +Output: {"id1":10} +""" + + +def format_passages(passages: list[tuple[int, str]]) -> str: + """ + Format passages for inclusion in the reranker prompt. + + Args: + passages: List of tuples containing (passage_id, passage_text). + + Returns: + Formatted string with passages in XML-like tags. + + """ + return "\n".join( + f"{text}" for idx, text in passages + ) + + +def build_system_prompt(few_shot_examples: str | None = None) -> str: + """ + Build the system prompt with optional custom few-shot examples. + + Args: + few_shot_examples: Optional custom few-shot examples. If None, + uses DEFAULT_FEW_SHOT_EXAMPLES. + + Returns: + Complete system prompt with few-shot examples. + + """ + examples = few_shot_examples or DEFAULT_FEW_SHOT_EXAMPLES + few_shot_section = f"\n{examples}" if examples else "" + return POINTWISE_SYSTEM_PROMPT.format(few_shot_section=few_shot_section) + + +def build_user_prompt(query: str, passages: list[tuple[int, str]]) -> str: + """ + Build the user prompt with query and passages. + + Args: + query: The user query. + passages: List of tuples containing (passage_id, passage_text). + + Returns: + Formatted user prompt. + + """ + return POINTWISE_USER_TEMPLATE.format( + query=query, passages=format_passages(passages) + ) diff --git a/src/flare_ai_kit/rag/vector/settings.py b/src/flare_ai_kit/rag/vector/settings.py index 25a26763..8d2ba6b4 100644 --- a/src/flare_ai_kit/rag/vector/settings.py +++ b/src/flare_ai_kit/rag/vector/settings.py @@ -95,3 +95,41 @@ class VectorDbSettings(BaseSettings): postgres_dsn: PostgresDsn | None = Field( default=None, description="DSN for PostgreSQL connection string." ) + + +class RerankerSettings(BaseSettings): + """Configuration for LLM-based reranker.""" + + model_config = SettingsConfigDict( + env_prefix="RERANKER__", + env_file=".env", + extra="ignore", + ) + + enabled: bool = Field( + default=False, + description="Enable LLM-based reranking after retrieval.", + ) + model: str = Field( + default="gemini-3-flash-preview", + description="Gemini model for reranking.", + ) + num_batches: PositiveInt = Field( + default=4, + description="Number of parallel batches for scoring passages.", + ) + timeout_seconds: float = Field( + default=5.0, + gt=0.0, + description="Timeout per batch in seconds.", + ) + score_threshold: float = Field( + default=5.0, + ge=0.0, + le=10.0, + description="Minimum score (0-10) to include a passage in results.", + ) + few_shot_examples: str | None = Field( + default=None, + description="Custom few-shot examples for calibration. If None, uses defaults.", + ) diff --git a/tests/unit/rag/__init__.py b/tests/unit/rag/__init__.py new file mode 100644 index 00000000..0bd17669 --- /dev/null +++ b/tests/unit/rag/__init__.py @@ -0,0 +1 @@ +"""Unit tests for RAG module.""" diff --git a/tests/unit/rag/vector/__init__.py b/tests/unit/rag/vector/__init__.py new file mode 100644 index 00000000..3ac46559 --- /dev/null +++ b/tests/unit/rag/vector/__init__.py @@ -0,0 +1 @@ +"""Unit tests for vector RAG module.""" diff --git a/tests/unit/rag/vector/reranker/__init__.py b/tests/unit/rag/vector/reranker/__init__.py new file mode 100644 index 00000000..4a1b78bd --- /dev/null +++ b/tests/unit/rag/vector/reranker/__init__.py @@ -0,0 +1 @@ +"""Unit tests for reranker module.""" diff --git a/tests/unit/rag/vector/reranker/test_gemini_reranker.py b/tests/unit/rag/vector/reranker/test_gemini_reranker.py new file mode 100644 index 00000000..0ae44130 --- /dev/null +++ b/tests/unit/rag/vector/reranker/test_gemini_reranker.py @@ -0,0 +1,330 @@ +"""Unit tests for GeminiPointwiseReranker.""" + +from unittest.mock import patch + +import pytest + +from flare_ai_kit.common import RerankerParseError, SemanticSearchResult +from flare_ai_kit.rag.vector.reranker import GeminiPointwiseReranker + + +@pytest.fixture +def mock_candidates() -> list[SemanticSearchResult]: + """Create mock search candidates.""" + return [ + SemanticSearchResult( + text=f"Passage {i} about topic", + score=0.9 - i * 0.1, + metadata={"file_name": f"doc{i}.txt"}, + ) + for i in range(10) + ] + + +@pytest.fixture +def reranker() -> GeminiPointwiseReranker: + """Create reranker with mocked client.""" + with patch("flare_ai_kit.rag.vector.reranker.gemini_reranker.genai"): + return GeminiPointwiseReranker( + api_key="test-api-key", + model="gemini-2.0-flash", + num_batches=4, + timeout_seconds=5.0, + score_threshold=5.0, + ) + + +class TestDistributeToBatches: + """Tests for _distribute_to_batches method.""" + + def test_round_robin_distribution( + self, + reranker: GeminiPointwiseReranker, + mock_candidates: list[SemanticSearchResult], + ) -> None: + """Test that candidates are distributed round-robin across batches.""" + batches = reranker._distribute_to_batches(mock_candidates, 4) + + assert len(batches) == 4 + + # Batch 0 should have indices 0, 4, 8 + assert [idx for idx, _ in batches[0]] == [0, 4, 8] + # Batch 1 should have indices 1, 5, 9 + assert [idx for idx, _ in batches[1]] == [1, 5, 9] + # Batch 2 should have indices 2, 6 + assert [idx for idx, _ in batches[2]] == [2, 6] + # Batch 3 should have indices 3, 7 + assert [idx for idx, _ in batches[3]] == [3, 7] + + def test_single_batch( + self, + reranker: GeminiPointwiseReranker, + mock_candidates: list[SemanticSearchResult], + ) -> None: + """Test distribution with single batch.""" + batches = reranker._distribute_to_batches(mock_candidates, 1) + + assert len(batches) == 1 + assert len(batches[0]) == 10 + assert [idx for idx, _ in batches[0]] == list(range(10)) + + def test_more_batches_than_candidates( + self, reranker: GeminiPointwiseReranker + ) -> None: + """Test when there are more batches than candidates.""" + candidates = [ + SemanticSearchResult(text="p1", score=0.9, metadata={}), + SemanticSearchResult(text="p2", score=0.8, metadata={}), + ] + batches = reranker._distribute_to_batches(candidates, 4) + + assert len(batches) == 4 + # Only first 2 batches should have candidates + assert len(batches[0]) == 1 + assert len(batches[1]) == 1 + assert len(batches[2]) == 0 + assert len(batches[3]) == 0 + + +class TestParseScores: + """Tests for _parse_scores method.""" + + def test_valid_json(self, reranker: GeminiPointwiseReranker) -> None: + """Test parsing valid JSON response.""" + response = '{"id0":8,"id1":10,"id2":3}' + scores = reranker._parse_scores(response) + + assert scores == {0: 8.0, 1: 10.0, 2: 3.0} + + def test_json_with_whitespace(self, reranker: GeminiPointwiseReranker) -> None: + """Test parsing JSON with whitespace.""" + response = ' { "id0": 8, "id1": 10 } ' + scores = reranker._parse_scores(response) + + assert scores == {0: 8.0, 1: 10.0} + + def test_json_wrapped_in_text(self, reranker: GeminiPointwiseReranker) -> None: + """Test extracting JSON from text response.""" + response = 'Here are the scores: {"id0":7,"id1":9} based on my analysis.' + scores = reranker._parse_scores(response) + + assert scores == {0: 7.0, 1: 9.0} + + def test_empty_json(self, reranker: GeminiPointwiseReranker) -> None: + """Test parsing empty JSON (no passages scored above threshold).""" + response = "{}" + scores = reranker._parse_scores(response) + + assert scores == {} + + def test_invalid_json(self, reranker: GeminiPointwiseReranker) -> None: + """Test that invalid JSON raises RerankerParseError.""" + response = "not valid json" + with pytest.raises(RerankerParseError): + reranker._parse_scores(response) + + def test_invalid_key_format(self, reranker: GeminiPointwiseReranker) -> None: + """Test that invalid keys are skipped.""" + response = '{"id0":8,"invalid_key":10,"id1":7}' + scores = reranker._parse_scores(response) + + # Only valid id keys should be included + assert scores == {0: 8.0, 1: 7.0} + + def test_score_out_of_range(self, reranker: GeminiPointwiseReranker) -> None: + """Test that out-of-range scores are skipped.""" + response = '{"id0":8,"id1":15,"id2":-1}' + scores = reranker._parse_scores(response) + + # Only valid scores (0-10) should be included + assert scores == {0: 8.0} + + +class TestFallbackRerank: + """Tests for _fallback_rerank method.""" + + def test_returns_sorted_by_original_score( + self, + reranker: GeminiPointwiseReranker, + mock_candidates: list[SemanticSearchResult], + ) -> None: + """Test that fallback returns candidates sorted by original score.""" + # Shuffle the order + shuffled = mock_candidates[5:] + mock_candidates[:5] + result = reranker._fallback_rerank(shuffled, top_k=5) + + assert len(result) == 5 + # Should be sorted by score descending + for i in range(len(result) - 1): + assert result[i].score >= result[i + 1].score + + def test_respects_top_k( + self, + reranker: GeminiPointwiseReranker, + mock_candidates: list[SemanticSearchResult], + ) -> None: + """Test that fallback respects top_k limit.""" + result = reranker._fallback_rerank(mock_candidates, top_k=3) + assert len(result) == 3 + + def test_returns_all_when_top_k_none( + self, + reranker: GeminiPointwiseReranker, + mock_candidates: list[SemanticSearchResult], + ) -> None: + """Test that fallback returns all candidates when top_k is None.""" + result = reranker._fallback_rerank(mock_candidates, top_k=None) + assert len(result) == len(mock_candidates) + + +class TestRerank: + """Tests for the main rerank method.""" + + @pytest.mark.asyncio + async def test_empty_candidates(self, reranker: GeminiPointwiseReranker) -> None: + """Test rerank with empty candidates list.""" + result = await reranker.rerank("query", [], top_k=5) + assert result == [] + + @pytest.mark.asyncio + async def test_rerank_filters_by_threshold( + self, + reranker: GeminiPointwiseReranker, + mock_candidates: list[SemanticSearchResult], + ) -> None: + """Test that rerank filters results below score threshold.""" + + # Mock the _score_batch to return mixed scores + async def mock_score_batch(query, batch): + return { + idx: 8.0 if idx < 5 else 3.0 # First 5 pass threshold, rest don't + for idx, _ in batch + } + + with patch.object(reranker, "_score_batch", side_effect=mock_score_batch): + result = await reranker.rerank("test query", mock_candidates, top_k=10) + + # Only candidates with score >= 5.0 should be included + assert len(result) == 5 + for res in result: + assert float(res.metadata.get("reranker_score", 0)) >= 5.0 + + @pytest.mark.asyncio + async def test_rerank_preserves_metadata( + self, reranker: GeminiPointwiseReranker + ) -> None: + """Test that rerank preserves original metadata and adds new fields.""" + candidates = [ + SemanticSearchResult( + text="Test passage", + score=0.9, + metadata={"file_name": "test.txt", "chunk_id": "1"}, + ) + ] + + async def mock_score_batch(query, batch): + return {0: 8.0} + + with patch.object(reranker, "_score_batch", side_effect=mock_score_batch): + result = await reranker.rerank("query", candidates, top_k=1) + + assert len(result) == 1 + assert result[0].metadata["file_name"] == "test.txt" + assert result[0].metadata["chunk_id"] == "1" + assert result[0].metadata["original_score"] == "0.9" + assert result[0].metadata["reranker_score"] == "8.0" + + @pytest.mark.asyncio + async def test_rerank_uses_fallback_on_timeout( + self, + reranker: GeminiPointwiseReranker, + mock_candidates: list[SemanticSearchResult], + ) -> None: + """Test that rerank falls back to original scores on timeout.""" + + # Make _score_batch raise TimeoutError + async def mock_timeout(*args): + raise TimeoutError + + with patch.object(reranker, "_score_batch", side_effect=mock_timeout): + result = await reranker.rerank("query", mock_candidates, top_k=5) + + # Should return top-5 by original score + assert len(result) == 5 + # First result should have highest original score + assert result[0].score == mock_candidates[0].score + + @pytest.mark.asyncio + async def test_score_batch_maps_local_ids_to_original_indices( + self, + reranker: GeminiPointwiseReranker, + ) -> None: + """Ensure batch-local ids map back to the original candidate indices.""" + batch = [ + (2, SemanticSearchResult(text="p2", score=0.9, metadata={})), + (6, SemanticSearchResult(text="p6", score=0.8, metadata={})), + ] + + with patch.object(reranker, "_call_gemini", return_value='{"id0":9,"id1":6}'): + scores = await reranker._score_batch("query", batch) + + assert scores == {2: 9.0, 6: 6.0} + + @pytest.mark.asyncio + async def test_rerank_runs_batches_in_parallel( + self, + reranker: GeminiPointwiseReranker, + mock_candidates: list[SemanticSearchResult], + ) -> None: + """Catch accidental serial execution by asserting on wall-clock time.""" + reranker.num_batches = 4 + + async def slow_score_batch(_query, batch): + import asyncio + + await asyncio.sleep(0.2) + return {idx: 8.0 for idx, _ in batch} + + import time + + with patch.object(reranker, "_score_batch", side_effect=slow_score_batch): + start = time.perf_counter() + await reranker.rerank("query", mock_candidates, top_k=10) + elapsed = time.perf_counter() - start + + assert elapsed < 0.6 + + +class TestSettings: + """Tests for RerankerSettings.""" + + def test_default_settings(self) -> None: + """Test default reranker settings.""" + from flare_ai_kit.rag.vector.settings import RerankerSettings + + settings = RerankerSettings() + + assert settings.enabled is False + assert settings.model == "gemini-3-flash-preview" + assert settings.num_batches == 4 + assert settings.timeout_seconds == 5.0 + assert settings.score_threshold == 5.0 + assert settings.few_shot_examples is None + + def test_custom_settings(self) -> None: + """Test custom reranker settings.""" + from flare_ai_kit.rag.vector.settings import RerankerSettings + + settings = RerankerSettings( + enabled=True, + model="gemini-2.0-flash", # Override with older model + num_batches=8, + timeout_seconds=10.0, + score_threshold=7.0, + ) + + assert settings.enabled is True + assert settings.model == "gemini-2.0-flash" + assert settings.num_batches == 8 + assert settings.timeout_seconds == 10.0 + assert settings.score_threshold == 7.0