|
| 1 | +# pyright: reportMissingImports=false |
| 2 | +# pyright: reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false |
| 3 | +""" |
| 4 | +Demo: RAG with LLM-based Reranking |
| 5 | +
|
| 6 | +Demonstrates why reranking helps: compare vector-only retrieval vs LLM reranking on an |
| 7 | +adversarial doc set (high keyword overlap but low true relevance). |
| 8 | +""" |
| 9 | + |
| 10 | +import asyncio |
| 11 | +import sys |
| 12 | +from pathlib import Path |
| 13 | +from typing import Final |
| 14 | +from uuid import uuid4 |
| 15 | + |
| 16 | +from dotenv import load_dotenv |
| 17 | + |
| 18 | +# pyright: ignore[reportMissingImports] |
| 19 | +from qdrant_client import QdrantClient |
| 20 | + |
| 21 | +from flare_ai_kit.agent import AgentSettings |
| 22 | +from flare_ai_kit.common import SemanticSearchResult |
| 23 | +from flare_ai_kit.rag.vector.api import ( |
| 24 | + FixedSizeChunker, |
| 25 | + GeminiEmbedding, |
| 26 | + GeminiPointwiseReranker, |
| 27 | + LocalFileIndexer, |
| 28 | + QdrantRetriever, |
| 29 | + ingest_and_embed, |
| 30 | + upsert_to_qdrant, |
| 31 | +) |
| 32 | +from flare_ai_kit.rag.vector.settings import RerankerSettings, VectorDbSettings |
| 33 | + |
| 34 | + |
| 35 | +def _print_results( |
| 36 | + title: str, |
| 37 | + results: list[SemanticSearchResult], |
| 38 | + *, |
| 39 | + show_reranker_scores: bool = False, |
| 40 | + max_text_chars: int = 96, |
| 41 | +) -> None: |
| 42 | + print(f"\n{title} (top {len(results)}):\n") |
| 43 | + for i, res in enumerate(results, 1): |
| 44 | + source = res.metadata.get("file_name", "unknown") |
| 45 | + if show_reranker_scores: |
| 46 | + original_score = res.metadata.get("original_score", "N/A") |
| 47 | + reranker_score = res.metadata.get("reranker_score", "N/A") |
| 48 | + scores = f"orig={original_score}, rerank={reranker_score}/10" |
| 49 | + else: |
| 50 | + scores = f"vector={res.score:.3f}" |
| 51 | + preview = res.text.replace("\n", " ")[:max_text_chars] |
| 52 | + print(f" {i:>2}. [{source}] ({scores})") |
| 53 | + print(f" {preview}...") |
| 54 | + |
| 55 | + |
| 56 | +def _hit_rate(results: list[SemanticSearchResult], relevant_sources: set[str]) -> int: |
| 57 | + return sum(1 for r in results if r.metadata.get("file_name") in relevant_sources) |
| 58 | + |
| 59 | + |
| 60 | +async def main() -> None: # noqa: PLR0915 |
| 61 | + """Run the RAG + reranking demo.""" |
| 62 | + load_dotenv() |
| 63 | + agent = AgentSettings() # pyright: ignore[reportCallIssue] |
| 64 | + vector_db = VectorDbSettings(qdrant_batch_size=8) |
| 65 | + reranker_settings = RerankerSettings(enabled=True, score_threshold=5.0) |
| 66 | + |
| 67 | + if agent.gemini_api_key is None: |
| 68 | + print("GEMINI_API_KEY environment variable not set.") |
| 69 | + print("Please set the GEMINI_API_KEY environment variable to run this demo.") |
| 70 | + print("Example: export GEMINI_API_KEY='your_api_key_here'") |
| 71 | + sys.exit(1) |
| 72 | + |
| 73 | + demo_dir = Path("demo_rerank_data") |
| 74 | + demo_dir.mkdir(parents=True, exist_ok=True) |
| 75 | + |
| 76 | + # Intentionally include "keyword trap" docs that mention the right terms but are not helpful. |
| 77 | + documents: Final[list[tuple[str, str]]] = [ |
| 78 | + ( |
| 79 | + "how_to_reset_password.txt", |
| 80 | + "Password reset steps: On the login page, click 'Forgot Password'. " |
| 81 | + "Enter your email. Check your inbox for the reset link (valid for 24 hours). " |
| 82 | + "If the link expired, request a new one. If you still can't reset, contact support.", |
| 83 | + ), |
| 84 | + ( |
| 85 | + "reset_link_expired.txt", |
| 86 | + "If your password reset link has expired: return to the login page, click 'Forgot Password', " |
| 87 | + "and request a new link. For security, links expire after 24 hours and can only be used once.", |
| 88 | + ), |
| 89 | + ( |
| 90 | + "account_lockout_policy.txt", |
| 91 | + "Account lockout: after 5 failed login attempts, your account is locked for 15 minutes. " |
| 92 | + "Password reset does not unlock the account immediately. Wait 15 minutes or contact support.", |
| 93 | + ), |
| 94 | + ( |
| 95 | + "billing_address_change.txt", |
| 96 | + "To change your billing address: go to Settings → Billing → Address, update fields, and Save. " |
| 97 | + "This does not change your login email or password.", |
| 98 | + ), |
| 99 | + ( |
| 100 | + "two_factor_setup.txt", |
| 101 | + "Enable two-factor authentication (2FA): Settings → Security → Two-Factor. " |
| 102 | + "Scan the QR code in your authenticator app and enter the 6-digit code to confirm.", |
| 103 | + ), |
| 104 | + ( |
| 105 | + "keyword_trap_marketing_blog.txt", |
| 106 | + "Password reset, reset password, forgot password, reset link, password reset email — " |
| 107 | + "we love talking about password reset in our marketing. This post resets expectations. " |
| 108 | + "It does not contain support steps, but it repeats password reset keywords many times. " |
| 109 | + "Password reset. Reset password. Forgot password. Reset link.", |
| 110 | + ), |
| 111 | + ( |
| 112 | + "keyword_trap_shipping_reset.txt", |
| 113 | + "If your package tracking looks wrong, reset the tracking view in the shipping portal. " |
| 114 | + "This is unrelated to accounts, login, or password reset. It mentions reset and link and email.", |
| 115 | + ), |
| 116 | + ( |
| 117 | + "password_requirements.txt", |
| 118 | + "Password requirements: 8+ chars, one uppercase, one lowercase, one number. " |
| 119 | + "If your password reset fails, ensure the new password meets requirements.", |
| 120 | + ), |
| 121 | + ( |
| 122 | + "sso_login_note.txt", |
| 123 | + "If you signed up with Google/SSO, you may not have a password to reset. " |
| 124 | + "Use 'Continue with Google' or contact support to convert your account to email/password login.", |
| 125 | + ), |
| 126 | + ] |
| 127 | + |
| 128 | + try: |
| 129 | + for filename, content in documents: |
| 130 | + (demo_dir / filename).write_text(content, encoding="utf-8") |
| 131 | + |
| 132 | + print(f"Created {len(documents)} sample documents in {demo_dir}/\n") |
| 133 | + |
| 134 | + # Use larger chunks so each doc stays mostly intact for clearer ranking diffs. |
| 135 | + chunker = FixedSizeChunker(chunk_size=256) |
| 136 | + indexer = LocalFileIndexer( |
| 137 | + root_dir=str(demo_dir), chunker=chunker, allowed_extensions={".txt"} |
| 138 | + ) |
| 139 | + |
| 140 | + embedding_model = GeminiEmbedding( |
| 141 | + api_key=agent.gemini_api_key.get_secret_value(), |
| 142 | + model=vector_db.embeddings_model, |
| 143 | + output_dimensionality=vector_db.embeddings_output_dimensionality, |
| 144 | + ) |
| 145 | + |
| 146 | + data = ingest_and_embed(indexer, embedding_model, batch_size=8) |
| 147 | + print(f"Ingested and embedded {len(data)} chunks.\n") |
| 148 | + |
| 149 | + collection_name = f"reranker-demo-{uuid4().hex}" |
| 150 | + vector_size = 768 |
| 151 | + upsert_to_qdrant( |
| 152 | + data, |
| 153 | + str(vector_db.qdrant_url), |
| 154 | + collection_name, |
| 155 | + vector_size, |
| 156 | + batch_size=vector_db.qdrant_batch_size, |
| 157 | + ) |
| 158 | + print( |
| 159 | + f"Upserted {len(data)} vectors to Qdrant collection '{collection_name}'.\n" |
| 160 | + ) |
| 161 | + |
| 162 | + client = QdrantClient(str(vector_db.qdrant_url)) |
| 163 | + retriever = QdrantRetriever(client, embedding_model, vector_db) |
| 164 | + |
| 165 | + reranker = GeminiPointwiseReranker( |
| 166 | + api_key=agent.gemini_api_key.get_secret_value(), |
| 167 | + model=reranker_settings.model, |
| 168 | + num_batches=reranker_settings.num_batches, |
| 169 | + timeout_seconds=reranker_settings.timeout_seconds, |
| 170 | + score_threshold=reranker_settings.score_threshold, |
| 171 | + ) |
| 172 | + |
| 173 | + scenarios: Final[list[tuple[str, str, set[str]]]] = [ |
| 174 | + ( |
| 175 | + "Reset password (needs actionable steps; keyword traps present)", |
| 176 | + "My password reset link expired. How do I get a new one?", |
| 177 | + {"how_to_reset_password.txt", "reset_link_expired.txt"}, |
| 178 | + ), |
| 179 | + ( |
| 180 | + "SSO edge case (vector overlap is misleading)", |
| 181 | + "I signed up with Google SSO. Why doesn't 'Forgot Password' work?", |
| 182 | + {"sso_login_note.txt"}, |
| 183 | + ), |
| 184 | + ( |
| 185 | + "Security settings (should prefer the exact setup steps)", |
| 186 | + "How do I enable two-factor authentication?", |
| 187 | + {"two_factor_setup.txt"}, |
| 188 | + ), |
| 189 | + ] |
| 190 | + |
| 191 | + top_k_retrieve = 8 |
| 192 | + top_k_show = 5 |
| 193 | + for title, query, relevant_sources in scenarios: |
| 194 | + print("\n" + "=" * 80) |
| 195 | + print(f"{title}\nQuery: {query}") |
| 196 | + |
| 197 | + candidates = retriever.semantic_search( |
| 198 | + query, collection_name, top_k=top_k_retrieve |
| 199 | + ) |
| 200 | + baseline = candidates[:top_k_show] |
| 201 | + |
| 202 | + _print_results("Vector-only results", baseline, show_reranker_scores=False) |
| 203 | + |
| 204 | + print("\nReranking with LLM...\n") |
| 205 | + reranked = await reranker.rerank(query, candidates, top_k=top_k_show) |
| 206 | + _print_results( |
| 207 | + "Reranked results", |
| 208 | + reranked, |
| 209 | + show_reranker_scores=True, |
| 210 | + ) |
| 211 | + |
| 212 | + base_hits = _hit_rate(baseline, relevant_sources) |
| 213 | + rerank_hits = _hit_rate(reranked, relevant_sources) |
| 214 | + print("\nSummary:") |
| 215 | + print(f" - Relevant sources: {sorted(relevant_sources)}") |
| 216 | + print( |
| 217 | + f" - Hit@{top_k_show}: vector-only={base_hits}/{top_k_show}, reranked={rerank_hits}/{top_k_show}" |
| 218 | + ) |
| 219 | + |
| 220 | + print("\n" + "=" * 80) |
| 221 | + print("Demo complete.") |
| 222 | + print( |
| 223 | + f"Notes: reranker score threshold={reranker_settings.score_threshold}, " |
| 224 | + f"batches={reranker_settings.num_batches}, timeout={reranker_settings.timeout_seconds}s" |
| 225 | + ) |
| 226 | + finally: |
| 227 | + for filename, _ in documents: |
| 228 | + (demo_dir / filename).unlink(missing_ok=True) |
| 229 | + demo_dir.rmdir() |
| 230 | + |
| 231 | + |
| 232 | +if __name__ == "__main__": |
| 233 | + asyncio.run(main()) |
0 commit comments