Skip to content

Commit 33c9e36

Browse files
authored
Adds simple RAG example to contrib (#673)
This is a basic example to show the basic mechanics of a RAG pipeline. It uses an in memory vector store with the FAISS for similarity search.
1 parent 3ea0068 commit 33c9e36

File tree

6 files changed

+197
-0
lines changed

6 files changed

+197
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Purpose of this module
2+
3+
This module shows a simple retrieval augmented generation (RAG) example using
4+
Hamilton. It shows you how you might structure your code with Hamilton to
5+
create a simple RAG pipeline.
6+
7+
This example uses [FAISS](https://engineering.fb.com/2017/03/29/data-infrastructure/faiss-a-library-for-efficient-similarity-search/) + and in memory vector store and the OpenAI LLM provider.
8+
The implementation of the FAISS vector store uses the LangChain wrapper around it.
9+
That's because this was the simplest way to get this example up without requiring
10+
someone having to host and manage a proper vector store.
11+
12+
## Example Usage
13+
14+
### Inputs
15+
These are the defined inputs.
16+
17+
- *input_texts*: A list of strings. Each string will be encoded into a vector and stored in the vector store.
18+
- *question*: A string. This is the question you want to ask the LLM, and vector store which will provide context.
19+
- *top_k*: An integer. This is the number of vectors to retrieve from the vector store. Defaults to 5.
20+
21+
### Overrides
22+
With Hamilton you can easily override a function and provide a value for it. For example if you're
23+
iterating you might just want to override these two values before modifying the functions:
24+
25+
- *context*: if you want to skip going to the vector store and provide the context directly, you can do so by providing this override.
26+
- *rag_prompt*: if you want to provide the prompt to pass to the LLM, pass it in as an override.
27+
28+
### Execution
29+
You can ask to get back any result of an intermediate function by providing the function name in the `execute` call.
30+
Here we just ask for the final result, but if you wanted to, you could ask for outputs of any of the functions, which
31+
you can then introspect or log for debugging/evaluation purposes. Note if you want more platform integrations,
32+
you can add adapters that will do this automatically for you, e.g. like we have the `PrintLn` adapter here.
33+
```python
34+
# import the module
35+
from hamilton import driver
36+
from hamilton import lifecycle
37+
dr = (
38+
driver.Builder()
39+
.with_modules(faiss_rag)
40+
.with_config({})
41+
# this prints the inputs and outputs of each step.
42+
.with_adapters(lifecycle.PrintLn(verbosity=2))
43+
.build()
44+
)
45+
result = dr.execute(
46+
["rag_response"],
47+
inputs={
48+
"input_texts": [
49+
"harrison worked at kensho",
50+
"stefan worked at Stitch Fix",
51+
],
52+
"question": "where did stefan work?",
53+
},
54+
)
55+
print(result)
56+
```
57+
58+
# How to extend this module
59+
What you'd most likely want to do is:
60+
61+
1. Change the vector store (and how embeddings are generated).
62+
2. Change the LLM provider.
63+
3. Change the context and prompt.
64+
65+
With (1) you can import any vector store/library that you want. You should draw out
66+
the process you would like, and that should then map to Hamilton functions.
67+
With (2) you can import any LLM provider that you want, just use `@config.when` if you
68+
want to switch between multiple providers.
69+
With (3) you can add more functions that create parts of the prompt.
70+
71+
# Configuration Options
72+
There is no configuration needed for this module.
73+
74+
# Limitations
75+
76+
You need to have the OPENAI_API_KEY in your environment.
77+
It should be accessible from your code by doing `os.environ["OPENAI_API_KEY"]`.
78+
79+
The code does not check the context length, so it may fail if the context passed is too long
80+
for the LLM you send it to.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import logging
2+
3+
logger = logging.getLogger(__name__)
4+
5+
from hamilton import contrib
6+
7+
with contrib.catch_import_errors(__name__, __file__, logger):
8+
import openai
9+
10+
# use langchain implementation of vector store
11+
from langchain_community.vectorstores import FAISS
12+
from langchain_core.vectorstores import VectorStoreRetriever
13+
14+
# use langchain embedding wrapper with vector store
15+
from langchain_openai import OpenAIEmbeddings
16+
17+
18+
def vector_store(input_texts: list[str]) -> VectorStoreRetriever:
19+
"""A Vector store. This function populates and creates one for querying.
20+
21+
This is a cute function encapsulating the creation of a vector store. In real life
22+
you could replace this with a more complex function, or one that returns a
23+
client to an existing vector store.
24+
25+
:param input_texts: the input "text" i.e. documents to be stored.
26+
:return: a vector store that can be queried against.
27+
"""
28+
vectorstore = FAISS.from_texts(input_texts, embedding=OpenAIEmbeddings())
29+
retriever = vectorstore.as_retriever()
30+
return retriever
31+
32+
33+
def context(question: str, vector_store: VectorStoreRetriever, top_k: int = 5) -> str:
34+
"""This function returns the string context to put into a prompt for the RAG model.
35+
36+
:param question: the user question to use to search the vector store against.
37+
:param vector_store: the vector store to search against.
38+
:param top_k: the number of results to return.
39+
:return: a string with all the context.
40+
"""
41+
_results = vector_store.invoke(question, search_kwargs={"k": top_k})
42+
return "\n".join(map(lambda d: d.page_content, _results))
43+
44+
45+
def rag_prompt(context: str, question: str) -> str:
46+
"""Creates a prompt that includes the question and context for the LLM to makse sense of.
47+
48+
:param context: the information context to use.
49+
:param question: the user question the LLM should answer.
50+
:return: the full prompt.
51+
"""
52+
template = (
53+
"Answer the question based only on the following context:\n"
54+
"{context}\n\n"
55+
"Question: {question}"
56+
)
57+
58+
return template.format(context=context, question=question)
59+
60+
61+
def llm_client() -> openai.OpenAI:
62+
"""The LLM client to use for the RAG model."""
63+
return openai.OpenAI()
64+
65+
66+
def rag_response(rag_prompt: str, llm_client: openai.OpenAI) -> str:
67+
"""Creates the RAG response from the LLM model for the given prompt.
68+
69+
:param rag_prompt: the prompt to send to the LLM.
70+
:param llm_client: the LLM client to use.
71+
:return: the response from the LLM.
72+
"""
73+
response = llm_client.chat.completions.create(
74+
model="gpt-3.5-turbo",
75+
messages=[{"role": "user", "content": rag_prompt}],
76+
)
77+
return response.choices[0].message.content
78+
79+
80+
if __name__ == "__main__":
81+
import __init__ as hamilton_faiss_rag
82+
83+
from hamilton import driver, lifecycle
84+
85+
dr = (
86+
driver.Builder()
87+
.with_modules(hamilton_faiss_rag)
88+
.with_config({})
89+
# this prints the inputs and outputs of each step.
90+
.with_adapters(lifecycle.PrintLn(verbosity=2))
91+
.build()
92+
)
93+
dr.display_all_functions("dag.png")
94+
print(
95+
dr.execute(
96+
["rag_response"],
97+
inputs={
98+
"input_texts": [
99+
"harrison worked at kensho",
100+
"stefan worked at Stitch Fix",
101+
],
102+
"question": "where did stefan work?",
103+
},
104+
)
105+
)
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
faiss-cpu
2+
langchain
3+
langchain-community
4+
langchain-openai
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"schema": "1.0",
3+
"use_case_tags": ["LLM", "openai", "RAG", "retrieval augmented generation", "FAISS"],
4+
"secondary_tags": {
5+
"language": "English"
6+
}
7+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"description": "Default", "name": "default", "config": {}}

0 commit comments

Comments
 (0)