Skip to content

Commit 5b0f981

Browse files
committed
Adds conversational RAG example to contrib
This shows how you can incorporate chat history into a RAG setting. The primary purpose is to help contextualize the ask from the new question appropriately.
1 parent da01dc8 commit 5b0f981

File tree

7 files changed

+329
-0
lines changed

7 files changed

+329
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Purpose of this module
2+
3+
This module shows a conversational retrieval augmented generation (RAG) example using
4+
Hamilton. It shows you how you might structure your code with Hamilton to
5+
create a RAG pipeline that takes into account conversation.
6+
7+
This example uses [FAISS](https://engineering.fb.com/2017/03/29/data-infrastructure/faiss-a-library-for-efficient-similarity-search/) + and in memory vector store and the OpenAI LLM provider.
8+
The implementation of the FAISS vector store uses the LangChain wrapper around it.
9+
That's because this was the simplest way to get this example up without requiring
10+
someone having to host and manage a proper vector store.
11+
12+
The "smarts" in the is pipeline are that it will take a conversation, and then a question,
13+
and then rewrite the question based on the conversation to be "standalone". That way
14+
the standalone question can be used for the vector store query, as well as a more
15+
specific question for the LLM given the found context.
16+
17+
## Example Usage
18+
19+
```python
20+
# import the module
21+
from hamilton import driver
22+
from hamilton import lifecycle
23+
dr = (
24+
driver.Builder()
25+
.with_modules(conversational_rag)
26+
.with_config({})
27+
# this prints the inputs and outputs of each step.
28+
.with_adapters(lifecycle.PrintLn(verbosity=2))
29+
.build()
30+
)
31+
# no chat history -- nothing to rewrite
32+
result = dr.execute(
33+
["conversational_rag_response"],
34+
inputs={
35+
"input_texts": [
36+
"harrison worked at kensho",
37+
"stefan worked at Stitch Fix",
38+
],
39+
"question": "where did stefan work?",
40+
"chat_history": []
41+
},
42+
)
43+
print(result)
44+
45+
# this will now reword the question to then be
46+
# used to query the vector store and the final LLM call.
47+
result = dr.execute(
48+
["conversational_rag_response"],
49+
inputs={
50+
"input_texts": [
51+
"harrison worked at kensho",
52+
"stefan worked at Stitch Fix",
53+
],
54+
"question": "where did he work?",
55+
"chat_history": [
56+
"Human: Who wrote this example?",
57+
"AI: Stefan"
58+
]
59+
},
60+
)
61+
print(result)
62+
```
63+
64+
# How to extend this module
65+
What you'd most likely want to do is:
66+
67+
1. Change the vector store (and how embeddings are generated).
68+
2. Change the LLM provider.
69+
3. Change the context and prompt.
70+
71+
With (1) you can import any vector store/library that you want. You should draw out
72+
the process you would like, and that should then map to Hamilton functions.
73+
With (2) you can import any LLM provider that you want, just use `@config.when` if you
74+
want to switch between multiple providers.
75+
With (3) you can add more functions that create parts of the prompt.
76+
77+
# Configuration Options
78+
There is no configuration needed for this module.
79+
80+
# Limitations
81+
82+
You need to have the OPENAI_API_KEY in your environment.
83+
It should be accessible from your code by doing `os.environ["OPENAI_API_KEY"]`.
84+
85+
The code does not check the context length, so it may fail if the context passed is too long
86+
for the LLM you send it to.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import logging
2+
3+
logger = logging.getLogger(__name__)
4+
5+
from hamilton import contrib
6+
7+
with contrib.catch_import_errors(__name__, __file__, logger):
8+
import openai
9+
10+
# use langchain implementation of vector store
11+
from langchain_community.vectorstores import FAISS
12+
from langchain_core.vectorstores import VectorStoreRetriever
13+
14+
# use langchain embedding wrapper with vector store
15+
from langchain_openai import OpenAIEmbeddings
16+
17+
18+
def standalone_question_prompt(chat_history: list[str], question: str) -> str:
19+
"""Prompt for getting a standalone question given the chat history.
20+
21+
This is then used to query the vector store with.
22+
23+
:param chat_history: the history of the conversation.
24+
:param question: the current user question.
25+
:return: prompt to use.
26+
"""
27+
chat_history_str = "\n".join(chat_history)
28+
return (
29+
"Given the following conversation and a follow up question, "
30+
"rephrase the follow up question to be a standalone question, "
31+
"in its original language.\n\n"
32+
"Chat History:\n"
33+
"{chat_history}\n"
34+
"Follow Up Input: {question}\n"
35+
"Standalone question:"
36+
).format(chat_history=chat_history_str, question=question)
37+
38+
39+
def standalone_question(standalone_question_prompt: str, llm_client: openai.OpenAI) -> str:
40+
"""Asks the LLM to create a standalone question from the prompt.
41+
42+
:param standalone_question_prompt: the prompt with context.
43+
:param llm_client: the llm client to use.
44+
:return: the standalone question.
45+
"""
46+
response = llm_client.chat.completions.create(
47+
model="gpt-3.5-turbo",
48+
messages=[{"role": "user", "content": standalone_question_prompt}],
49+
)
50+
return response.choices[0].message.content
51+
52+
53+
def vector_store(input_texts: list[str]) -> VectorStoreRetriever:
54+
"""A Vector store. This function populates and creates one for querying.
55+
56+
This is a cute function encapsulating the creation of a vector store. In real life
57+
you could replace this with a more complex function, or one that returns a
58+
client to an existing vector store.
59+
60+
:param input_texts: the input "text" i.e. documents to be stored.
61+
:return: a vector store that can be queried against.
62+
"""
63+
vectorstore = FAISS.from_texts(input_texts, embedding=OpenAIEmbeddings())
64+
retriever = vectorstore.as_retriever()
65+
return retriever
66+
67+
68+
def context(standalone_question: str, vector_store: VectorStoreRetriever, top_k: int = 5) -> str:
69+
"""This function returns the string context to put into a prompt for the RAG model.
70+
71+
It queries the provided vector store for information.
72+
73+
:param standalone_question: the question to use to search the vector store against.
74+
:param vector_store: the vector store to search against.
75+
:param top_k: the number of results to return.
76+
:return: a string with all the context.
77+
"""
78+
_results = vector_store.invoke(standalone_question, search_kwargs={"k": top_k})
79+
return "\n\n".join(map(lambda d: d.page_content, _results))
80+
81+
82+
def answer_prompt(context: str, standalone_question: str) -> str:
83+
"""Creates a prompt that includes the question and context for the LLM to make sense of.
84+
85+
:param context: the information context to use.
86+
:param standalone_question: the user question the LLM should answer.
87+
:return: the full prompt.
88+
"""
89+
template = (
90+
"Answer the question based only on the following context:\n"
91+
"{context}\n\n"
92+
"Question: {question}"
93+
)
94+
95+
return template.format(context=context, question=standalone_question)
96+
97+
98+
def llm_client() -> openai.OpenAI:
99+
"""The LLM client to use for the RAG model."""
100+
return openai.OpenAI()
101+
102+
103+
def conversational_rag_response(answer_prompt: str, llm_client: openai.OpenAI) -> str:
104+
"""Creates the RAG response from the LLM model for the given prompt.
105+
106+
:param answer_prompt: the prompt to send to the LLM.
107+
:param llm_client: the LLM client to use.
108+
:return: the response from the LLM.
109+
"""
110+
response = llm_client.chat.completions.create(
111+
model="gpt-3.5-turbo",
112+
messages=[{"role": "user", "content": answer_prompt}],
113+
)
114+
return response.choices[0].message.content
115+
116+
117+
if __name__ == "__main__":
118+
import __init__ as conversational_rag
119+
120+
from hamilton import driver, lifecycle
121+
122+
dr = (
123+
driver.Builder()
124+
.with_modules(conversational_rag)
125+
.with_config({})
126+
# this prints the inputs and outputs of each step.
127+
.with_adapters(lifecycle.PrintLn(verbosity=2))
128+
.build()
129+
)
130+
dr.display_all_functions("dag.png")
131+
132+
# shows no question is reworded
133+
print(
134+
dr.execute(
135+
["conversational_rag_response"],
136+
inputs={
137+
"input_texts": [
138+
"harrison worked at kensho",
139+
"stefan worked at Stitch Fix",
140+
],
141+
"question": "where did stefan work?",
142+
"chat_history": [],
143+
},
144+
)
145+
)
146+
147+
# this will now reword the question to then be
148+
# used to query the vector store.
149+
print(
150+
dr.execute(
151+
["conversational_rag_response"],
152+
inputs={
153+
"input_texts": [
154+
"harrison worked at kensho",
155+
"stefan worked at Stitch Fix",
156+
],
157+
"question": "where did he work?",
158+
"chat_history": ["Human: Who wrote this example?", "AI: Stefan"],
159+
},
160+
)
161+
)
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from operator import itemgetter
2+
3+
from langchain.prompts.prompt import PromptTemplate
4+
from langchain.schema import format_document
5+
from langchain_community.vectorstores import FAISS
6+
from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string
7+
from langchain_core.output_parsers import StrOutputParser
8+
from langchain_core.prompts import ChatPromptTemplate
9+
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
10+
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
11+
12+
vectorstore = FAISS.from_texts(["harrison worked at kensho"], embedding=OpenAIEmbeddings())
13+
retriever = vectorstore.as_retriever()
14+
15+
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
16+
17+
Chat History:
18+
{chat_history}
19+
Follow Up Input: {question}
20+
Standalone question:"""
21+
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
22+
23+
template = """Answer the question based only on the following context:
24+
{context}
25+
26+
Question: {question}
27+
"""
28+
ANSWER_PROMPT = ChatPromptTemplate.from_template(template)
29+
30+
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
31+
32+
33+
def _combine_documents(docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
34+
doc_strings = [format_document(doc, document_prompt) for doc in docs]
35+
return document_separator.join(doc_strings)
36+
37+
38+
_inputs = RunnableParallel(
39+
standalone_question=RunnablePassthrough.assign(
40+
chat_history=lambda x: get_buffer_string(x["chat_history"])
41+
)
42+
| CONDENSE_QUESTION_PROMPT
43+
| ChatOpenAI(temperature=0)
44+
| StrOutputParser(),
45+
)
46+
_context = {
47+
"context": itemgetter("standalone_question") | retriever | _combine_documents,
48+
"question": lambda x: x["standalone_question"],
49+
}
50+
conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | ChatOpenAI()
51+
52+
print(
53+
conversational_qa_chain.invoke(
54+
{
55+
"question": "where did harrison work?",
56+
"chat_history": [],
57+
}
58+
)
59+
)
60+
print(
61+
conversational_qa_chain.invoke(
62+
{
63+
"question": "where did he work?",
64+
"chat_history": [
65+
HumanMessage(content="Who wrote this notebook?"),
66+
AIMessage(content="Harrison"),
67+
],
68+
}
69+
)
70+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
faiss-cpu
2+
langchain
3+
langchain-community
4+
langchain-openai
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"schema": "1.0",
3+
"use_case_tags": ["LLM", "openai", "RAG", "retrieval augmented generation", "FAISS"],
4+
"secondary_tags": {
5+
"language": "English"
6+
}
7+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"description": "Default", "name": "default", "config": {}}

0 commit comments

Comments
 (0)