-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathapp.py
More file actions
212 lines (178 loc) · 8.2 KB
/
Copy pathapp.py
File metadata and controls
212 lines (178 loc) · 8.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import pickle
import os
import numpy as np
import openai
import streamlit as st
import boto3
# Get credentials from environment variables
aws_access_key_id = os.environ.get('AWS_ACCESS_KEY_ID')
aws_secret_access_key = os.environ.get('AWS_SECRET_ACCESS_KEY')
openai_api_key = os.environ.get('OPENAI_API_KEY')
# Setup OpenAI API Key
try:
openai.api_key = st.secrets.get("OPENAI_API_KEY", openai_api_key)
except (KeyError, AttributeError, RuntimeError, Exception):
openai.api_key = openai_api_key
# Constants
AWS_BUCKET_NAME = 'hcmbotknowledgesource'
DOCUMENT_STORE_FILE = 'document_store.pkl'
# Load the document store from S3 or local file
@st.cache_resource
def load_document_store():
try:
# Try to download from S3 first
s3 = boto3.client(
's3',
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key
)
try:
s3.download_file(AWS_BUCKET_NAME, DOCUMENT_STORE_FILE, DOCUMENT_STORE_FILE)
print(f"Downloaded document store from S3")
except Exception as e:
print(f"Could not download from S3: {e}")
# Load from local file
with open(DOCUMENT_STORE_FILE, 'rb') as f:
document_store = pickle.load(f)
print(f"Document store loaded with {len(document_store)} documents")
return document_store
except FileNotFoundError:
st.error("Error: The document store file was not found.")
return {}
except Exception as e:
st.error(f"Error loading document store: {e}")
return {}
# Load document store
document_store = load_document_store()
# Cosine similarity function for comparing embeddings with improved numerical stability
def cosine_similarity(vec1, vec2):
vec1_norm = np.linalg.norm(vec1)
vec2_norm = np.linalg.norm(vec2)
if vec1_norm < 1e-10 or vec2_norm < 1e-10:
return 0.0
return np.dot(vec1, vec2) / (vec1_norm * vec2_norm)
def generate_embeddings(texts, batch_size=10):
embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
try:
response = openai.Embedding.create(
model="text-embedding-3-small",
input=batch
)
embeddings.extend([embedding["embedding"] for embedding in response["data"]])
except Exception as e:
print(f"Error generating embeddings for batch {i}-{i+batch_size}: {e}")
return embeddings
# Function to check if content is relevant
def is_relevant_content(chunk, query):
"""Check if chunk contains actual relevant information rather than metadata."""
# Skip chunks that are mostly version numbers or deployment info
if chunk.count(':v') > 3 or chunk.count('-') > 10:
return False
# Extract key terms from query (excluding common words)
query_terms = set(term.lower() for term in query.split()
if term.lower() not in {'how', 'what', 'when', 'where', 'do', 'does', 'is', 'are', 'the'})
# Check if chunk contains any query terms
chunk_lower = chunk.lower()
terms_found = sum(1 for term in query_terms if term in chunk_lower)
return terms_found > 0
# Updated retrieve_relevant_chunks function with hybrid approach
def retrieve_relevant_chunks(query, top_k=5):
query_embedding = generate_embeddings([query])[0]
similarities = []
# Extract key terms from query for keyword matching
query_terms = set(term.lower() for term in query.split()
if term.lower() not in {'how', 'what', 'when', 'where', 'do', 'does', 'is', 'are', 'the'})
for doc_name, doc_data in document_store.items():
for chunk, chunk_embedding in zip(doc_data["chunks"], doc_data["embeddings"]):
# Calculate vector similarity
similarity = cosine_similarity(query_embedding, chunk_embedding)
# Calculate keyword match score
chunk_lower = chunk.lower()
keyword_matches = sum(1 for term in query_terms if term in chunk_lower)
# Boost score for slack PDFs if query seems to target them
source_boost = 0.1 if doc_data.get("source") == "slack_pdf" and "slack" in query.lower() else 0
# Combine scores (vector similarity is primary, keywords provide a boost)
combined_score = similarity + (keyword_matches * 0.05) + source_boost
similarities.append((chunk, combined_score, doc_name))
# Get top results by combined score
relevant_chunks = sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
return [(chunk, doc_name) for chunk, _, doc_name in relevant_chunks]
def chat_with_assistant(query):
with st.spinner("Retrieving relevant information..."):
relevant_chunks = retrieve_relevant_chunks(query)
context = "\n\n".join([f"Source ({doc}): {chunk}" for chunk, doc in relevant_chunks])
with st.spinner("Generating response..."):
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "system",
"content": """You are a precise assistant that answers questions based strictly on the provided context.
Rules:
1. Use ONLY information from the context
2. Keep exact terminology and steps from the source
3. If multiple sources have different information, specify which source you're using
4. If information isn't in the context, say "I don't have enough information"
5. For procedures, list exact steps in order
6. Include specific buttons, links, and UI elements mentioned in the source"""
},
{
"role": "user",
"content": f"Context:\n{context}\n\nQuestion: {query}"
}
],
temperature=0.3,
max_tokens=1000
)
answer = response.choices[0].message.content.strip()
return answer
# Streamlit interface
# Initialize session state for tracking question clicks
if 'question_clicks' not in st.session_state:
st.session_state.question_clicks = {
"What is Health Campaign Management?": 0,
"What are the steps involved in creating a KPI?": 0
}
def handle_trending_click(question):
# Update click count in session state
st.session_state.question_clicks[question] += 1
# Set the clicked question as the current query
st.session_state.query = question
return question
# Main interface
st.image("egovlogo.png", width=200)
st.title("Health Campaign Management (HCM) Support Bot [Beta version]")
# Notes Section
st.subheader("Note:")
st.markdown(
'<p style="color:red; font-size:16px;">Please try to be as in detail as possible with your prompt and use full forms for beta version, e.g., Health Campaign Management instead of HCM.</p>',
unsafe_allow_html=True,
)
# Trending Questions Section
st.subheader("Trending Questions")
col1, col2 = st.columns(2)
# First column of trending questions
with col1:
for question in list(st.session_state.question_clicks.keys())[:1]:
if st.button(f"📈 {question}", key=f"btn_{question}"):
query = handle_trending_click(question)
# Second column of trending questions
with col2:
for question in list(st.session_state.question_clicks.keys())[1:]:
if st.button(f"📈 {question}", key=f"btn_{question}"):
query = handle_trending_click(question)
# User input section
query = st.text_input("Ask a question:", key="query")
submit_button = st.button("Submit")
if submit_button:
if query.strip():
try:
st.write("Query Received:", query)
answer = chat_with_assistant(query)
st.write(f"Assistant's answer: {answer}")
except Exception as e:
st.error(f"An error occurred: {str(e)}")
else:
st.warning("Please enter a question before clicking Submit.")