Skip to content

Commit 232176a

Browse files
authored
Contrib/conversational rag example (#674)
Adds conversational RAG example to contrib This shows how you can incorporate chat history into a RAG setting. The primary purpose is to help contextualize the ask from the new question appropriately.
1 parent 33c9e36 commit 232176a

File tree

7 files changed

+351
-0
lines changed

7 files changed

+351
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Purpose of this module
2+
3+
This module shows a conversational retrieval augmented generation (RAG) example using
4+
Hamilton. It shows you how you might structure your code with Hamilton to
5+
create a RAG pipeline that takes into account conversation.
6+
7+
This example uses [FAISS](https://engineering.fb.com/2017/03/29/data-infrastructure/faiss-a-library-for-efficient-similarity-search/) + and in memory vector store and the OpenAI LLM provider.
8+
The implementation of the FAISS vector store uses the LangChain wrapper around it.
9+
That's because this was the simplest way to get this example up without requiring
10+
someone having to host and manage a proper vector store.
11+
12+
The "smarts" in the is pipeline are that it will take a conversation, and then a question,
13+
and then rewrite the question based on the conversation to be "standalone". That way
14+
the standalone question can be used for the vector store query, as well as a more
15+
specific question for the LLM given the found context.
16+
17+
## Example Usage
18+
19+
### Inputs
20+
These are the defined inputs you can provide.
21+
22+
- *input_texts*: A list of strings. Each string will be encoded into a vector and stored in the vector store.
23+
- *question*: A string. This is the question you want to ask the LLM, and vector store which will provide context.
24+
- *chat_history*: A list of strings. Each string is a line of conversation. They need to be prefixed with "Human" or "AI" to indicate who said it. They should be alternating.
25+
- *top_k*: An integer. This is the number of vectors to retrieve from the vector store. Defaults to 5.
26+
27+
### Overrides
28+
With Hamilton you can easily override a function and provide a value for it. For example if you're
29+
iterating you might just want to override these two values before modifying the functions:
30+
31+
- *context*: if you want to skip going to the vector store and provide the context directly, you can do so by providing this override.
32+
- *standalone_question*: if you want to skip the rewording of the question, you can provide the standalone question directly.
33+
- *answer_prompt*: if you want to provide the prompt to pass to the LLM, pass it in as an override.
34+
35+
### Execution
36+
You can ask to get back any result of an intermediate function by providing the function name in the `execute` call.
37+
Here we just ask for the final result, but if you wanted to, you could ask for outputs of any of the functions, which
38+
you can then introspect or log for debugging/evaluation purposes. Note if you want more platform integrations,
39+
you can add adapters that will do this automatically for you, e.g. like we have the `PrintLn` adapter here.
40+
41+
```python
42+
# import the module
43+
from hamilton import driver
44+
from hamilton import lifecycle
45+
dr = (
46+
driver.Builder()
47+
.with_modules(conversational_rag)
48+
.with_config({})
49+
# this prints the inputs and outputs of each step.
50+
.with_adapters(lifecycle.PrintLn(verbosity=2))
51+
.build()
52+
)
53+
# no chat history -- nothing to rewrite
54+
result = dr.execute(
55+
["conversational_rag_response"],
56+
inputs={
57+
"input_texts": [
58+
"harrison worked at kensho",
59+
"stefan worked at Stitch Fix",
60+
],
61+
"question": "where did stefan work?",
62+
"chat_history": []
63+
},
64+
)
65+
print(result)
66+
67+
# this will now reword the question to then be
68+
# used to query the vector store and the final LLM call.
69+
result = dr.execute(
70+
["conversational_rag_response"],
71+
inputs={
72+
"input_texts": [
73+
"harrison worked at kensho",
74+
"stefan worked at Stitch Fix",
75+
],
76+
"question": "where did he work?",
77+
"chat_history": [
78+
"Human: Who wrote this example?",
79+
"AI: Stefan"
80+
]
81+
},
82+
)
83+
print(result)
84+
```
85+
86+
# How to extend this module
87+
What you'd most likely want to do is:
88+
89+
1. Change the vector store (and how embeddings are generated).
90+
2. Change the LLM provider.
91+
3. Change the context and prompt.
92+
93+
With (1) you can import any vector store/library that you want. You should draw out
94+
the process you would like, and that should then map to Hamilton functions.
95+
With (2) you can import any LLM provider that you want, just use `@config.when` if you
96+
want to switch between multiple providers.
97+
With (3) you can add more functions that create parts of the prompt.
98+
99+
# Configuration Options
100+
There is no configuration needed for this module.
101+
102+
# Limitations
103+
104+
You need to have the OPENAI_API_KEY in your environment.
105+
It should be accessible from your code by doing `os.environ["OPENAI_API_KEY"]`.
106+
107+
The code does not check the context length, so it may fail if the context passed is too long
108+
for the LLM you send it to.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import logging
2+
3+
logger = logging.getLogger(__name__)
4+
5+
from hamilton import contrib
6+
7+
with contrib.catch_import_errors(__name__, __file__, logger):
8+
import openai
9+
10+
# use langchain implementation of vector store
11+
from langchain_community.vectorstores import FAISS
12+
from langchain_core.vectorstores import VectorStoreRetriever
13+
14+
# use langchain embedding wrapper with vector store
15+
from langchain_openai import OpenAIEmbeddings
16+
17+
18+
def standalone_question_prompt(chat_history: list[str], question: str) -> str:
19+
"""Prompt for getting a standalone question given the chat history.
20+
21+
This is then used to query the vector store with.
22+
23+
:param chat_history: the history of the conversation.
24+
:param question: the current user question.
25+
:return: prompt to use.
26+
"""
27+
chat_history_str = "\n".join(chat_history)
28+
return (
29+
"Given the following conversation and a follow up question, "
30+
"rephrase the follow up question to be a standalone question, "
31+
"in its original language.\n\n"
32+
"Chat History:\n"
33+
"{chat_history}\n"
34+
"Follow Up Input: {question}\n"
35+
"Standalone question:"
36+
).format(chat_history=chat_history_str, question=question)
37+
38+
39+
def standalone_question(standalone_question_prompt: str, llm_client: openai.OpenAI) -> str:
40+
"""Asks the LLM to create a standalone question from the prompt.
41+
42+
:param standalone_question_prompt: the prompt with context.
43+
:param llm_client: the llm client to use.
44+
:return: the standalone question.
45+
"""
46+
response = llm_client.chat.completions.create(
47+
model="gpt-3.5-turbo",
48+
messages=[{"role": "user", "content": standalone_question_prompt}],
49+
)
50+
return response.choices[0].message.content
51+
52+
53+
def vector_store(input_texts: list[str]) -> VectorStoreRetriever:
54+
"""A Vector store. This function populates and creates one for querying.
55+
56+
This is a cute function encapsulating the creation of a vector store. In real life
57+
you could replace this with a more complex function, or one that returns a
58+
client to an existing vector store.
59+
60+
:param input_texts: the input "text" i.e. documents to be stored.
61+
:return: a vector store that can be queried against.
62+
"""
63+
vectorstore = FAISS.from_texts(input_texts, embedding=OpenAIEmbeddings())
64+
retriever = vectorstore.as_retriever()
65+
return retriever
66+
67+
68+
def context(standalone_question: str, vector_store: VectorStoreRetriever, top_k: int = 5) -> str:
69+
"""This function returns the string context to put into a prompt for the RAG model.
70+
71+
It queries the provided vector store for information.
72+
73+
:param standalone_question: the question to use to search the vector store against.
74+
:param vector_store: the vector store to search against.
75+
:param top_k: the number of results to return.
76+
:return: a string with all the context.
77+
"""
78+
_results = vector_store.invoke(standalone_question, search_kwargs={"k": top_k})
79+
return "\n\n".join(map(lambda d: d.page_content, _results))
80+
81+
82+
def answer_prompt(context: str, standalone_question: str) -> str:
83+
"""Creates a prompt that includes the question and context for the LLM to make sense of.
84+
85+
:param context: the information context to use.
86+
:param standalone_question: the user question the LLM should answer.
87+
:return: the full prompt.
88+
"""
89+
template = (
90+
"Answer the question based only on the following context:\n"
91+
"{context}\n\n"
92+
"Question: {question}"
93+
)
94+
95+
return template.format(context=context, question=standalone_question)
96+
97+
98+
def llm_client() -> openai.OpenAI:
99+
"""The LLM client to use for the RAG model."""
100+
return openai.OpenAI()
101+
102+
103+
def conversational_rag_response(answer_prompt: str, llm_client: openai.OpenAI) -> str:
104+
"""Creates the RAG response from the LLM model for the given prompt.
105+
106+
:param answer_prompt: the prompt to send to the LLM.
107+
:param llm_client: the LLM client to use.
108+
:return: the response from the LLM.
109+
"""
110+
response = llm_client.chat.completions.create(
111+
model="gpt-3.5-turbo",
112+
messages=[{"role": "user", "content": answer_prompt}],
113+
)
114+
return response.choices[0].message.content
115+
116+
117+
if __name__ == "__main__":
118+
import __init__ as conversational_rag
119+
120+
from hamilton import driver, lifecycle
121+
122+
dr = (
123+
driver.Builder()
124+
.with_modules(conversational_rag)
125+
.with_config({})
126+
# this prints the inputs and outputs of each step.
127+
.with_adapters(lifecycle.PrintLn(verbosity=2))
128+
.build()
129+
)
130+
dr.display_all_functions("dag.png")
131+
132+
# shows no question is reworded
133+
print(
134+
dr.execute(
135+
["conversational_rag_response"],
136+
inputs={
137+
"input_texts": [
138+
"harrison worked at kensho",
139+
"stefan worked at Stitch Fix",
140+
],
141+
"question": "where did stefan work?",
142+
"chat_history": [],
143+
},
144+
)
145+
)
146+
147+
# this will now reword the question to then be
148+
# used to query the vector store.
149+
print(
150+
dr.execute(
151+
["conversational_rag_response"],
152+
inputs={
153+
"input_texts": [
154+
"harrison worked at kensho",
155+
"stefan worked at Stitch Fix",
156+
],
157+
"question": "where did he work?",
158+
"chat_history": ["Human: Who wrote this example?", "AI: Stefan"],
159+
},
160+
)
161+
)
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from operator import itemgetter
2+
3+
from langchain.prompts.prompt import PromptTemplate
4+
from langchain.schema import format_document
5+
from langchain_community.vectorstores import FAISS
6+
from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string
7+
from langchain_core.output_parsers import StrOutputParser
8+
from langchain_core.prompts import ChatPromptTemplate
9+
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
10+
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
11+
12+
vectorstore = FAISS.from_texts(["harrison worked at kensho"], embedding=OpenAIEmbeddings())
13+
retriever = vectorstore.as_retriever()
14+
15+
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
16+
17+
Chat History:
18+
{chat_history}
19+
Follow Up Input: {question}
20+
Standalone question:"""
21+
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
22+
23+
template = """Answer the question based only on the following context:
24+
{context}
25+
26+
Question: {question}
27+
"""
28+
ANSWER_PROMPT = ChatPromptTemplate.from_template(template)
29+
30+
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
31+
32+
33+
def _combine_documents(docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
34+
doc_strings = [format_document(doc, document_prompt) for doc in docs]
35+
return document_separator.join(doc_strings)
36+
37+
38+
_inputs = RunnableParallel(
39+
standalone_question=RunnablePassthrough.assign(
40+
chat_history=lambda x: get_buffer_string(x["chat_history"])
41+
)
42+
| CONDENSE_QUESTION_PROMPT
43+
| ChatOpenAI(temperature=0)
44+
| StrOutputParser(),
45+
)
46+
_context = {
47+
"context": itemgetter("standalone_question") | retriever | _combine_documents,
48+
"question": lambda x: x["standalone_question"],
49+
}
50+
conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | ChatOpenAI()
51+
52+
print(
53+
conversational_qa_chain.invoke(
54+
{
55+
"question": "where did harrison work?",
56+
"chat_history": [],
57+
}
58+
)
59+
)
60+
print(
61+
conversational_qa_chain.invoke(
62+
{
63+
"question": "where did he work?",
64+
"chat_history": [
65+
HumanMessage(content="Who wrote this notebook?"),
66+
AIMessage(content="Harrison"),
67+
],
68+
}
69+
)
70+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
faiss-cpu
2+
langchain
3+
langchain-community
4+
langchain-openai
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"schema": "1.0",
3+
"use_case_tags": ["LLM", "openai", "RAG", "retrieval augmented generation", "FAISS"],
4+
"secondary_tags": {
5+
"language": "English"
6+
}
7+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"description": "Default", "name": "default", "config": {}}

0 commit comments

Comments
 (0)