Skip to content

Commit 3d8e506

Browse files
authored
Merge pull request #74 from amosproj/API_error_handling_branch
Api error handling branch and context awareness
2 parents ae52d2a + 94d7ac3 commit 3d8e506

4 files changed

Lines changed: 108 additions & 130 deletions

File tree

src/ChatUI_streamlit/LLMModel.py

Lines changed: 71 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
2-
# LLMModel
1+
#%%
32
import os
43
from langchain.chains import RetrievalQA
5-
64
from langchain.cache import InMemoryCache
75
from langchain.globals import set_llm_cache
86
from langchain.document_loaders.generic import GenericLoader
@@ -16,89 +14,82 @@
1614
from langchain.memory import ConversationSummaryMemory
1715
from langchain.vectorstores import FAISS
1816
from langchain.document_loaders.text import TextLoader
17+
from langchain.agents import AgentType, Tool, initialize_agent
18+
from langchain.memory import ConversationBufferMemory
19+
from langchain.agents import AgentExecutor
20+
# import faiss
21+
from langchain.vectorstores import FAISS as FAISS
22+
import faiss
1923

20-
openai_api_key = os.environ["OPENAI_API_KEY"]
24+
# Load the OpenAI API key
2125

22-
# check if the API key is loaded
26+
openai_api_key = os.environ["OPENAI_API_KEY"]
2327
assert openai_api_key is not None, "Failed to load the OpenAI API key from .env file. Please create .env file and add OPENAI_API_KEY = 'your key'"
2428

25-
26-
27-
28-
llm = ChatOpenAI(model_name='gpt-3.5-turbo',openai_api_key=openai_api_key) # Load the LLM model
29-
# set_llm_cache(InMemoryCache())
30-
31-
32-
embeddings = OpenAIEmbeddings(disallowed_special=(), openai_api_key=openai_api_key) # Load the embeddings
33-
#
34-
# # This is the root directory for the documents i want to create the RAG from
35-
# repo_path = '/Users/zainhazzouri/projects/amos2023ws05-pipeline-config-chat-ai/src/RAG'
36-
# loader = GenericLoader.from_filesystem(
37-
# repo_path,
38-
# glob="**/*",
39-
# suffixes=[".py"],
40-
# parser=LanguageParser(language=Language.PYTHON, parser_threshold=500),
41-
# )
42-
# documents = loader.load()
43-
#
44-
# python_splitter = RecursiveCharacterTextSplitter.from_language(
45-
# language=Language.PYTHON, chunk_size=2000, chunk_overlap=200
46-
# )
47-
# texts = python_splitter.split_documents(documents)
48-
#
49-
#
50-
# db = Chroma.from_documents(texts, OpenAIEmbeddings(disallowed_special=()))
51-
# retriever = db.as_retriever(
52-
# search_type="mmr", # Also test "similarity"
53-
# search_kwargs={"k": 8},
54-
# )
55-
56-
########################################## the old version of RAG
57-
# This is the root directory for the documents i want to create the RAG from
58-
root_dir = os.path.join("..", "RAG")
59-
docs = [] # Create an empty list to store the docs
60-
61-
# Go through each folder to extract all the files
62-
for dirpath, dirnames, filenames in os.walk(root_dir):
63-
64-
# Go through each file
65-
for file in filenames:
66-
try:
67-
# Load up the file as a doc and split
68-
loader = TextLoader(os.path.join(dirpath, file), encoding='utf-8')
69-
docs.extend(loader.load_and_split())
70-
except Exception as e:
71-
pass
72-
73-
docsearch = FAISS.from_documents(docs, embeddings) # Create the FAISS index
74-
# source https://python.langchain.com/docs/integrations/vectorstores/faiss_async
75-
76-
77-
#memory = ConversationSummaryMemory(llm=llm, memory_key="chat_history", return_messages=True)
78-
# add caching to the memory
79-
80-
81-
RAG = RetrievalQA.from_chain_type(llm,chain_type="stuff" ,retriever=docsearch.as_retriever()) # the old chain for the retrieval
29+
# Initialize the language model
30+
llm = ChatOpenAI(model_name='gpt-3.5-turbo', openai_api_key=openai_api_key)
31+
32+
# Load the embeddings
33+
embeddings = OpenAIEmbeddings(disallowed_special=(), openai_api_key=openai_api_key)
34+
35+
# # Load and split documents
36+
# root_dir = '/Users/zainhazzouri/projects/amos2023ws05-pipeline-config-chat-ai/src/RAG/pipelines'
37+
# docs = []
38+
# for dirpath, dirnames, filenames in os.walk(root_dir):
39+
# for file in filenames:
40+
# try:
41+
# loader = TextLoader(os.path.join(dirpath, file), encoding='utf-8')
42+
# docs.extend(loader.load_and_split())
43+
# except Exception as e:
44+
# pass # Consider logging the exception for debugging
45+
46+
# # Create the FAISS index
47+
# docsearch = FAISS.from_documents(docs, embeddings)
48+
49+
#%%
50+
# save the vector store offline for later use
51+
# faiss.write_index(docsearch.index, '/Users/zainhazzouri/projects/amos2023ws05-pipeline-config-chat-ai/src/ChatUI_streamlit/faiss_index_file')
52+
# docsearch.save_local("/Users/zainhazzouri/projects/amos2023ws05-pipeline-config-chat-ai/src/ChatUI_streamlit/faiss_index")
53+
54+
#%%
55+
docsearch = FAISS.load_local("/Users/zainhazzouri/projects/amos2023ws05-pipeline-config-chat-ai/src/ChatUI_streamlit/faiss_index", embeddings)
56+
#%%
57+
# Initialize RetrievalQA
58+
RAG = RetrievalQA.from_chain_type(llm, chain_type="stuff", retriever=docsearch.as_retriever())
59+
60+
# Define tools
61+
tools = [
62+
Tool(
63+
name="RTDIP SDK",
64+
func=RAG.run,
65+
description="useful for when you need to answer questions about RTDIP",
66+
)
67+
]
68+
69+
# Initialize conversation memory
70+
conversation_memory = ConversationBufferMemory()
71+
72+
# Initialize Agent with conversation memory
73+
agent = initialize_agent(
74+
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, memory=conversation_memory, handle_parsing_errors=True
75+
)
76+
77+
# Set the LLM cache
8278
set_llm_cache(InMemoryCache())
8379

