-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathassistant.py
More file actions
559 lines (463 loc) · 25.8 KB
/
assistant.py
File metadata and controls
559 lines (463 loc) · 25.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
import re
import logging
import traceback
import json
import time
from typing import Dict, List, Any, Optional, Union
from db_handler import get_db_connection, extract_schema, execute_query
from groq import GroqError, APIError, APIConnectionError
import os
from dotenv import load_dotenv
from pinecone import Pinecone
import hashlib
load_dotenv()
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("assistant.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
class VoiceAssistantError(Exception):
"""Base exception class for VoiceAssistant errors"""
pass
class DatabaseConnectionError(VoiceAssistantError):
"""Exception raised for database connection issues"""
pass
class SchemaExtractionError(VoiceAssistantError):
"""Exception raised when schema extraction fails"""
pass
class EmbeddingError(VoiceAssistantError):
"""Exception raised when document embedding fails"""
pass
class TranscriptionError(VoiceAssistantError):
"""Exception raised when audio transcription fails"""
pass
class QueryGenerationError(VoiceAssistantError):
"""Exception raised when SQL/MongoDB query generation fails"""
pass
class QueryExecutionError(VoiceAssistantError):
"""Exception raised when query execution fails"""
pass
class VoiceAssistant:
def __init__(self, groq_model, max_retries=3, retry_delay=1):
"""
Initialize the VoiceAssistant with error handling and retry logic.
Args:
groq_model: LLM model for query generation and response formatting
max_retries: Maximum number of retries for external API calls
retry_delay: Delay between retries in seconds
"""
self.llm_model = groq_model
self.max_retries = max_retries
self.retry_delay = retry_delay
self.model_name = "llama3-8b-8192"
self.pinecone_api_key = os.environ.get('PINECONE_API_KEY')
self.pinecone_index_name = os.environ.get('PINECONE_INDEX_NAME')
# Initialize Pinecone
try:
self.pc = Pinecone(api_key=self.pinecone_api_key)
self.index = self.pc.Index(self.pinecone_index_name)
logger.info("Pinecone client initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize Pinecone: {str(e)}")
raise
# try:
# self._setup_pinecone_vectors()
# logger.info("VoiceAssistant initialized successfully with Pinecone")
# except SchemaExtractionError as e:
# logger.error(f"Failed to extract schema: {str(e)}")
# except EmbeddingError as e:
# logger.error(f"Failed to create embeddings: {str(e)}")
# except Exception as e:
# logger.error(f"Initialization error: {str(e)}\n{traceback.format_exc()}")
def _setup_pinecone_vectors(self, db_config):
"""
Set up Pinecone vectors with schema data.
Raises:
SchemaExtractionError: If schema extraction fails
EmbeddingError: If document embedding fails
"""
try:
# Extract schema with error handling
schema_texts = self._extract_schema_with_retry(db_config)
if not schema_texts:
raise SchemaExtractionError("Failed to extract schema or schema is empty")
# Combine all schema texts into a single string
combined_schema_text = "\n".join(schema_texts)
logger.info(f"Combined schema text length: {len(combined_schema_text)} characters")
# Check if a vector for the combined schema already exists
vector_id = self._generate_vector_id(combined_schema_text)
result = self.index.fetch(ids=[vector_id])
if result.vectors:
logger.info("Schema vector already exists in Pinecone")
return vector_id
# Upsert the combined schema
self._upsert_with_inference_api([combined_schema_text],vector_id)
logger.info("Successfully set up Pinecone vector for combined schema")
return vector_id
except SchemaExtractionError:
raise
except EmbeddingError:
raise
except Exception as e:
logger.error(f"Pinecone setup error: {str(e)}\n{traceback.format_exc()}")
raise VoiceAssistantError(f"Error setting up Pinecone vectors: {str(e)}")
def _upsert_with_inference_api(self, schema_texts: List[str], vector_id):
"""
Upsert a single document for the combined schema using Pinecone's Inference API.
Args:
schema_texts: List containing a single combined schema text
"""
try:
# Expect a single schema text (combined)
if len(schema_texts) != 1:
raise EmbeddingError("Expected a single combined schema text")
combined_schema_text = schema_texts[0]
print("combined ",combined_schema_text)
# Generate embedding for the combined schema
response = self.pc.inference.embed(
model="llama-text-embed-v2", # Match the model used in similarity search
inputs=[combined_schema_text],
parameters={"input_type": "passage"}
)
# Extract the embedding
embedding = response.data[0]['values']
logger.info(f"Generated embedding for combined schema: dimension {len(embedding)}")
vector = {
'id': vector_id,
'values': embedding,
'metadata': {
'text': combined_schema_text,
'type': 'schema'
}
}
# Upsert the single vector
self.index.upsert(vectors=[vector])
logger.info("Successfully upserted combined schema vector")
except Exception as e:
logger.error(f"Upsert failed: {str(e)}")
raise EmbeddingError(f"Failed to upsert combined schema vector: {str(e)}")
# Don't raise - the assistant can still work without pre-stored vectors
def _generate_vector_id(self, text: str) -> str:
"""Generate a unique ID for a vector based on text content."""
return hashlib.md5(text.encode()).hexdigest()
def _check_existing_vectors(self, schema_texts: List[str]) -> bool:
"""Check if vectors for the current schema already exist in Pinecone."""
try:
# Check if any of the schema text vectors exist
for text in schema_texts[:1]: # Check first one as sample
vector_id = self._generate_vector_id(text)
result = self.index.fetch(ids=[vector_id])
if result.vectors:
return True
return False
except Exception as e:
logger.warning(f"Error checking existing vectors: {str(e)}")
return False
def _cosine_similarity(self,vec1, vec2):
if len(vec1) != len(vec2):
raise ValueError("Vector lengths must match")
dot_product = sum(a * b for a, b in zip(vec1, vec2))
norm_a = sum(a * a for a in vec1) ** 0.5
norm_b = sum(b * b for b in vec2) ** 0.5
if norm_a == 0 or norm_b == 0:
return 0.0
return dot_product / (norm_a * norm_b)
def _similarity_search_pinecone(self, query: str, k: int = 1, vector_id=None):
"""
Perform similarity search using Pinecone Inference API.
"""
try:
# Generate embeddings for the query using Pinecone's Inference API
response = self.pc.inference.embed(
model="llama-text-embed-v2",
inputs=[query],
parameters={"input_type": "query"}
)
query_embedding = response.data[0]['values']
# Fetch existing vector from Pinecone
fetch_result = self.index.fetch(ids=[vector_id])
if vector_id not in fetch_result.vectors:
logger.warning(f"No vector found for ID: {vector_id}")
return []
stored_vector = fetch_result.vectors[vector_id]
# Fix: Access Vector object attributes properly
stored_embedding = stored_vector.values # Use .values attribute, not ['values']
# Fix: Access metadata properly
metadata = stored_vector.metadata or {} # Use .metadata attribute
schema_text = metadata.get('text', '')
score = self._cosine_similarity(query_embedding, stored_embedding)
logger.info(f"Similarity score: {score}")
# Create a mock document object
class MockDocument:
def __init__(self, content):
self.page_content = content
return [(MockDocument(schema_text), score)]
except Exception as e:
logger.error(f"Pinecone similarity search error: {str(e)}")
return []
def _extract_schema_with_retry(self, db_config) -> List[str]:
"""
Extract database schema with retry logic.
Returns:
List of schema text strings
Raises:
SchemaExtractionError: If schema extraction fails after retries
"""
for attempt in range(self.max_retries):
try:
return extract_schema(db_config)
except Exception as e:
logger.warning(f"Schema extraction attempt {attempt+1} failed: {str(e)}")
if attempt < self.max_retries - 1:
time.sleep(self.retry_delay)
else:
logger.error(f"Schema extraction failed after {self.max_retries} attempts")
raise SchemaExtractionError(f"Failed to extract schema after {self.max_retries} attempts: {str(e)}")
def get_response(self, question: str, vector_id: str, db_config) -> str:
"""
Get response to user question with comprehensive error handling.
Args:
question: User question or prompt
Returns:
Response text
"""
try:
# Validate input
if not question or not isinstance(question, str):
return "I need a valid question to assist you. Could you please try again?"
# Check for common greetings and pleasantries first
greeting_patterns = [
r'\b(hi|hello|hey|greetings|howdy)\b',
r'\b(thank you|thanks)\b',
r'\b(good morning|good afternoon|good evening)\b'
]
for pattern in greeting_patterns:
if re.search(pattern, question.lower()):
return "Hello! I'm your database assistant. How can I help with your database queries today?"
# Perform similarity search with Pinecone
# try:
# retrieved_docs = self._similarity_search_pinecone(question, k=1,vector_id=vector_id)
# if not retrieved_docs or retrieved_docs[0][1] > 1.0:
# return "I can only answer questions about your database schema. Your query appears to be out of context."
# schema_text = retrieved_docs[0][0].page_content
# except Exception as e:
# logger.error(f"Similarity search error: {str(e)}\n{traceback.format_exc()}")
# return "I'm having trouble finding relevant information in your database schema. Could you try a more specific question about your database?"
try:
schema_text = self._extract_schema_with_retry(db_config)
except SchemaExtractionError as e:
logger.error(f"Schema extraction error: {str(e)}")
return f"I encountered an error extracting the database schema: {str(e)}. Could you try again?"
except Exception as e:
logger.error(f"Unexpected schema extraction error: {str(e)}\n{traceback.format_exc()}")
return "I had an unexpected issue extracting the database schema. Please try again with a clearer question about your database."
# Generate query with error handling
try:
query = self._generate_query(question, schema_text, db_config)
print(query)
if not query:
return "I couldn't generate a valid database query from your question. Could you rephrase it to be more specific about the database information you're looking for?"
except QueryGenerationError as e:
logger.error(f"Query generation error: {str(e)}")
return f"I encountered an error generating a database query: {str(e)}. Could you try rephrasing your question?"
except Exception as e:
logger.error(f"Unexpected query generation error: {str(e)}\n{traceback.format_exc()}")
return "I had an unexpected issue generating a database query. Please try again with a clearer question about your database."
# Execute query with error handling
try:
sql_response = execute_query(db_config, query)
except Exception as e:
logger.error(f"Query execution error: {str(e)}\n{traceback.format_exc()}")
return f"I encountered an error executing the database query. The database returned: {str(e)}"
# Generate final response with error handling
try:
return self._generate_final_response(sql_response, question)
except Exception as e:
logger.error(f"Response generation error: {str(e)}\n{traceback.format_exc()}")
# Fallback to returning raw SQL response if formatting fails
return f"Here are the raw results from your query (I had trouble formatting them nicely): {sql_response}"
except Exception as e:
logger.error(f"Unexpected error in get_response: {str(e)}\n{traceback.format_exc()}")
return "I encountered an unexpected error while processing your question. Please try again or check your database connection."
def _generate_query(self, question: str, schema_text: str,db_config) -> str:
"""
Generate database query with error handling using Groq SDK.
Args:
question: User question
schema_text: Combined database schema text
Returns:
Generated query string
Raises:
QueryGenerationError: If query generation fails
"""
try:
# Preprocess query to handle common grammatical issues
question = question.lower().replace("student", "students").strip()
logger.info(f"Preprocessed question: {question}")
# Select appropriate system prompt
if db_config['db_type'] in ['mysql', 'postgresql', 'sqlite', 'sqlserver']:
system_prompt =system_prompt = """
You are an expert SQL query generator that generates SQL for database-related questions.
IMPORTANT:
- The schema provided contains the ENTIRE database schema (all tables and columns).
- First, determine if the user question is about querying a database (retrieving, filtering, aggregating data, etc.).
- Second, check if the question is relevant to the provided schema.
- If the question mentions entities, tables, or columns NOT present in the schema, consider it irrelevant.
- Example: If schema is about 'students' and the question is about 'animals', mark it as irrelevant.
- If the question is NOT database-related OR is irrelevant to the schema:
- Respond with:
{{
"generated": false,
"query": null
}}
- If the question IS database-related AND relevant to the schema:
- Respond with:
{{
"generated": true,
"query": "<SQL query>"
}}
- Use the schema EXACTLY, referencing only existing tables and columns.
- For queries about 'all students' or similar, assume the 'students' table is relevant unless otherwise specified.
- The output must always be in strict JSON format with keys "generated" and "query".
Schema:
{schema}
User question:
{question}
"""
elif db_config['db_type'] == 'mongodb':
system_prompt = """You are an expert MongoDB query generator that generates queries for database-related questions.
IMPORTANT:
- The schema provided contains the ENTIRE database schema (all collections and fields).
- Determine if the question is about querying the database (e.g., retrieving, filtering, or aggregating data).
- If the question is NOT database-related (e.g., 'make tea', 'what's the weather'), respond ONLY with 'NOT_DB_QUERY'.
- If the question is database-related, generate ONLY the MongoDB query in the format: {'collection': '...', 'operation': 'find', 'query': {...}, 'projection': {...}}.
- Use the schema EXACTLY, referencing only existing collections and fields.
Schema:
{schema}
User question:
{question}"""
else:
raise QueryGenerationError(f"Unsupported database type: {db_config['db_type']}")
formatted_prompt = system_prompt.format(schema=schema_text, question=question)
# Execute query generation with retry logic
for attempt in range(self.max_retries):
try:
response = self.llm_model.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": formatted_prompt},
{"role": "user", "content": "Generate the query based on the above question and schema."}
],
temperature=0.1,
max_tokens=1000,
top_p=1,
stream=False
)
generated_query = response.choices[0].message.content.strip()
try:
generated_query = json.loads(generated_query)
except json.JSONDecodeError:
logger.error(f"Failed to parse model output as JSON: {generated_query}")
return ""
if not generated_query["generated"]:
logger.info("Model determined the query is not database-related")
return ""
# logger.info(f"Generated query: {generated_query}")
# # cleaned_response = re.sub(r"<think>.*?</think>", "", generated_query, flags=re.DOTALL).strip()
# # if not cleaned_response:
# # raise QueryGenerationError("Generated query is empty after cleaning")
# return cleaned_response
return generated_query["query"]
except Exception as e:
if attempt < self.max_retries - 1:
logger.warning(f"Query generation attempt {attempt+1} failed: {str(e)}")
time.sleep(self.retry_delay)
else:
raise QueryGenerationError(f"Failed to generate query after {self.max_retries} attempts: {str(e)}")
except QueryGenerationError:
raise
except Exception as e:
logger.error(f"Unexpected error in query generation: {str(e)}\n{traceback.format_exc()}")
raise QueryGenerationError(f"Unexpected error generating query: {str(e)}")
def _generate_final_response(self, sql_data: Union[str, Dict, List], question: str) -> str:
"""
Generate final natural language response with error handling.
Args:
sql_data: SQL query result data
question: Original user question
Returns:
Formatted response text
"""
try:
# Validate and prepare SQL data
if isinstance(sql_data, (dict, list)):
sql_response_str = json.dumps(sql_data, indent=2)
else:
sql_response_str = str(sql_data)
from string import Template
prompt_template = Template("""
You are a rendering assistant. Produce ONLY a minimal HTML snippet styled with Tailwind CSS classes.
Rules (mandatory):
- Output ONLY raw HTML. No markdown fences, no explanations, no comments.
- Do NOT include templating logic or placeholders of any kind (e.g., {{ row.name }}, {% ... %}).
- Do NOT include <script>, <style>, external CSS/JS links, or inline JavaScript.
- Keep it small and readable: prefer a single <div> container, short headings, and either a simple <table> or a compact <ul>.
- Use Tailwind utility classes only (e.g., p-4, text-gray-700, table-auto, border, rounded, shadow).
- If data appears tabular, render a table with a header row derived from the keys in the first object (if present). If keys are unknown/inconsistent, render a two-column “Key / Value” table for each row.
- If there are no results, render a compact empty-state card that says “No results found.”
- Never fabricate columns or values not present in the provided data. If you cannot determine column names, use generic headers like “Column 1”, “Column 2”, etc.
- Absolutely NO extra prose before or after the HTML.
Context for rendering (use this to decide headings and rows):
SQL Query Results:
$sql_response
User Question:
$question
""")
formatted_prompt = prompt_template.substitute(
sql_response=sql_response_str,
question=question
)
print("formatted prompt ",formatted_prompt)
# formatted_prompt = prompt.format(sql_response=sql_response_str, question=question)
# Execute response generation with retry logic
for attempt in range(self.max_retries):
try:
response = self.llm_model.chat.completions.create(
model=self.model_name, # e.g., "llama3-8b-8192" or your preferred model
messages=[
{"role": "system", "content": formatted_prompt},
{"role": "user", "content": "Render the results as HTML with Tailwind CSS only, no explanations or markdown."}
],
temperature=0.1,
max_tokens=1000,
top_p=1,
stream=False
)
# Extract the actual message content from ChatCompletion object
if hasattr(response, 'choices') and len(response.choices) > 0:
response_content = response.choices[0].message.content
else:
raise Exception("Invalid response format from LLM")
# Clean and validate response
cleaned_response = re.sub(r"<think>.*?</think>", "", response_content, flags=re.DOTALL).strip()
if not cleaned_response:
raise Exception("Generated response is empty after cleaning")
logger.info(f"Successfully generated final response------------- {cleaned_response}")
return cleaned_response
except (GroqError, APIError, APIConnectionError) as e:
if attempt < self.max_retries - 1:
logger.warning(f"Response generation attempt {attempt+1} failed: {str(e)}")
time.sleep(self.retry_delay)
else:
# On final failure, return a simple formatted response with the raw data
logger.error(f"Final response generation failed after {self.max_retries} attempts: {str(e)}")
return f"Here are the results from your database query (I had trouble creating a detailed explanation):\n\n{sql_response_str}"
except Exception as e:
logger.error(f"Error generating final response: {str(e)}\n{traceback.format_exc()}")
# Fallback to returning raw SQL response
return f"Here are the raw results from your query (I encountered an error while formatting the response):\n\n{sql_data}"