-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhybrid.py
More file actions
240 lines (204 loc) · 9.69 KB
/
hybrid.py
File metadata and controls
240 lines (204 loc) · 9.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
"""
src/hybrid.py
-------------
Hybrid retriever combining BM25 keyword search and FAISS semantic search,
fused with Reciprocal Rank Fusion (RRF).
Designed to plug into the existing run_rag() pipeline in rag_pipeline.py
as a drop-in replacement for the semantic retriever:
hybrid_retriever = load_hybrid_retriever(
bm25_index_path="data/processed/tokenisation/bm25_index_mini.pkl",
faiss_store_path="data/processed/embeddings",
k=5,
)
answer = run_rag(hybrid_retriever, "Best coffee beans for espresso")
The HybridRetriever class extends LangChain's BaseRetriever so it is fully
compatible with the | (pipe) operator used in rag_pipeline.py:
rag_chain = (
{
"context": hybrid_retriever | RunnableLambda(build_context),
"question": RunnablePassthrough(),
}
| prompt_template
| llm
| StrOutputParser()
)
"""
from __future__ import annotations
import logging
from typing import Any
from langchain_community.retrievers import BM25Retriever
from langchain_community.vectorstores import FAISS
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import Field
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# HybridRetriever
# ---------------------------------------------------------------------------
class HybridRetriever(BaseRetriever):
"""
Combines BM25 keyword retrieval and FAISS semantic retrieval using
Reciprocal Rank Fusion (RRF) to produce a unified ranked document list.
RRF score for document d across retriever r:
score(d) = weight_r * (1 / (rrf_c + rank(d, r)))
Documents appearing in both retrievers accumulate scores from both,
naturally promoting results that are relevant by both keyword and meaning.
Parameters
----------
bm25_retriever : Fitted LangChain BM25Retriever (from bm25.load())
semantic_store : Loaded FAISS vectorstore (from semantic.load_vector_store())
k : Number of final documents to return
rrf_c : RRF constant — dampens the impact of rank differences.
Standard value is 60; lower = top ranks matter more.
bm25_weight : RRF weight for BM25 results (keyword signal)
semantic_weight : RRF weight for semantic results (meaning signal)
fetch_multiplier : Fetch this multiple of k from each retriever before fusing.
More candidates = better fusion quality. Default: 3.
"""
bm25_retriever: Any = Field(...)
semantic_store: Any = Field(...)
k: int = Field(default=5)
rrf_c: int = Field(default=60)
bm25_weight: float = Field(default=0.5)
semantic_weight: float = Field(default=0.5)
fetch_multiplier: int = Field(default=3)
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]:
"""
Core retrieval logic called by LangChain when the retriever is invoked.
Steps
-----
1. Fetch candidates from BM25 and FAISS independently
2. Assign RRF scores weighted by retriever confidence
3. Deduplicate by parent_asin, accumulating scores for shared hits
4. Sort by fused RRF score and return top-k Documents
"""
fetch_k = self.k * self.fetch_multiplier
# ── 1. BM25 retrieval ────────────────────────────────────────────────
self.bm25_retriever.k = fetch_k
try:
bm25_docs: list[Document] = self.bm25_retriever.invoke(query)
logger.debug("BM25 returned %d docs for query: %r", len(bm25_docs), query)
except Exception as exc:
logger.warning("BM25 retrieval failed: %s — using empty list.", exc)
bm25_docs = []
# ── 2. Semantic retrieval ────────────────────────────────────────────
# similarity_search returns list[Document] (no scores needed — rank is enough for RRF)
try:
semantic_docs: list[Document] = self.semantic_store.similarity_search(
query, k=fetch_k
)
logger.debug(
"Semantic returned %d docs for query: %r", len(semantic_docs), query
)
except Exception as exc:
logger.warning("Semantic retrieval failed: %s — using empty list.", exc)
semantic_docs = []
# ── 3. RRF fusion ────────────────────────────────────────────────────
rrf_scores: dict[str, float] = {}
doc_map: dict[str, Document] = {}
def _asin_key(doc: Document, fallback: str) -> str:
"""Use parent_asin as the dedup key; fall back to a content prefix."""
return doc.metadata.get("parent_asin") or fallback
for rank, doc in enumerate(bm25_docs):
key = _asin_key(doc, f"bm25_{rank}")
score = self.bm25_weight / (self.rrf_c + rank + 1)
rrf_scores[key] = rrf_scores.get(key, 0.0) + score
doc_map[key] = doc # BM25 docs have richer metadata (top_reviews etc.)
for rank, doc in enumerate(semantic_docs):
key = _asin_key(doc, f"sem_{rank}")
score = self.semantic_weight / (self.rrf_c + rank + 1)
rrf_scores[key] = rrf_scores.get(key, 0.0) + score
# Only add to doc_map if BM25 didn't already supply this product
# (BM25 metadata is richer — has top_reviews, image_url, etc.)
if key not in doc_map:
doc_map[key] = doc
# ── 4. Sort and truncate ─────────────────────────────────────────────
ranked_keys = sorted(rrf_scores, key=lambda k: rrf_scores[k], reverse=True)
top_docs = [doc_map[key] for key in ranked_keys[: self.k]]
# Attach fused score to metadata — useful for app display
for key, doc in zip(ranked_keys, top_docs):
doc.metadata["hybrid_score"] = round(rrf_scores[key], 6)
# Record which retriever(s) contributed to this result
in_bm25 = any(
_asin_key(d, f"bm25_{i}") == key for i, d in enumerate(bm25_docs)
)
in_sem = any(
_asin_key(d, f"sem_{i}") == key for i, d in enumerate(semantic_docs)
)
if in_bm25 and in_sem:
doc.metadata["retrieval_source"] = "hybrid"
elif in_bm25:
doc.metadata["retrieval_source"] = "bm25"
else:
doc.metadata["retrieval_source"] = "semantic"
logger.info(
"HybridRetriever: BM25=%d, Semantic=%d → fused=%d (returning top %d)",
len(bm25_docs), len(semantic_docs), len(rrf_scores), len(top_docs),
)
return top_docs
# ---------------------------------------------------------------------------
# Convenience loader
# ---------------------------------------------------------------------------
def load_hybrid_retriever(
bm25_index_path: str = "data/processed/tokenisation/bm25_index_mini.pkl",
faiss_store_path: str = "data/processed/embeddings",
k: int = 5,
bm25_weight: float = 0.5,
semantic_weight: float = 0.5,
rrf_c: int = 60,
fetch_multiplier: int = 3,
) -> HybridRetriever:
"""
Load both indexes from disk and return a ready-to-use HybridRetriever.
Call this once in your notebook or app.py, then pass the result to run_rag().
Parameters
----------
bm25_index_path : Path to the pickled BM25Retriever (from bm25.build_and_save())
faiss_store_path : Directory containing index.faiss + index.pkl
(from semantic.build_and_save_vector_store())
k : Number of documents to return per query
bm25_weight : RRF weight for BM25 (keyword signal). Default 0.5.
semantic_weight : RRF weight for semantic (meaning signal). Default 0.5.
Weights don't need to sum to 1 but relative scale matters.
rrf_c : RRF rank-dampening constant. Default 60 (standard).
fetch_multiplier : Candidates to fetch per retriever = k * fetch_multiplier.
Returns
-------
HybridRetriever
A LangChain-compatible retriever pipeable with |.
Example
-------
>>> from src.hybrid import load_hybrid_retriever
>>> from src.rag_pipeline import run_rag
>>>
>>> hybrid = load_hybrid_retriever(k=5)
>>> answer = run_rag(hybrid, "Best coffee beans for a French press")
>>> print(answer)
"""
# Import here to avoid circular imports when used from rag_pipeline.py
from src.bm25 import load as load_bm25
from src.semantic import load_vector_store
print(f"Loading BM25 index from: {bm25_index_path}")
bm25_ret: BM25Retriever = load_bm25(bm25_index_path)
print(f"Loading FAISS store from: {faiss_store_path}")
faiss_store: FAISS = load_vector_store(faiss_store_path)
retriever = HybridRetriever(
bm25_retriever=bm25_ret,
semantic_store=faiss_store,
k=k,
bm25_weight=bm25_weight,
semantic_weight=semantic_weight,
rrf_c=rrf_c,
fetch_multiplier=fetch_multiplier,
)
print(
f"HybridRetriever ready — k={k}, "
f"BM25 weight={bm25_weight}, Semantic weight={semantic_weight}, RRF c={rrf_c}"
)
return retriever