diff --git a/examples/08_rag_reranker_demo.py b/examples/08_rag_reranker_demo.py
new file mode 100644
index 00000000..42a9831c
--- /dev/null
+++ b/examples/08_rag_reranker_demo.py
@@ -0,0 +1,233 @@
+# pyright: reportMissingImports=false
+# pyright: reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false
+"""
+Demo: RAG with LLM-based Reranking
+
+Demonstrates why reranking helps: compare vector-only retrieval vs LLM reranking on an
+adversarial doc set (high keyword overlap but low true relevance).
+"""
+
+import asyncio
+import sys
+from pathlib import Path
+from typing import Final
+from uuid import uuid4
+
+from dotenv import load_dotenv
+
+# pyright: ignore[reportMissingImports]
+from qdrant_client import QdrantClient
+
+from flare_ai_kit.agent import AgentSettings
+from flare_ai_kit.common import SemanticSearchResult
+from flare_ai_kit.rag.vector.api import (
+ FixedSizeChunker,
+ GeminiEmbedding,
+ GeminiPointwiseReranker,
+ LocalFileIndexer,
+ QdrantRetriever,
+ ingest_and_embed,
+ upsert_to_qdrant,
+)
+from flare_ai_kit.rag.vector.settings import RerankerSettings, VectorDbSettings
+
+
+def _print_results(
+ title: str,
+ results: list[SemanticSearchResult],
+ *,
+ show_reranker_scores: bool = False,
+ max_text_chars: int = 96,
+) -> None:
+ print(f"\n{title} (top {len(results)}):\n")
+ for i, res in enumerate(results, 1):
+ source = res.metadata.get("file_name", "unknown")
+ if show_reranker_scores:
+ original_score = res.metadata.get("original_score", "N/A")
+ reranker_score = res.metadata.get("reranker_score", "N/A")
+ scores = f"orig={original_score}, rerank={reranker_score}/10"
+ else:
+ scores = f"vector={res.score:.3f}"
+ preview = res.text.replace("\n", " ")[:max_text_chars]
+ print(f" {i:>2}. [{source}] ({scores})")
+ print(f" {preview}...")
+
+
+def _hit_rate(results: list[SemanticSearchResult], relevant_sources: set[str]) -> int:
+ return sum(1 for r in results if r.metadata.get("file_name") in relevant_sources)
+
+
+async def main() -> None: # noqa: PLR0915
+ """Run the RAG + reranking demo."""
+ load_dotenv()
+ agent = AgentSettings() # pyright: ignore[reportCallIssue]
+ vector_db = VectorDbSettings(qdrant_batch_size=8)
+ reranker_settings = RerankerSettings(enabled=True, score_threshold=5.0)
+
+ if agent.gemini_api_key is None:
+ print("GEMINI_API_KEY environment variable not set.")
+ print("Please set the GEMINI_API_KEY environment variable to run this demo.")
+ print("Example: export GEMINI_API_KEY='your_api_key_here'")
+ sys.exit(1)
+
+ demo_dir = Path("demo_rerank_data")
+ demo_dir.mkdir(parents=True, exist_ok=True)
+
+ # Intentionally include "keyword trap" docs that mention the right terms but are not helpful.
+ documents: Final[list[tuple[str, str]]] = [
+ (
+ "how_to_reset_password.txt",
+ "Password reset steps: On the login page, click 'Forgot Password'. "
+ "Enter your email. Check your inbox for the reset link (valid for 24 hours). "
+ "If the link expired, request a new one. If you still can't reset, contact support.",
+ ),
+ (
+ "reset_link_expired.txt",
+ "If your password reset link has expired: return to the login page, click 'Forgot Password', "
+ "and request a new link. For security, links expire after 24 hours and can only be used once.",
+ ),
+ (
+ "account_lockout_policy.txt",
+ "Account lockout: after 5 failed login attempts, your account is locked for 15 minutes. "
+ "Password reset does not unlock the account immediately. Wait 15 minutes or contact support.",
+ ),
+ (
+ "billing_address_change.txt",
+ "To change your billing address: go to Settings → Billing → Address, update fields, and Save. "
+ "This does not change your login email or password.",
+ ),
+ (
+ "two_factor_setup.txt",
+ "Enable two-factor authentication (2FA): Settings → Security → Two-Factor. "
+ "Scan the QR code in your authenticator app and enter the 6-digit code to confirm.",
+ ),
+ (
+ "keyword_trap_marketing_blog.txt",
+ "Password reset, reset password, forgot password, reset link, password reset email — "
+ "we love talking about password reset in our marketing. This post resets expectations. "
+ "It does not contain support steps, but it repeats password reset keywords many times. "
+ "Password reset. Reset password. Forgot password. Reset link.",
+ ),
+ (
+ "keyword_trap_shipping_reset.txt",
+ "If your package tracking looks wrong, reset the tracking view in the shipping portal. "
+ "This is unrelated to accounts, login, or password reset. It mentions reset and link and email.",
+ ),
+ (
+ "password_requirements.txt",
+ "Password requirements: 8+ chars, one uppercase, one lowercase, one number. "
+ "If your password reset fails, ensure the new password meets requirements.",
+ ),
+ (
+ "sso_login_note.txt",
+ "If you signed up with Google/SSO, you may not have a password to reset. "
+ "Use 'Continue with Google' or contact support to convert your account to email/password login.",
+ ),
+ ]
+
+ try:
+ for filename, content in documents:
+ (demo_dir / filename).write_text(content, encoding="utf-8")
+
+ print(f"Created {len(documents)} sample documents in {demo_dir}/\n")
+
+ # Use larger chunks so each doc stays mostly intact for clearer ranking diffs.
+ chunker = FixedSizeChunker(chunk_size=256)
+ indexer = LocalFileIndexer(
+ root_dir=str(demo_dir), chunker=chunker, allowed_extensions={".txt"}
+ )
+
+ embedding_model = GeminiEmbedding(
+ api_key=agent.gemini_api_key.get_secret_value(),
+ model=vector_db.embeddings_model,
+ output_dimensionality=vector_db.embeddings_output_dimensionality,
+ )
+
+ data = ingest_and_embed(indexer, embedding_model, batch_size=8)
+ print(f"Ingested and embedded {len(data)} chunks.\n")
+
+ collection_name = f"reranker-demo-{uuid4().hex}"
+ vector_size = 768
+ upsert_to_qdrant(
+ data,
+ str(vector_db.qdrant_url),
+ collection_name,
+ vector_size,
+ batch_size=vector_db.qdrant_batch_size,
+ )
+ print(
+ f"Upserted {len(data)} vectors to Qdrant collection '{collection_name}'.\n"
+ )
+
+ client = QdrantClient(str(vector_db.qdrant_url))
+ retriever = QdrantRetriever(client, embedding_model, vector_db)
+
+ reranker = GeminiPointwiseReranker(
+ api_key=agent.gemini_api_key.get_secret_value(),
+ model=reranker_settings.model,
+ num_batches=reranker_settings.num_batches,
+ timeout_seconds=reranker_settings.timeout_seconds,
+ score_threshold=reranker_settings.score_threshold,
+ )
+
+ scenarios: Final[list[tuple[str, str, set[str]]]] = [
+ (
+ "Reset password (needs actionable steps; keyword traps present)",
+ "My password reset link expired. How do I get a new one?",
+ {"how_to_reset_password.txt", "reset_link_expired.txt"},
+ ),
+ (
+ "SSO edge case (vector overlap is misleading)",
+ "I signed up with Google SSO. Why doesn't 'Forgot Password' work?",
+ {"sso_login_note.txt"},
+ ),
+ (
+ "Security settings (should prefer the exact setup steps)",
+ "How do I enable two-factor authentication?",
+ {"two_factor_setup.txt"},
+ ),
+ ]
+
+ top_k_retrieve = 8
+ top_k_show = 5
+ for title, query, relevant_sources in scenarios:
+ print("\n" + "=" * 80)
+ print(f"{title}\nQuery: {query}")
+
+ candidates = retriever.semantic_search(
+ query, collection_name, top_k=top_k_retrieve
+ )
+ baseline = candidates[:top_k_show]
+
+ _print_results("Vector-only results", baseline, show_reranker_scores=False)
+
+ print("\nReranking with LLM...\n")
+ reranked = await reranker.rerank(query, candidates, top_k=top_k_show)
+ _print_results(
+ "Reranked results",
+ reranked,
+ show_reranker_scores=True,
+ )
+
+ base_hits = _hit_rate(baseline, relevant_sources)
+ rerank_hits = _hit_rate(reranked, relevant_sources)
+ print("\nSummary:")
+ print(f" - Relevant sources: {sorted(relevant_sources)}")
+ print(
+ f" - Hit@{top_k_show}: vector-only={base_hits}/{top_k_show}, reranked={rerank_hits}/{top_k_show}"
+ )
+
+ print("\n" + "=" * 80)
+ print("Demo complete.")
+ print(
+ f"Notes: reranker score threshold={reranker_settings.score_threshold}, "
+ f"batches={reranker_settings.num_batches}, timeout={reranker_settings.timeout_seconds}s"
+ )
+ finally:
+ for filename, _ in documents:
+ (demo_dir / filename).unlink(missing_ok=True)
+ demo_dir.rmdir()
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/src/flare_ai_kit/common/exceptions.py b/src/flare_ai_kit/common/exceptions.py
index de8fdd5a..88e6bb78 100644
--- a/src/flare_ai_kit/common/exceptions.py
+++ b/src/flare_ai_kit/common/exceptions.py
@@ -138,3 +138,12 @@ class A2AClientError(FlareAIKitError):
# --- PDF Processing Errors ---
class PdfPostingError(FlareAIKitError):
"""Error class concerned with onchain PDF data posting errors."""
+
+
+# --- Reranker Errors ---
+class RerankerError(FlareAIKitError):
+ """Base exception for reranker errors."""
+
+
+class RerankerParseError(RerankerError):
+ """Raised when LLM response cannot be parsed."""
diff --git a/src/flare_ai_kit/rag/vector/api.py b/src/flare_ai_kit/rag/vector/api.py
index 2ea38bb2..c2b647b4 100644
--- a/src/flare_ai_kit/rag/vector/api.py
+++ b/src/flare_ai_kit/rag/vector/api.py
@@ -10,6 +10,7 @@
ingest_and_embed,
upsert_to_qdrant,
)
+from .reranker import BaseReranker, GeminiPointwiseReranker
from .responder import BaseResponder
from .retriever import BaseRetriever, QdrantRetriever
@@ -17,10 +18,12 @@
"BaseChunker",
"BaseEmbedding",
"BaseIndexer",
+ "BaseReranker",
"BaseResponder",
"BaseRetriever",
"FixedSizeChunker",
"GeminiEmbedding",
+ "GeminiPointwiseReranker",
"LocalFileIndexer",
"QdrantRetriever",
"VectorRAGPipeline",
diff --git a/src/flare_ai_kit/rag/vector/factory.py b/src/flare_ai_kit/rag/vector/factory.py
index 1673e728..b9a20cc6 100644
--- a/src/flare_ai_kit/rag/vector/factory.py
+++ b/src/flare_ai_kit/rag/vector/factory.py
@@ -1,20 +1,21 @@
"""Factory functions for creating RAG (Retrieval-Augmented Generation) pipelines."""
-from dataclasses import dataclass
+from dataclasses import dataclass, field
import structlog
from qdrant_client import QdrantClient
from flare_ai_kit.agent.settings import AgentSettings
-from flare_ai_kit.common import FlareAIKitError
+from flare_ai_kit.common import FlareAIKitError, SemanticSearchResult
from flare_ai_kit.rag.vector.embedding import GeminiEmbedding
+from flare_ai_kit.rag.vector.reranker import BaseReranker, GeminiPointwiseReranker
from flare_ai_kit.rag.vector.retriever import QdrantRetriever
-from flare_ai_kit.rag.vector.settings import VectorDbSettings
+from flare_ai_kit.rag.vector.settings import RerankerSettings, VectorDbSettings
logger = structlog.get_logger(__name__)
-@dataclass(frozen=True)
+@dataclass
class VectorRAGPipeline:
"""
A container for the components of a vector-based RAG pipeline.
@@ -25,31 +26,82 @@ class VectorRAGPipeline:
Attributes:
retriever: A retriever instance (e.g., QdrantRetriever) used to
embed chunks, store them, and perform semantic search.
+ reranker: Optional reranker instance for improving retrieval quality.
"""
retriever: QdrantRetriever
+ reranker: BaseReranker | None = field(default=None)
+
+ async def retrieve_and_rerank(
+ self,
+ query: str,
+ collection_name: str,
+ top_k_retrieve: int = 20,
+ top_k_rerank: int = 5,
+ score_threshold: float | None = None,
+ ) -> list[SemanticSearchResult]:
+ """
+ Retrieve candidates and optionally rerank them.
+
+ This is the main entry point for RAG retrieval with reranking.
+ First retrieves top_k_retrieve candidates, then reranks them
+ (if reranker is configured) and returns top_k_rerank results.
+
+ Args:
+ query: The search query string.
+ collection_name: Name of the Qdrant collection to search.
+ top_k_retrieve: Number of candidates to retrieve initially.
+ top_k_rerank: Number of results to return after reranking.
+ score_threshold: Optional minimum similarity score for retrieval.
+
+ Returns:
+ List of SemanticSearchResult, reranked if reranker is configured.
+
+ """
+ # Initial retrieval
+ candidates = self.retriever.semantic_search(
+ query=query,
+ collection_name=collection_name,
+ top_k=top_k_retrieve,
+ score_threshold=score_threshold,
+ )
+
+ if not candidates:
+ return []
+
+ # Rerank if configured
+ if self.reranker:
+ return await self.reranker.rerank(query, candidates, top_k=top_k_rerank)
+
+ # No reranker - just return top_k results
+ return candidates[:top_k_rerank]
def create_vector_rag_pipeline(
- vector_db_settings: VectorDbSettings, agent_settings: AgentSettings
+ vector_db_settings: VectorDbSettings,
+ agent_settings: AgentSettings,
+ reranker_settings: RerankerSettings | None = None,
) -> VectorRAGPipeline:
"""
Builds and configures a complete vector RAG pipeline.
This factory function initializes and wires together all the necessary
components for a vector-based RAG system, including the embedding model,
- the vector database client, the data indexer, and the retriever.
+ the vector database client, the data indexer, the retriever, and optionally
+ a reranker for improved retrieval quality.
Args:
vector_db_settings: Configuration specific to the vector database and
data chunking/indexing process.
agent_settings: Configuration for the AI agent, which includes the
necessary API keys for the embedding model.
+ reranker_settings: Optional configuration for LLM-based reranking.
+ If None or disabled, no reranker is initialized.
Returns:
- A `VectorRAGPipeline` object containing the fully configured indexer
- and retriever.
+ A `VectorRAGPipeline` object containing the fully configured indexer,
+ retriever, and optionally a reranker.
Raises:
FlareAIKitError: If essential configuration like the Qdrant URL or
@@ -90,12 +142,32 @@ def create_vector_rag_pipeline(
)
logger.debug("QdrantRetriever initialized.")
+ # 4. Optional Reranker
+ reranker: BaseReranker | None = None
+ if reranker_settings and reranker_settings.enabled:
+ reranker = GeminiPointwiseReranker(
+ api_key=agent_settings.gemini_api_key.get_secret_value(),
+ model=reranker_settings.model,
+ num_batches=reranker_settings.num_batches,
+ timeout_seconds=reranker_settings.timeout_seconds,
+ score_threshold=reranker_settings.score_threshold,
+ few_shot_examples=reranker_settings.few_shot_examples,
+ )
+ logger.debug(
+ "GeminiPointwiseReranker initialized.",
+ model=reranker_settings.model,
+ num_batches=reranker_settings.num_batches,
+ )
+
except Exception as e:
logger.exception("Failed to initialize a component for the RAG pipeline.")
msg = "Could not create vector RAG pipeline."
raise FlareAIKitError(msg) from e
- pipeline = VectorRAGPipeline(retriever=retriever)
- logger.info("Vector RAG pipeline created successfully.")
+ pipeline = VectorRAGPipeline(retriever=retriever, reranker=reranker)
+ logger.info(
+ "Vector RAG pipeline created successfully.",
+ has_reranker=reranker is not None,
+ )
return pipeline
diff --git a/src/flare_ai_kit/rag/vector/reranker/__init__.py b/src/flare_ai_kit/rag/vector/reranker/__init__.py
new file mode 100644
index 00000000..219ef698
--- /dev/null
+++ b/src/flare_ai_kit/rag/vector/reranker/__init__.py
@@ -0,0 +1,9 @@
+"""Reranker module for improving retrieval quality."""
+
+from .base import BaseReranker
+from .gemini_reranker import GeminiPointwiseReranker
+
+__all__ = [
+ "BaseReranker",
+ "GeminiPointwiseReranker",
+]
diff --git a/src/flare_ai_kit/rag/vector/reranker/base.py b/src/flare_ai_kit/rag/vector/reranker/base.py
new file mode 100644
index 00000000..dd4de65c
--- /dev/null
+++ b/src/flare_ai_kit/rag/vector/reranker/base.py
@@ -0,0 +1,37 @@
+"""Base class for reranking retrieved passages."""
+
+from abc import ABC, abstractmethod
+
+from flare_ai_kit.common import SemanticSearchResult
+
+
+class BaseReranker(ABC):
+ """
+ Abstract base class for reranking modules.
+
+ Rerankers take a query and a list of candidate passages from initial
+ retrieval and re-score them based on relevance to the query.
+ """
+
+ @abstractmethod
+ async def rerank(
+ self,
+ query: str,
+ candidates: list[SemanticSearchResult],
+ top_k: int | None = None,
+ ) -> list[SemanticSearchResult]:
+ """
+ Rerank candidate passages based on relevance to the query.
+
+ Args:
+ query: The user query to score passages against.
+ candidates: List of SemanticSearchResult from initial retrieval.
+ top_k: Optional limit on number of results to return.
+ If None, returns all candidates that pass the threshold.
+
+ Returns:
+ Reranked list of SemanticSearchResult sorted by relevance score
+ (highest first). The score field will contain the reranker score
+ normalized to 0-1 range.
+
+ """
diff --git a/src/flare_ai_kit/rag/vector/reranker/gemini_reranker.py b/src/flare_ai_kit/rag/vector/reranker/gemini_reranker.py
new file mode 100644
index 00000000..c2cf6906
--- /dev/null
+++ b/src/flare_ai_kit/rag/vector/reranker/gemini_reranker.py
@@ -0,0 +1,338 @@
+"""LLM-based pointwise reranker using Gemini API."""
+
+import asyncio
+import json
+import re
+from typing import Any, Final, cast, override
+
+import structlog
+from google import genai # pyright: ignore[reportMissingTypeStubs]
+from google.genai import types # pyright: ignore[reportMissingTypeStubs]
+
+from flare_ai_kit.common import RerankerError, RerankerParseError, SemanticSearchResult
+
+from .base import BaseReranker
+from .prompts import build_system_prompt, build_user_prompt
+
+logger = structlog.get_logger(__name__)
+
+MIN_SCORE: Final[float] = 0.0
+MAX_SCORE: Final[float] = 10.0
+
+
+class GeminiPointwiseReranker(BaseReranker):
+ """
+ LLM-based pointwise reranker using Google's Gemini API.
+
+ This reranker scores each passage on a 1-10 scale based on relevance
+ to the query. It uses parallel batch processing with round-robin
+ distribution for improved latency and accuracy.
+ """
+
+ def __init__( # noqa: PLR0913
+ self,
+ api_key: str,
+ model: str = "gemini-3-flash-preview",
+ num_batches: int = 4,
+ timeout_seconds: float = 5.0,
+ score_threshold: float = 5.0,
+ few_shot_examples: str | None = None,
+ ) -> None:
+ """
+ Initialize the Gemini pointwise reranker.
+
+ Args:
+ api_key: Google API key for accessing Gemini models.
+ model: Gemini model identifier (e.g., "gemini-2.0-flash").
+ num_batches: Number of parallel batches for scoring.
+ timeout_seconds: Timeout per batch in seconds.
+ score_threshold: Minimum score (0-10) to include a passage.
+ few_shot_examples: Optional custom few-shot examples for calibration.
+
+ """
+ self.model = model
+ self.num_batches = num_batches
+ self.timeout_seconds = timeout_seconds
+ self.score_threshold = score_threshold
+ self.few_shot_examples = few_shot_examples
+ self.client = genai.Client(api_key=api_key)
+ self._system_prompt = build_system_prompt(few_shot_examples)
+
+ @override
+ async def rerank(
+ self,
+ query: str,
+ candidates: list[SemanticSearchResult],
+ top_k: int | None = None,
+ ) -> list[SemanticSearchResult]:
+ """
+ Rerank candidate passages using parallel batched LLM scoring.
+
+ Args:
+ query: The user query to score passages against.
+ candidates: List of SemanticSearchResult from initial retrieval.
+ top_k: Optional limit on number of results to return.
+
+ Returns:
+ Reranked list of SemanticSearchResult sorted by relevance score.
+
+ """
+ if not candidates:
+ return []
+
+ logger.debug(
+ "Starting rerank",
+ num_candidates=len(candidates),
+ num_batches=self.num_batches,
+ top_k=top_k,
+ )
+
+ batches = self._distribute_to_batches(candidates, self.num_batches)
+
+ tasks = [self._score_batch(query, batch) for batch in batches if batch]
+
+ try:
+ results = await asyncio.wait_for(
+ asyncio.gather(*tasks, return_exceptions=True),
+ timeout=self.timeout_seconds * 2,
+ )
+ except TimeoutError:
+ logger.warning(
+ "Reranker timeout, using fallback",
+ timeout_seconds=self.timeout_seconds,
+ )
+ return self._fallback_rerank(candidates, top_k)
+
+ all_scores: dict[int, float] = {}
+ for result in results:
+ if isinstance(result, dict):
+ all_scores.update(result)
+ elif isinstance(result, Exception):
+ logger.warning("Batch failed", error=str(result))
+
+ if not all_scores:
+ logger.warning("No scores obtained from LLM, using fallback")
+ return self._fallback_rerank(candidates, top_k)
+
+ reranked: list[SemanticSearchResult] = []
+ for idx, candidate in enumerate(candidates):
+ score = all_scores.get(idx)
+ if score is not None and score >= self.score_threshold:
+ updated_metadata = dict(candidate.metadata)
+ updated_metadata["original_score"] = str(candidate.score)
+ updated_metadata["reranker_score"] = str(score)
+
+ reranked.append(
+ SemanticSearchResult(
+ text=candidate.text,
+ score=score / 10.0,
+ metadata=updated_metadata,
+ )
+ )
+
+ reranked.sort(key=lambda x: x.score, reverse=True)
+
+ logger.debug(
+ "Rerank complete",
+ input_count=len(candidates),
+ output_count=len(reranked),
+ filtered_count=len(candidates) - len(reranked),
+ )
+
+ return reranked[:top_k] if top_k else reranked
+
+ def _distribute_to_batches(
+ self,
+ candidates: list[SemanticSearchResult],
+ num_batches: int,
+ ) -> list[list[tuple[int, SemanticSearchResult]]]:
+ """
+ Distribute candidates to batches using round-robin assignment.
+
+ This ensures each batch sees a mix of high/medium/low similarity
+ passages, reducing positional bias from vector search ordering.
+
+ Args:
+ candidates: List of candidate passages.
+ num_batches: Number of batches to distribute to.
+
+ Returns:
+ List of batches, each containing (original_index, candidate) tuples.
+
+ """
+ batches: list[list[tuple[int, SemanticSearchResult]]] = [
+ [] for _ in range(num_batches)
+ ]
+ for idx, candidate in enumerate(candidates):
+ batch_idx = idx % num_batches
+ batches[batch_idx].append((idx, candidate))
+ return batches
+
+ async def _score_batch(
+ self,
+ query: str,
+ batch: list[tuple[int, SemanticSearchResult]],
+ ) -> dict[int, float]:
+ """
+ Score a single batch of passages via LLM.
+
+ Args:
+ query: The user query.
+ batch: List of (original_index, candidate) tuples.
+
+ Returns:
+ Dictionary mapping original indices to scores.
+
+ Raises:
+ RerankerError: If the LLM call fails.
+
+ """
+ # Prompt IDs are batch-local (id0..idN); map back to original indices below.
+ passages = [
+ (local_idx, candidate.text)
+ for local_idx, (_, candidate) in enumerate(batch)
+ ]
+ user_prompt = build_user_prompt(query, passages)
+
+ local_to_original = {i: orig_idx for i, (orig_idx, _) in enumerate(batch)}
+
+ try:
+ response = await asyncio.wait_for(
+ asyncio.to_thread(
+ self._call_gemini,
+ user_prompt,
+ ),
+ timeout=self.timeout_seconds,
+ )
+
+ local_scores = self._parse_scores(response)
+ return {
+ local_to_original[local_idx]: score
+ for local_idx, score in local_scores.items()
+ }
+
+ except TimeoutError:
+ logger.warning("Batch scoring timeout", batch_size=len(batch))
+ raise
+ except RerankerParseError:
+ raise
+ except Exception as e:
+ logger.warning("Batch scoring failed", error=str(e))
+ msg = f"Failed to score batch: {e}"
+ raise RerankerError(msg) from e
+
+ def _call_gemini(self, user_prompt: str) -> str:
+ """
+ Make synchronous Gemini API call.
+
+ Args:
+ user_prompt: The user prompt with query and passages.
+
+ Returns:
+ The LLM response text.
+
+ """
+ response = self.client.models.generate_content( # pyright: ignore[reportUnknownMemberType]
+ model=self.model,
+ contents=[
+ types.Content(
+ role="user",
+ parts=[types.Part(text=user_prompt)],
+ ),
+ ],
+ config=types.GenerateContentConfig(
+ system_instruction=self._system_prompt,
+ temperature=0.0, # Deterministic scoring
+ max_output_tokens=1024,
+ ),
+ )
+
+ if response.text:
+ return response.text
+ msg = "Gemini API returned empty response"
+ raise RerankerError(msg)
+
+ def _parse_scores(self, response: str) -> dict[int, float]:
+ """
+ Parse JSON scores from LLM response.
+
+ Expected format: {"id0":8,"id1":10,"id3":7}
+
+ Args:
+ response: Raw LLM response text.
+
+ Returns:
+ Dictionary mapping local indices to scores.
+
+ Raises:
+ RerankerParseError: If response cannot be parsed.
+
+ """
+ response = response.strip()
+
+ json_match = re.search(r"\{[^{}]*\}", response)
+ if json_match:
+ response = json_match.group()
+
+ try:
+ raw_data: Any = json.loads(response)
+ except json.JSONDecodeError as e:
+ logger.warning("Failed to parse JSON response", response=response[:200])
+ msg = f"Invalid JSON in response: {e}"
+ raise RerankerParseError(msg) from e
+
+ if not isinstance(raw_data, dict):
+ msg = f"Expected dict, got {type(raw_data).__name__}"
+ raise RerankerParseError(msg)
+
+ # Parse id keys and validate scores
+ data = cast("dict[str, Any]", raw_data)
+ scores: dict[int, float] = {}
+ for key, value in data.items():
+ # Extract index from "idN" format
+ if not key.startswith("id"):
+ logger.warning("Invalid key format", key=key)
+ continue
+
+ try:
+ idx = int(key[2:])
+ except ValueError:
+ logger.warning("Cannot parse index from key", key=key)
+ continue
+
+ # Validate score
+ if not isinstance(value, int | float):
+ logger.warning("Invalid score type", key=key, value=value)
+ continue
+
+ score = float(value)
+ if not MIN_SCORE <= score <= MAX_SCORE:
+ logger.warning("Score out of range", key=key, score=score)
+ continue
+
+ scores[idx] = score
+
+ return scores
+
+ def _fallback_rerank(
+ self,
+ candidates: list[SemanticSearchResult],
+ top_k: int | None = None,
+ ) -> list[SemanticSearchResult]:
+ """
+ Fallback: return candidates sorted by original embedding score.
+
+ Args:
+ candidates: List of candidate passages.
+ top_k: Optional limit on number of results.
+
+ Returns:
+ Candidates sorted by original score (highest first).
+
+ """
+ logger.info(
+ "Using fallback reranking (original scores)",
+ num_candidates=len(candidates),
+ )
+ sorted_candidates = sorted(candidates, key=lambda x: x.score, reverse=True)
+ return sorted_candidates[:top_k] if top_k else sorted_candidates
diff --git a/src/flare_ai_kit/rag/vector/reranker/prompts.py b/src/flare_ai_kit/rag/vector/reranker/prompts.py
new file mode 100644
index 00000000..440635a8
--- /dev/null
+++ b/src/flare_ai_kit/rag/vector/reranker/prompts.py
@@ -0,0 +1,139 @@
+"""Prompt templates for LLM-based pointwise reranking."""
+# ruff: noqa: E501
+# Long lines are intentional in prompts for readability
+
+POINTWISE_SYSTEM_PROMPT = """You are a customer support answer service. Your task is to evaluate help center passages and score their relevance to a given customer query for a retrieval augmented generation (RAG) system.
+
+Evaluation Process:
+1. Analyze the customer's query to identify both explicit needs and implicit context including underlying user goals
+2. Assess each passage's ability to directly resolve the query or provide substantive supporting information with actionable guidance
+3. Score based on how effectively the passage addresses the query's core intent while considering potential interpretations
+
+Grading Criteria:
+
+10: EXCEPTIONAL match - Contains exact step-by-step instructions that perfectly match the query's specific scenario. Must include all required parameters/context and resolve the issue completely without any ambiguity. Reserved for definitive solutions that exactly mirror the user's described situation and require no interpretation.
+
+9: NEAR-PERFECT solution - Contains all critical steps for resolution but may lack one minor non-essential detail. Addresses the precise query parameters with specialized information. Solution must be directly applicable without requiring adaptation or assumptions.
+
+8: STRONG MATCH - Provides complete technical resolution through specific instructions, but may require simple logical inferences for full application. Covers all essential components but might need minor contextualization.
+
+7: GOOD MATCH - Contains substantial relevant details that address core aspects of the query, but lacks one important element for complete resolution. Provides concrete guidance requiring some user interpretation.
+
+6: PARTIAL match - General guidance on the right topic but lacks the specifics for direct application. May only resolve a subset of the request.
+
+5: LIMITED relevance - Related context or approach, but indirect. Requires substantial effort to adapt to the user's exact need.
+
+4: TANGENTIAL - Mentions related concepts/keywords with little practical connection to the request. Minimal actionable value.
+
+3: VAGUE domain info - Talks about the general area but not the query's specifics. No concrete, actionable steps.
+
+2: TOKEN overlap - Shares isolated terms without context or intent aligned to the request. Similarity is coincidental.
+
+1: IRRELEVANT - Uses query terms in a completely unrelated way. No meaningful link to the user's goal.
+
+0: UNRELATED - No thematic or contextual connection to the query at all.
+
+
+Input Format:
+
+
+// The customer's question or request
+
+
+...
+...
+...
+
+
+
+Output Format:
+
+Return your response in a valid JSON (skip spaces):
+{{"id0":score0,"id1":score1,...}}
+
+Strict guidelines:
+- Return ONLY a well-formed valid JSON with passage IDs as keys
+- Each key must be a passage id in the format "idN"
+- Each score must be an integer between 5 to 10. EXCLUDE passages that score below 5 (i.e. 0, 1, 2, 3 or 4)
+- Integer values only, no decimals
+- Skip spaces in the JSON
+- No additional text or formatting
+- Maintain original passage ID order
+- Note: If NO passages score 5+, return empty JSON object
+
+
+{few_shot_section}"""
+
+POINTWISE_USER_TEMPLATE = """{query}
+
+{passages}
+"""
+
+DEFAULT_FEW_SHOT_EXAMPLES = """
+Example 1:
+How do I reset my password?
+
+To reset your password, click 'Forgot Password' on the login page, enter your email, and follow the link sent to your inbox.
+Our company was founded in 2010 and has offices worldwide.
+Password requirements include at least 8 characters with one uppercase letter.
+
+Output: {"id0":10,"id2":6}
+
+Example 2:
+What are your business hours?
+
+Contact our sales team for enterprise pricing options.
+We are open Monday to Friday, 9 AM to 5 PM EST. Weekend support is available via email only.
+
+Output: {"id1":10}
+"""
+
+
+def format_passages(passages: list[tuple[int, str]]) -> str:
+ """
+ Format passages for inclusion in the reranker prompt.
+
+ Args:
+ passages: List of tuples containing (passage_id, passage_text).
+
+ Returns:
+ Formatted string with passages in XML-like tags.
+
+ """
+ return "\n".join(
+ f"{text}" for idx, text in passages
+ )
+
+
+def build_system_prompt(few_shot_examples: str | None = None) -> str:
+ """
+ Build the system prompt with optional custom few-shot examples.
+
+ Args:
+ few_shot_examples: Optional custom few-shot examples. If None,
+ uses DEFAULT_FEW_SHOT_EXAMPLES.
+
+ Returns:
+ Complete system prompt with few-shot examples.
+
+ """
+ examples = few_shot_examples or DEFAULT_FEW_SHOT_EXAMPLES
+ few_shot_section = f"\n{examples}" if examples else ""
+ return POINTWISE_SYSTEM_PROMPT.format(few_shot_section=few_shot_section)
+
+
+def build_user_prompt(query: str, passages: list[tuple[int, str]]) -> str:
+ """
+ Build the user prompt with query and passages.
+
+ Args:
+ query: The user query.
+ passages: List of tuples containing (passage_id, passage_text).
+
+ Returns:
+ Formatted user prompt.
+
+ """
+ return POINTWISE_USER_TEMPLATE.format(
+ query=query, passages=format_passages(passages)
+ )
diff --git a/src/flare_ai_kit/rag/vector/settings.py b/src/flare_ai_kit/rag/vector/settings.py
index 25a26763..8d2ba6b4 100644
--- a/src/flare_ai_kit/rag/vector/settings.py
+++ b/src/flare_ai_kit/rag/vector/settings.py
@@ -95,3 +95,41 @@ class VectorDbSettings(BaseSettings):
postgres_dsn: PostgresDsn | None = Field(
default=None, description="DSN for PostgreSQL connection string."
)
+
+
+class RerankerSettings(BaseSettings):
+ """Configuration for LLM-based reranker."""
+
+ model_config = SettingsConfigDict(
+ env_prefix="RERANKER__",
+ env_file=".env",
+ extra="ignore",
+ )
+
+ enabled: bool = Field(
+ default=False,
+ description="Enable LLM-based reranking after retrieval.",
+ )
+ model: str = Field(
+ default="gemini-3-flash-preview",
+ description="Gemini model for reranking.",
+ )
+ num_batches: PositiveInt = Field(
+ default=4,
+ description="Number of parallel batches for scoring passages.",
+ )
+ timeout_seconds: float = Field(
+ default=5.0,
+ gt=0.0,
+ description="Timeout per batch in seconds.",
+ )
+ score_threshold: float = Field(
+ default=5.0,
+ ge=0.0,
+ le=10.0,
+ description="Minimum score (0-10) to include a passage in results.",
+ )
+ few_shot_examples: str | None = Field(
+ default=None,
+ description="Custom few-shot examples for calibration. If None, uses defaults.",
+ )
diff --git a/tests/unit/rag/__init__.py b/tests/unit/rag/__init__.py
new file mode 100644
index 00000000..0bd17669
--- /dev/null
+++ b/tests/unit/rag/__init__.py
@@ -0,0 +1 @@
+"""Unit tests for RAG module."""
diff --git a/tests/unit/rag/vector/__init__.py b/tests/unit/rag/vector/__init__.py
new file mode 100644
index 00000000..3ac46559
--- /dev/null
+++ b/tests/unit/rag/vector/__init__.py
@@ -0,0 +1 @@
+"""Unit tests for vector RAG module."""
diff --git a/tests/unit/rag/vector/reranker/__init__.py b/tests/unit/rag/vector/reranker/__init__.py
new file mode 100644
index 00000000..4a1b78bd
--- /dev/null
+++ b/tests/unit/rag/vector/reranker/__init__.py
@@ -0,0 +1 @@
+"""Unit tests for reranker module."""
diff --git a/tests/unit/rag/vector/reranker/test_gemini_reranker.py b/tests/unit/rag/vector/reranker/test_gemini_reranker.py
new file mode 100644
index 00000000..0ae44130
--- /dev/null
+++ b/tests/unit/rag/vector/reranker/test_gemini_reranker.py
@@ -0,0 +1,330 @@
+"""Unit tests for GeminiPointwiseReranker."""
+
+from unittest.mock import patch
+
+import pytest
+
+from flare_ai_kit.common import RerankerParseError, SemanticSearchResult
+from flare_ai_kit.rag.vector.reranker import GeminiPointwiseReranker
+
+
+@pytest.fixture
+def mock_candidates() -> list[SemanticSearchResult]:
+ """Create mock search candidates."""
+ return [
+ SemanticSearchResult(
+ text=f"Passage {i} about topic",
+ score=0.9 - i * 0.1,
+ metadata={"file_name": f"doc{i}.txt"},
+ )
+ for i in range(10)
+ ]
+
+
+@pytest.fixture
+def reranker() -> GeminiPointwiseReranker:
+ """Create reranker with mocked client."""
+ with patch("flare_ai_kit.rag.vector.reranker.gemini_reranker.genai"):
+ return GeminiPointwiseReranker(
+ api_key="test-api-key",
+ model="gemini-2.0-flash",
+ num_batches=4,
+ timeout_seconds=5.0,
+ score_threshold=5.0,
+ )
+
+
+class TestDistributeToBatches:
+ """Tests for _distribute_to_batches method."""
+
+ def test_round_robin_distribution(
+ self,
+ reranker: GeminiPointwiseReranker,
+ mock_candidates: list[SemanticSearchResult],
+ ) -> None:
+ """Test that candidates are distributed round-robin across batches."""
+ batches = reranker._distribute_to_batches(mock_candidates, 4)
+
+ assert len(batches) == 4
+
+ # Batch 0 should have indices 0, 4, 8
+ assert [idx for idx, _ in batches[0]] == [0, 4, 8]
+ # Batch 1 should have indices 1, 5, 9
+ assert [idx for idx, _ in batches[1]] == [1, 5, 9]
+ # Batch 2 should have indices 2, 6
+ assert [idx for idx, _ in batches[2]] == [2, 6]
+ # Batch 3 should have indices 3, 7
+ assert [idx for idx, _ in batches[3]] == [3, 7]
+
+ def test_single_batch(
+ self,
+ reranker: GeminiPointwiseReranker,
+ mock_candidates: list[SemanticSearchResult],
+ ) -> None:
+ """Test distribution with single batch."""
+ batches = reranker._distribute_to_batches(mock_candidates, 1)
+
+ assert len(batches) == 1
+ assert len(batches[0]) == 10
+ assert [idx for idx, _ in batches[0]] == list(range(10))
+
+ def test_more_batches_than_candidates(
+ self, reranker: GeminiPointwiseReranker
+ ) -> None:
+ """Test when there are more batches than candidates."""
+ candidates = [
+ SemanticSearchResult(text="p1", score=0.9, metadata={}),
+ SemanticSearchResult(text="p2", score=0.8, metadata={}),
+ ]
+ batches = reranker._distribute_to_batches(candidates, 4)
+
+ assert len(batches) == 4
+ # Only first 2 batches should have candidates
+ assert len(batches[0]) == 1
+ assert len(batches[1]) == 1
+ assert len(batches[2]) == 0
+ assert len(batches[3]) == 0
+
+
+class TestParseScores:
+ """Tests for _parse_scores method."""
+
+ def test_valid_json(self, reranker: GeminiPointwiseReranker) -> None:
+ """Test parsing valid JSON response."""
+ response = '{"id0":8,"id1":10,"id2":3}'
+ scores = reranker._parse_scores(response)
+
+ assert scores == {0: 8.0, 1: 10.0, 2: 3.0}
+
+ def test_json_with_whitespace(self, reranker: GeminiPointwiseReranker) -> None:
+ """Test parsing JSON with whitespace."""
+ response = ' { "id0": 8, "id1": 10 } '
+ scores = reranker._parse_scores(response)
+
+ assert scores == {0: 8.0, 1: 10.0}
+
+ def test_json_wrapped_in_text(self, reranker: GeminiPointwiseReranker) -> None:
+ """Test extracting JSON from text response."""
+ response = 'Here are the scores: {"id0":7,"id1":9} based on my analysis.'
+ scores = reranker._parse_scores(response)
+
+ assert scores == {0: 7.0, 1: 9.0}
+
+ def test_empty_json(self, reranker: GeminiPointwiseReranker) -> None:
+ """Test parsing empty JSON (no passages scored above threshold)."""
+ response = "{}"
+ scores = reranker._parse_scores(response)
+
+ assert scores == {}
+
+ def test_invalid_json(self, reranker: GeminiPointwiseReranker) -> None:
+ """Test that invalid JSON raises RerankerParseError."""
+ response = "not valid json"
+ with pytest.raises(RerankerParseError):
+ reranker._parse_scores(response)
+
+ def test_invalid_key_format(self, reranker: GeminiPointwiseReranker) -> None:
+ """Test that invalid keys are skipped."""
+ response = '{"id0":8,"invalid_key":10,"id1":7}'
+ scores = reranker._parse_scores(response)
+
+ # Only valid id keys should be included
+ assert scores == {0: 8.0, 1: 7.0}
+
+ def test_score_out_of_range(self, reranker: GeminiPointwiseReranker) -> None:
+ """Test that out-of-range scores are skipped."""
+ response = '{"id0":8,"id1":15,"id2":-1}'
+ scores = reranker._parse_scores(response)
+
+ # Only valid scores (0-10) should be included
+ assert scores == {0: 8.0}
+
+
+class TestFallbackRerank:
+ """Tests for _fallback_rerank method."""
+
+ def test_returns_sorted_by_original_score(
+ self,
+ reranker: GeminiPointwiseReranker,
+ mock_candidates: list[SemanticSearchResult],
+ ) -> None:
+ """Test that fallback returns candidates sorted by original score."""
+ # Shuffle the order
+ shuffled = mock_candidates[5:] + mock_candidates[:5]
+ result = reranker._fallback_rerank(shuffled, top_k=5)
+
+ assert len(result) == 5
+ # Should be sorted by score descending
+ for i in range(len(result) - 1):
+ assert result[i].score >= result[i + 1].score
+
+ def test_respects_top_k(
+ self,
+ reranker: GeminiPointwiseReranker,
+ mock_candidates: list[SemanticSearchResult],
+ ) -> None:
+ """Test that fallback respects top_k limit."""
+ result = reranker._fallback_rerank(mock_candidates, top_k=3)
+ assert len(result) == 3
+
+ def test_returns_all_when_top_k_none(
+ self,
+ reranker: GeminiPointwiseReranker,
+ mock_candidates: list[SemanticSearchResult],
+ ) -> None:
+ """Test that fallback returns all candidates when top_k is None."""
+ result = reranker._fallback_rerank(mock_candidates, top_k=None)
+ assert len(result) == len(mock_candidates)
+
+
+class TestRerank:
+ """Tests for the main rerank method."""
+
+ @pytest.mark.asyncio
+ async def test_empty_candidates(self, reranker: GeminiPointwiseReranker) -> None:
+ """Test rerank with empty candidates list."""
+ result = await reranker.rerank("query", [], top_k=5)
+ assert result == []
+
+ @pytest.mark.asyncio
+ async def test_rerank_filters_by_threshold(
+ self,
+ reranker: GeminiPointwiseReranker,
+ mock_candidates: list[SemanticSearchResult],
+ ) -> None:
+ """Test that rerank filters results below score threshold."""
+
+ # Mock the _score_batch to return mixed scores
+ async def mock_score_batch(query, batch):
+ return {
+ idx: 8.0 if idx < 5 else 3.0 # First 5 pass threshold, rest don't
+ for idx, _ in batch
+ }
+
+ with patch.object(reranker, "_score_batch", side_effect=mock_score_batch):
+ result = await reranker.rerank("test query", mock_candidates, top_k=10)
+
+ # Only candidates with score >= 5.0 should be included
+ assert len(result) == 5
+ for res in result:
+ assert float(res.metadata.get("reranker_score", 0)) >= 5.0
+
+ @pytest.mark.asyncio
+ async def test_rerank_preserves_metadata(
+ self, reranker: GeminiPointwiseReranker
+ ) -> None:
+ """Test that rerank preserves original metadata and adds new fields."""
+ candidates = [
+ SemanticSearchResult(
+ text="Test passage",
+ score=0.9,
+ metadata={"file_name": "test.txt", "chunk_id": "1"},
+ )
+ ]
+
+ async def mock_score_batch(query, batch):
+ return {0: 8.0}
+
+ with patch.object(reranker, "_score_batch", side_effect=mock_score_batch):
+ result = await reranker.rerank("query", candidates, top_k=1)
+
+ assert len(result) == 1
+ assert result[0].metadata["file_name"] == "test.txt"
+ assert result[0].metadata["chunk_id"] == "1"
+ assert result[0].metadata["original_score"] == "0.9"
+ assert result[0].metadata["reranker_score"] == "8.0"
+
+ @pytest.mark.asyncio
+ async def test_rerank_uses_fallback_on_timeout(
+ self,
+ reranker: GeminiPointwiseReranker,
+ mock_candidates: list[SemanticSearchResult],
+ ) -> None:
+ """Test that rerank falls back to original scores on timeout."""
+
+ # Make _score_batch raise TimeoutError
+ async def mock_timeout(*args):
+ raise TimeoutError
+
+ with patch.object(reranker, "_score_batch", side_effect=mock_timeout):
+ result = await reranker.rerank("query", mock_candidates, top_k=5)
+
+ # Should return top-5 by original score
+ assert len(result) == 5
+ # First result should have highest original score
+ assert result[0].score == mock_candidates[0].score
+
+ @pytest.mark.asyncio
+ async def test_score_batch_maps_local_ids_to_original_indices(
+ self,
+ reranker: GeminiPointwiseReranker,
+ ) -> None:
+ """Ensure batch-local ids map back to the original candidate indices."""
+ batch = [
+ (2, SemanticSearchResult(text="p2", score=0.9, metadata={})),
+ (6, SemanticSearchResult(text="p6", score=0.8, metadata={})),
+ ]
+
+ with patch.object(reranker, "_call_gemini", return_value='{"id0":9,"id1":6}'):
+ scores = await reranker._score_batch("query", batch)
+
+ assert scores == {2: 9.0, 6: 6.0}
+
+ @pytest.mark.asyncio
+ async def test_rerank_runs_batches_in_parallel(
+ self,
+ reranker: GeminiPointwiseReranker,
+ mock_candidates: list[SemanticSearchResult],
+ ) -> None:
+ """Catch accidental serial execution by asserting on wall-clock time."""
+ reranker.num_batches = 4
+
+ async def slow_score_batch(_query, batch):
+ import asyncio
+
+ await asyncio.sleep(0.2)
+ return {idx: 8.0 for idx, _ in batch}
+
+ import time
+
+ with patch.object(reranker, "_score_batch", side_effect=slow_score_batch):
+ start = time.perf_counter()
+ await reranker.rerank("query", mock_candidates, top_k=10)
+ elapsed = time.perf_counter() - start
+
+ assert elapsed < 0.6
+
+
+class TestSettings:
+ """Tests for RerankerSettings."""
+
+ def test_default_settings(self) -> None:
+ """Test default reranker settings."""
+ from flare_ai_kit.rag.vector.settings import RerankerSettings
+
+ settings = RerankerSettings()
+
+ assert settings.enabled is False
+ assert settings.model == "gemini-3-flash-preview"
+ assert settings.num_batches == 4
+ assert settings.timeout_seconds == 5.0
+ assert settings.score_threshold == 5.0
+ assert settings.few_shot_examples is None
+
+ def test_custom_settings(self) -> None:
+ """Test custom reranker settings."""
+ from flare_ai_kit.rag.vector.settings import RerankerSettings
+
+ settings = RerankerSettings(
+ enabled=True,
+ model="gemini-2.0-flash", # Override with older model
+ num_batches=8,
+ timeout_seconds=10.0,
+ score_threshold=7.0,
+ )
+
+ assert settings.enabled is True
+ assert settings.model == "gemini-2.0-flash"
+ assert settings.num_batches == 8
+ assert settings.timeout_seconds == 10.0
+ assert settings.score_threshold == 7.0