80+
# Function to update and retrieve conversation context
81+
def update_and_get_context(user_input, conversation_memory):
82+
conversation_memory.add_user_input(user_input)
83+
context = conversation_memory.get_conversation()
84+
model_input = "\n".join(context + [user_input])
85+
return model_input
8486

85-
#RAG = ConversationalRetrievalChain.from_llm(llm,chain_type="stuff", retriever=docsearch.as_retriever()) # the new chain for the retrieval
86-
87+
# Example usage (commented out for testing)
88+
# user_input = "What's the weather like today?"
89+
# model_input = update_and_get_context(user_input, conversation_memory)
90+
# response = llm.run(model_input)
91+
# print(response)
8792

88-
##### this code for testing the model don't delete it --
93+
# Note: You can uncomment and modify the testing code as per your use case.
8994

90-
# question1 = " Hello , my name is Zain"
91-
# question2 = " what's my name?"
92-
#question3 = "I would like to use RTDIP components to read from an eventhub using ‘connection string’ as the connection string, and ‘consumer group’ as the consumer group, transform using binary to string, and edge x transformer then write to delta, return only the python code "
93-
#
94-
# result = RAG(question1)
95-
# result["answer"]
96-
# print(result["answer"])
97-
#
98-
# result = RAG(question2)
99-
# result["answer"]
100-
# print(result["answer"])
101-
#
102-
#result = RAG(question3)
103-
#result["answer"]
104-
#print(result["answer"])
95+
# %%

src/ChatUI_streamlit/app.py

Lines changed: 37 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,29 @@
1-
2-
# app
31
import streamlit as st
4-
import replicate
52
import os
63
import time
7-
8-
9-
10-
# App title
4+
import requests
5+
import openai
6+
7+
class InvalidAPIKeyException(Exception):
8+
pass
9+
10+
# Function to check API key validity
11+
def is_valid_api_key(key):
12+
url = "https://api.openai.com/v1/models/gpt-3.5-turbo-instruct"
13+
headers = {"Authorization": f"Bearer {key}"}
14+
try:
15+
response = requests.get(url, headers=headers)
16+
return response.status_code == 200
17+
except Exception as e:
18+
print(f"An error occurred: {e}")
19+
return False
20+
21+
# Initialize page configuration once
1122
if 'page_config_set' not in st.session_state:
1223
st.set_page_config(page_title="RTDIP Pipeline Chatbot")
1324
st.session_state['page_config_set'] = True
1425

