Skip to content

Commit 3a0a680

Browse files
committed
Fix streamlit app + add cloud option
1 parent 0745057 commit 3a0a680

File tree

3 files changed

+57
-11
lines changed

3 files changed

+57
-11
lines changed

python/rubin/rag/app.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,21 @@
2828
from pathlib import Path
2929

3030
import streamlit as st
31-
from chatbot import configure_retriever, create_qa_chain, handle_user_input
3231
from dotenv import load_dotenv
3332
from langchain_community.chat_message_histories import (
3433
StreamlitChatMessageHistory,
3534
)
36-
from layout import setup_header_and_footer, setup_landing_page, setup_sidebar
35+
36+
from rubin.rag.chatbot import (
37+
configure_retriever_cloud,
38+
create_qa_chain,
39+
handle_user_input,
40+
)
41+
from rubin.rag.layout import (
42+
setup_header_and_footer,
43+
setup_landing_page,
44+
setup_sidebar,
45+
)
3746

3847
# Load environment variables from .env file
3948
load_dotenv()
@@ -56,7 +65,7 @@
5665
st.session_state.message_sent = False
5766

5867
# Configure the Weaviate retriever and QA chain
59-
retriever = configure_retriever()
68+
retriever = configure_retriever_cloud()
6069
qa_chain = create_qa_chain(retriever)
6170

6271
# Enable dynamic filtering based on user input

python/rubin/rag/chatbot.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
import streamlit as st
3030
import weaviate
31-
from custom_weaviate_vector_store import CustomWeaviateVectorStore
3231
from langchain.chains import create_retrieval_chain
3332
from langchain.chains.combine_documents import create_stuff_documents_chain
3433
from langchain.prompts.chat import (
@@ -40,18 +39,21 @@
4039
StreamlitChatMessageHistory,
4140
)
4241
from langchain_core.prompts import MessagesPlaceholder
42+
from langchain_core.vectorstores.base import VectorStoreRetriever
4343
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
44-
from streamlit_callback import get_streamlit_cb
4544
from weaviate.classes.init import Auth
4645

46+
from .custom_weaviate_vector_store import CustomWeaviateVectorStore
47+
from .streamlit_callback import get_streamlit_cb
48+
4749

4850
def submit_text() -> None:
4951
"""Submit the user input."""
5052
st.session_state.message_sent = True
5153

5254

5355
@st.cache_resource(ttl="1h")
54-
def configure_retriever() -> CustomWeaviateVectorStore:
56+
def configure_retriever() -> VectorStoreRetriever:
5557
"""Configure the Weaviate retriever."""
5658
openai_api_key = os.getenv("OPENAI_API_KEY")
5759
weaviate_api_key = os.getenv("WEAVIATE_API_KEY")
@@ -93,8 +95,43 @@ def configure_retriever() -> CustomWeaviateVectorStore:
9395
)
9496

9597

98+
@st.cache_resource(ttl="1h")
99+
def configure_retriever_cloud() -> VectorStoreRetriever:
100+
"""Configure the Weaviate retriever."""
101+
openai_api_key = os.getenv("OPENAI_API_KEY_PAID")
102+
weaviate_api_key = os.getenv("WEAVIATE_API_KEY_CLOUD")
103+
weaviate_url = os.getenv("WEAVIATE_URL")
104+
105+
if openai_api_key is None:
106+
raise ValueError("OPENAI_API_KEY environment variable is not set")
107+
if weaviate_api_key is None:
108+
raise ValueError("WEAVIATE_API_KEY environment variable is not set")
109+
if weaviate_url is None:
110+
raise ValueError("WEAVIATE_URL environment variable is not set")
111+
112+
client = weaviate.connect_to_weaviate_cloud(
113+
cluster_url=weaviate_url,
114+
auth_credentials=Auth.api_key(
115+
weaviate_api_key
116+
), # The API key to use for authentication
117+
headers={"X-OpenAI-Api-Key": openai_api_key},
118+
skip_init_checks=True,
119+
)
120+
121+
return CustomWeaviateVectorStore(
122+
client=client,
123+
index_name="LangChain_cloudtest1",
124+
text_key="text",
125+
embedding=OpenAIEmbeddings(),
126+
attributes=["source", "source_key"],
127+
).as_retriever(
128+
search_type="similarity",
129+
search_kwargs={"k": 6, "return_metadata": ["score"]},
130+
)
131+
132+
96133
def create_qa_chain(
97-
retriever: CustomWeaviateVectorStore,
134+
retriever: VectorStoreRetriever,
98135
) -> ChatPromptTemplate:
99136
"""Create a QA chain for the chatbot."""
100137
# Setup ChatOpenAI (Language Model)
@@ -138,8 +175,8 @@ def handle_user_input(
138175
# Define avatars for user and assistant messages
139176
avatars = {"human": "user", "ai": "assistant"}
140177
avatar_images = {
141-
"human": "./static/user_avatar.png",
142-
"ai": "./static/rubin_avatar_bw.png",
178+
"human": "../../../static/user_avatar.png",
179+
"ai": "../../../static/rubin_avatar_bw.png",
143180
}
144181

145182
for msg in msgs.messages:

python/rubin/rag/layout.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def setup_landing_page() -> None:
5151
with st.container():
5252
# Add logo (Make sure the logo is in your
5353
# working directory or provide the full path)
54-
st.image("./static/rubin_avatar_bw.png", clamp=True)
54+
st.image("../../../static/rubin_avatar_bw.png", clamp=True)
5555

5656
# Centered title and message
5757
st.markdown(
@@ -60,7 +60,7 @@ def setup_landing_page() -> None:
6060
)
6161
st.markdown(
6262
(
63-
"<h4 class='h4-landing-page'>Your dedicated"
63+
"<h4 class='h4-landing-page'>Your dedicated "
6464
"Rubin Observatory bot.</h4>"
6565
),
6666
unsafe_allow_html=True,

0 commit comments

Comments
 (0)