Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions sample-applications/chat-question-and-answer/app/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
import openlit

from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from sqlalchemy import text
from langchain_core.documents import Document
logging.basicConfig(level=logging.INFO)

# Check if OTLP endpoint is set in environment variables
Expand Down Expand Up @@ -58,6 +61,8 @@
EMBEDDING_ENDPOINT_URL = os.getenv("EMBEDDING_ENDPOINT_URL", "http://localhost:6006")
COLLECTION_NAME = os.getenv("INDEX_NAME")
FETCH_K = int(os.getenv("FETCH_K", "1"))
DENSE_WEIGHT = float(os.getenv("DENSE_WEIGHT", "0.5"))
SPARSE_WEIGHT = float(os.getenv("SPARSE_WEIGHT", "0.5"))

engine = create_async_engine(PG_CONNECTION_STRING)

Expand Down Expand Up @@ -86,6 +91,37 @@
search_kwargs={"k": FETCH_K, "fetch_k": FETCH_K * 3},
)

bm25_retriever = None
ensemble_retriever = None

async def init_bm25():
global bm25_retriever, ensemble_retriever
if ensemble_retriever is not None:
return
try:
async with engine.begin() as conn:
query = text(
"SELECT document, cmetadata FROM langchain_pg_embedding "
"JOIN langchain_pg_collection ON langchain_pg_embedding.collection_id = langchain_pg_collection.uuid "
"WHERE langchain_pg_collection.name = :name"
)
result = await conn.execute(query, {"name": COLLECTION_NAME})
rows = result.all()

if rows:
docs = [Document(page_content=row[0], metadata=row[1] or {}) for row in rows]
bm25_retriever = BM25Retriever.from_documents(docs)
bm25_retriever.k = FETCH_K
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, retriever],
weights=[SPARSE_WEIGHT, DENSE_WEIGHT]
)
logging.info("Ensemble retriever with BM25 initialized successfully.")
else:
logging.warning("No documents found for BM25 initializing.")
except Exception as e:
logging.error(f"Failed to initialize BM25 retriever: {str(e)}")

# Define our prompt
template = """
Use the following pieces of context from retrieved
Expand Down Expand Up @@ -150,7 +186,13 @@ async def context_retriever_fn(chain_inputs: dict):
if not question:
return {} # to keep shape consistent

retrieved_docs = await retriever.aget_relevant_documents(question)
if ensemble_retriever is None:
await init_bm25()

if ensemble_retriever:
retrieved_docs = await ensemble_retriever.ainvoke(question)
else:
retrieved_docs = await retriever.aget_relevant_documents(question)
return retrieved_docs # context: list[Document]

# Format the context in a readable way
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ services:
- EMBEDDING_ENDPOINT_URL=${EMBEDDING_ENDPOINT_URL}
- INDEX_NAME=${INDEX_NAME}
- FETCH_K=${FETCH_K}
- DENSE_WEIGHT=${DENSE_WEIGHT}
- SPARSE_WEIGHT=${SPARSE_WEIGHT}
- EMBEDDING_MODEL=${EMBEDDING_MODEL_NAME}
- PG_CONNECTION_STRING=${PG_CONNECTION_STRING}
- HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN}
Expand Down
Loading