15-
# Use HTML/CSS to position the title and GitHub link on the same line
26+
# HTML/CSS for title and GitHub link
1627
st.markdown(
1728
'''
1829
<div style="display: flex; justify-content: space-between; align-items: center;">
@@ -21,24 +32,21 @@
2132
</div>
2233
''', unsafe_allow_html=True)
2334

35+
# Check if the OpenAI API key is already stored in the session
36+
if 'OPENAI_API_KEY' not in st.session_state:
37+
# If not, ask the user to input it
38+
openai_api_key = st.text_input('Enter OpenAI API Key:', type='password')
39+
if openai_api_key:
40+
try:
41+
if is_valid_api_key(openai_api_key):
42+
st.session_state['OPENAI_API_KEY'] = openai_api_key
43+
os.environ['OPENAI_API_KEY'] = openai_api_key
44+
st.success('API Key stored!')
45+
else:
46+
raise InvalidAPIKeyException
47+
except InvalidAPIKeyException:
48+
st.error('Invalid OpenAI API Key. Please enter a valid key.')
2449

25-
# Replicate Credentials
26-
api_key_container = st.empty()
27-
openai_api_key = api_key_container.text_input('Enter OpenAI API Key:', type='password')
28-
29-
# Check if OpenAI API Key is entered
30-
if openai_api_key:
31-
# Store the API key in the session state
32-
st.session_state['OPENAI_API_KEY'] = openai_api_key
33-
os.environ['OPENAI_API_KEY'] = openai_api_key
34-
success_message = st.success('API Key stored!')
35-
# Hide success message, input field, and chat messages after 3 seconds
36-
time.sleep(0)
37-
success_message.empty()
38-
api_key_container.empty()
39-
else:
40-
st.warning('Invalid OpenAI API Key. Please enter a valid key.')
41-
4250
# Store LLM generated responses
4351
if "conversations" not in st.session_state.keys():
4452
st.session_state.conversations = [{"title": "Default Conversation", "messages": [{"role": "assistant", "content": "How may I assist you today?"}]}]
@@ -49,48 +57,27 @@
4957
with st.chat_message(message["role"]):
5058
st.write(message["content"])
5159

52-
def clear_chat_history():
53-
st.session_state.conversations = [{"title": "Default Conversation", "messages": [{"role": "assistant", "content": "How may I assist you today?"}]}]
54-
#st.sidebar.button('Clear Chat History', on_click=clear_chat_history)
55-
56-
5760
# User-provided prompt
5861
if 'OPENAI_API_KEY' in st.session_state and st.session_state['OPENAI_API_KEY']:
5962
from LLMModel import RAG as RAG
6063
if prompt := st.chat_input():
61-
# Get the conversation context
6264
conversation = st.session_state.conversations[-1]
63-
64-
# Use the entire conversation context as input
65-
#context = "\n".join([message["content"] for message in conversation["messages"]])
66-
67-
# Add the user's input to the conversation
65+
context = "\n".join([message["content"] for message in conversation["messages"]])
6866
conversation["messages"].append({"role": "user", "content": prompt})
69-
70-
# Display user's input in the chat
7167
with st.chat_message("user"):
7268
st.write(prompt)
73-
74-
# Generate a new response considering the entire conversation context
7569
with st.chat_message("assistant"):
76-
start_time = time.time() # to calculate the time taken to generate the response
70+
start_time = time.time()
7771
with st.spinner("Generating..."):
78-
response = RAG.run(prompt)
79-
end_time = time.time() # to calculate the time taken to generate the response
72+
response = RAG.run(context + "\n" + prompt)
73+
end_time = time.time()
8074
placeholder = st.empty()
8175
full_response = ''
8276
for item in response:
8377
full_response += item
8478
placeholder.markdown(full_response)
8579
placeholder.markdown(full_response)
86-
87-
# Calculate the time taken
8880
response_time = end_time - start_time
8981
st.write(f"Response generated in {response_time:.2f} seconds.")
90-
91-
# Add the assistant's response to the conversation
9282
message = {"role": "assistant", "content": full_response}
9383
conversation["messages"].append(message)
94-
95-
96-
1.28 MB
Binary file not shown.
560 KB
Binary file not shown.

0 commit comments

Comments
 (0)