Skip to content

Commit dddac63

Browse files
committed
feat(rag): add LLM-based pointwise reranker to RAG pipeline
1 parent 2dfc1f1 commit dddac63

File tree

13 files changed

+1221
-10
lines changed

13 files changed

+1221
-10
lines changed

examples/08_rag_reranker_demo.py

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
# pyright: reportMissingImports=false
2+
# pyright: reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false
3+
"""
4+
Demo: RAG with LLM-based Reranking
5+
6+
Demonstrates why reranking helps: compare vector-only retrieval vs LLM reranking on an
7+
adversarial doc set (high keyword overlap but low true relevance).
8+
"""
9+
10+
import asyncio
11+
import sys
12+
from pathlib import Path
13+
from typing import Final
14+
from uuid import uuid4
15+
16+
from dotenv import load_dotenv
17+
18+
# pyright: ignore[reportMissingImports]
19+
from qdrant_client import QdrantClient
20+
21+
from flare_ai_kit.agent import AgentSettings
22+
from flare_ai_kit.common import SemanticSearchResult
23+
from flare_ai_kit.rag.vector.api import (
24+
FixedSizeChunker,
25+
GeminiEmbedding,
26+
GeminiPointwiseReranker,
27+
LocalFileIndexer,
28+
QdrantRetriever,
29+
ingest_and_embed,
30+
upsert_to_qdrant,
31+
)
32+
from flare_ai_kit.rag.vector.settings import RerankerSettings, VectorDbSettings
33+
34+
35+
def _print_results(
36+
title: str,
37+
results: list[SemanticSearchResult],
38+
*,
39+
show_reranker_scores: bool = False,
40+
max_text_chars: int = 96,
41+
) -> None:
42+
print(f"\n{title} (top {len(results)}):\n")
43+
for i, res in enumerate(results, 1):
44+
source = res.metadata.get("file_name", "unknown")
45+
if show_reranker_scores:
46+
original_score = res.metadata.get("original_score", "N/A")
47+
reranker_score = res.metadata.get("reranker_score", "N/A")
48+
scores = f"orig={original_score}, rerank={reranker_score}/10"
49+
else:
50+
scores = f"vector={res.score:.3f}"
51+
preview = res.text.replace("\n", " ")[:max_text_chars]
52+
print(f" {i:>2}. [{source}] ({scores})")
53+
print(f" {preview}...")
54+
55+
56+
def _hit_rate(results: list[SemanticSearchResult], relevant_sources: set[str]) -> int:
57+
return sum(1 for r in results if r.metadata.get("file_name") in relevant_sources)
58+
59+
60+
async def main() -> None: # noqa: PLR0915
61+
"""Run the RAG + reranking demo."""
62+
load_dotenv()
63+
agent = AgentSettings() # pyright: ignore[reportCallIssue]
64+
vector_db = VectorDbSettings(qdrant_batch_size=8)
65+
reranker_settings = RerankerSettings(enabled=True, score_threshold=5.0)
66+
67+
if agent.gemini_api_key is None:
68+
print("GEMINI_API_KEY environment variable not set.")
69+
print("Please set the GEMINI_API_KEY environment variable to run this demo.")
70+
print("Example: export GEMINI_API_KEY='your_api_key_here'")
71+
sys.exit(1)
72+
73+
demo_dir = Path("demo_rerank_data")
74+
demo_dir.mkdir(parents=True, exist_ok=True)
75+
76+
# Intentionally include "keyword trap" docs that mention the right terms but are not helpful.
77+
documents: Final[list[tuple[str, str]]] = [
78+
(
79+
"how_to_reset_password.txt",
80+
"Password reset steps: On the login page, click 'Forgot Password'. "
81+
"Enter your email. Check your inbox for the reset link (valid for 24 hours). "
82+
"If the link expired, request a new one. If you still can't reset, contact support.",
83+
),
84+
(
85+
"reset_link_expired.txt",
86+
"If your password reset link has expired: return to the login page, click 'Forgot Password', "
87+
"and request a new link. For security, links expire after 24 hours and can only be used once.",
88+
),
89+
(
90+
"account_lockout_policy.txt",
91+
"Account lockout: after 5 failed login attempts, your account is locked for 15 minutes. "
92+
"Password reset does not unlock the account immediately. Wait 15 minutes or contact support.",
93+
),
94+
(
95+
"billing_address_change.txt",
96+
"To change your billing address: go to Settings → Billing → Address, update fields, and Save. "
97+
"This does not change your login email or password.",
98+
),
99+
(
100+
"two_factor_setup.txt",
101+
"Enable two-factor authentication (2FA): Settings → Security → Two-Factor. "
102+
"Scan the QR code in your authenticator app and enter the 6-digit code to confirm.",
103+
),
104+
(
105+
"keyword_trap_marketing_blog.txt",
106+
"Password reset, reset password, forgot password, reset link, password reset email — "
107+
"we love talking about password reset in our marketing. This post resets expectations. "
108+
"It does not contain support steps, but it repeats password reset keywords many times. "
109+
"Password reset. Reset password. Forgot password. Reset link.",
110+
),
111+
(
112+
"keyword_trap_shipping_reset.txt",
113+
"If your package tracking looks wrong, reset the tracking view in the shipping portal. "
114+
"This is unrelated to accounts, login, or password reset. It mentions reset and link and email.",
115+
),
116+
(
117+
"password_requirements.txt",
118+
"Password requirements: 8+ chars, one uppercase, one lowercase, one number. "
119+
"If your password reset fails, ensure the new password meets requirements.",
120+
),
121+
(
122+
"sso_login_note.txt",
123+
"If you signed up with Google/SSO, you may not have a password to reset. "
124+
"Use 'Continue with Google' or contact support to convert your account to email/password login.",
125+
),
126+
]
127+
128+
try:
129+
for filename, content in documents:
130+
(demo_dir / filename).write_text(content, encoding="utf-8")
131+
132+
print(f"Created {len(documents)} sample documents in {demo_dir}/\n")
133+
134+
# Use larger chunks so each doc stays mostly intact for clearer ranking diffs.
135+
chunker = FixedSizeChunker(chunk_size=256)
136+
indexer = LocalFileIndexer(
137+
root_dir=str(demo_dir), chunker=chunker, allowed_extensions={".txt"}
138+
)
139+
140+
embedding_model = GeminiEmbedding(
141+
api_key=agent.gemini_api_key.get_secret_value(),
142+
model=vector_db.embeddings_model,
143+
output_dimensionality=vector_db.embeddings_output_dimensionality,
144+
)
145+
146+
data = ingest_and_embed(indexer, embedding_model, batch_size=8)
147+
print(f"Ingested and embedded {len(data)} chunks.\n")
148+
149+
collection_name = f"reranker-demo-{uuid4().hex}"
150+
vector_size = 768
151+
upsert_to_qdrant(
152+
data,
153+
str(vector_db.qdrant_url),
154+
collection_name,
155+
vector_size,
156+
batch_size=vector_db.qdrant_batch_size,
157+
)
158+
print(
159+
f"Upserted {len(data)} vectors to Qdrant collection '{collection_name}'.\n"
160+
)
161+
162+
client = QdrantClient(str(vector_db.qdrant_url))
163+
retriever = QdrantRetriever(client, embedding_model, vector_db)
164+
165+
reranker = GeminiPointwiseReranker(
166+
api_key=agent.gemini_api_key.get_secret_value(),
167+
model=reranker_settings.model,
168+
num_batches=reranker_settings.num_batches,
169+
timeout_seconds=reranker_settings.timeout_seconds,
170+
score_threshold=reranker_settings.score_threshold,
171+
)
172+
173+
scenarios: Final[list[tuple[str, str, set[str]]]] = [
174+
(
175+
"Reset password (needs actionable steps; keyword traps present)",
176+
"My password reset link expired. How do I get a new one?",
177+
{"how_to_reset_password.txt", "reset_link_expired.txt"},
178+
),
179+
(
180+
"SSO edge case (vector overlap is misleading)",
181+
"I signed up with Google SSO. Why doesn't 'Forgot Password' work?",
182+
{"sso_login_note.txt"},
183+
),
184+
(
185+
"Security settings (should prefer the exact setup steps)",
186+
"How do I enable two-factor authentication?",
187+
{"two_factor_setup.txt"},
188+
),
189+
]
190+
191+
top_k_retrieve = 8
192+
top_k_show = 5
193+
for title, query, relevant_sources in scenarios:
194+
print("\n" + "=" * 80)
195+
print(f"{title}\nQuery: {query}")
196+
197+
candidates = retriever.semantic_search(
198+
query, collection_name, top_k=top_k_retrieve
199+
)
200+
baseline = candidates[:top_k_show]
201+
202+
_print_results("Vector-only results", baseline, show_reranker_scores=False)
203+
204+
print("\nReranking with LLM...\n")
205+
reranked = await reranker.rerank(query, candidates, top_k=top_k_show)
206+
_print_results(
207+
"Reranked results",
208+
reranked,
209+
show_reranker_scores=True,
210+
)
211+
212+
base_hits = _hit_rate(baseline, relevant_sources)
213+
rerank_hits = _hit_rate(reranked, relevant_sources)
214+
print("\nSummary:")
215+
print(f" - Relevant sources: {sorted(relevant_sources)}")
216+
print(
217+
f" - Hit@{top_k_show}: vector-only={base_hits}/{top_k_show}, reranked={rerank_hits}/{top_k_show}"
218+
)
219+
220+
print("\n" + "=" * 80)
221+
print("Demo complete.")
222+
print(
223+
f"Notes: reranker score threshold={reranker_settings.score_threshold}, "
224+
f"batches={reranker_settings.num_batches}, timeout={reranker_settings.timeout_seconds}s"
225+
)
226+
finally:
227+
for filename, _ in documents:
228+
(demo_dir / filename).unlink(missing_ok=True)
229+
demo_dir.rmdir()
230+
231+
232+
if __name__ == "__main__":
233+
asyncio.run(main())

