|
28 | 28 |
|
29 | 29 | import streamlit as st |
30 | 30 | import weaviate |
31 | | -from custom_weaviate_vector_store import CustomWeaviateVectorStore |
32 | 31 | from langchain.chains import create_retrieval_chain |
33 | 32 | from langchain.chains.combine_documents import create_stuff_documents_chain |
34 | 33 | from langchain.prompts.chat import ( |
|
40 | 39 | StreamlitChatMessageHistory, |
41 | 40 | ) |
42 | 41 | from langchain_core.prompts import MessagesPlaceholder |
| 42 | +from langchain_core.vectorstores.base import VectorStoreRetriever |
43 | 43 | from langchain_openai import ChatOpenAI, OpenAIEmbeddings |
44 | | -from streamlit_callback import get_streamlit_cb |
45 | 44 | from weaviate.classes.init import Auth |
46 | 45 |
|
| 46 | +from .custom_weaviate_vector_store import CustomWeaviateVectorStore |
| 47 | +from .streamlit_callback import get_streamlit_cb |
| 48 | + |
47 | 49 |
|
48 | 50 | def submit_text() -> None: |
49 | 51 | """Submit the user input.""" |
50 | 52 | st.session_state.message_sent = True |
51 | 53 |
|
52 | 54 |
|
53 | 55 | @st.cache_resource(ttl="1h") |
54 | | -def configure_retriever() -> CustomWeaviateVectorStore: |
| 56 | +def configure_retriever() -> VectorStoreRetriever: |
55 | 57 | """Configure the Weaviate retriever.""" |
56 | 58 | openai_api_key = os.getenv("OPENAI_API_KEY") |
57 | 59 | weaviate_api_key = os.getenv("WEAVIATE_API_KEY") |
@@ -93,8 +95,43 @@ def configure_retriever() -> CustomWeaviateVectorStore: |
93 | 95 | ) |
94 | 96 |
|
95 | 97 |
|
| 98 | +@st.cache_resource(ttl="1h") |
| 99 | +def configure_retriever_cloud() -> VectorStoreRetriever: |
| 100 | + """Configure the Weaviate retriever.""" |
| 101 | + openai_api_key = os.getenv("OPENAI_API_KEY_PAID") |
| 102 | + weaviate_api_key = os.getenv("WEAVIATE_API_KEY_CLOUD") |
| 103 | + weaviate_url = os.getenv("WEAVIATE_URL") |
| 104 | + |
| 105 | + if openai_api_key is None: |
| 106 | + raise ValueError("OPENAI_API_KEY environment variable is not set") |
| 107 | + if weaviate_api_key is None: |
| 108 | + raise ValueError("WEAVIATE_API_KEY environment variable is not set") |
| 109 | + if weaviate_url is None: |
| 110 | + raise ValueError("WEAVIATE_URL environment variable is not set") |
| 111 | + |
| 112 | + client = weaviate.connect_to_weaviate_cloud( |
| 113 | + cluster_url=weaviate_url, |
| 114 | + auth_credentials=Auth.api_key( |
| 115 | + weaviate_api_key |
| 116 | + ), # The API key to use for authentication |
| 117 | + headers={"X-OpenAI-Api-Key": openai_api_key}, |
| 118 | + skip_init_checks=True, |
| 119 | + ) |
| 120 | + |
| 121 | + return CustomWeaviateVectorStore( |
| 122 | + client=client, |
| 123 | + index_name="LangChain_cloudtest1", |
| 124 | + text_key="text", |
| 125 | + embedding=OpenAIEmbeddings(), |
| 126 | + attributes=["source", "source_key"], |
| 127 | + ).as_retriever( |
| 128 | + search_type="similarity", |
| 129 | + search_kwargs={"k": 6, "return_metadata": ["score"]}, |
| 130 | + ) |
| 131 | + |
| 132 | + |
96 | 133 | def create_qa_chain( |
97 | | - retriever: CustomWeaviateVectorStore, |
| 134 | + retriever: VectorStoreRetriever, |
98 | 135 | ) -> ChatPromptTemplate: |
99 | 136 | """Create a QA chain for the chatbot.""" |
100 | 137 | # Setup ChatOpenAI (Language Model) |
@@ -138,8 +175,8 @@ def handle_user_input( |
138 | 175 | # Define avatars for user and assistant messages |
139 | 176 | avatars = {"human": "user", "ai": "assistant"} |
140 | 177 | avatar_images = { |
141 | | - "human": "./static/user_avatar.png", |
142 | | - "ai": "./static/rubin_avatar_bw.png", |
| 178 | + "human": "../../../static/user_avatar.png", |
| 179 | + "ai": "../../../static/rubin_avatar_bw.png", |
143 | 180 | } |
144 | 181 |
|
145 | 182 | for msg in msgs.messages: |
|
0 commit comments