src/flare_ai_kit/common/exceptions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,12 @@ class A2AClientError(FlareAIKitError):
138138
# --- PDF Processing Errors ---
139139
class PdfPostingError(FlareAIKitError):
140140
"""Error class concerned with onchain PDF data posting errors."""
141+
142+
143+
# --- Reranker Errors ---
144+
class RerankerError(FlareAIKitError):
145+
"""Base exception for reranker errors."""
146+
147+
148+
class RerankerParseError(RerankerError):
149+
"""Raised when LLM response cannot be parsed."""

src/flare_ai_kit/rag/vector/api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,20 @@
1010
ingest_and_embed,
1111
upsert_to_qdrant,
1212
)
13+
from .reranker import BaseReranker, GeminiPointwiseReranker
1314
from .responder import BaseResponder
1415
from .retriever import BaseRetriever, QdrantRetriever
1516

1617
__all__ = [
1718
"BaseChunker",
1819
"BaseEmbedding",
1920
"BaseIndexer",
21+
"BaseReranker",
2022
"BaseResponder",
2123
"BaseRetriever",
2224
"FixedSizeChunker",
2325
"GeminiEmbedding",
26+
"GeminiPointwiseReranker",
2427
"LocalFileIndexer",
2528
"QdrantRetriever",
2629
"VectorRAGPipeline",

0 commit comments

Comments
 (0)