Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 34 additions & 66 deletions rag_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,82 +184,50 @@ def setup_vectorstore(self, file_paths):
retriever = vector_store.as_retriever()
return retriever

def get_relevant_document(self, query, threshold=0.5):
try:
if isinstance(query, dict):
if "query" in query:
query = query["query"]
else:
query = next(iter(query.values()))

query = str(query)
def get_relevant_document(self, query):
results = self.retriever.invoke(query)

results = self.retriever.invoke(query)
print(f"[DEBUG] Retrieved results: {results}")

print(f"[DEBUG] Retrieved results: {results}")

if results and len(results) > 0:
return results[0], 1.0
return None, 0.0

except Exception as e:
print(f"Error in get_relevant_document: {e}")
return None, 0.0
if results:
return results[:2], 1.0
return None, 0.0

def run(self, question):
"""
Build the QA chain and process the output from model generation.
Apply double prompting to make answers child-friendly.
Print the actual prompts sent to the models.
"""
try:
doc_result, _ = self.get_relevant_document(question)

context_text = format_docs(doc_result) if doc_result else ""

doc_result, _ = self.get_relevant_document(question)
context = doc_result.page_content if doc_result else "No relevant documents were found. So, context is empty"

def print_prompt_before_model(x):
prompt_text = combine_messages(x)
print("\nPrompt sent to main model:\n" + "-" * 40)
print(prompt_text)
print("-" * 40 + "\n")
return prompt_text

first_chain = (
self.prompt
| print_prompt_before_model
| self.model
| extract_answer_from_output
)

first_response = first_chain.invoke({
"question": question,
"context": context
})

def print_child_prompt_before_model(x):
child_prompt_text = combine_messages(x)
print("\nPrompt sent to child model:\n" + "-" * 40)
print(child_prompt_text)
print("-" * 40 + "\n")
return child_prompt_text

second_chain = (
self.child_prompt
| print_child_prompt_before_model
| self.simplify_model
| extract_answer_from_output
)

final_response = second_chain.invoke({
"original_answer": first_response
})

return trim_incomplete_sentence(final_response)

except Exception as e:
print(f"Error in run method: {e}")
if not context_text.strip():
return "I couldn't find an answer in the documents."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all answers are supposed to be in the documents, the documents are there for Sugar specific context not as the sole source of answers.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agree the model should still answer even if there's no doc context. That was a temp check I forgot to remove. Will fix it


first_chain = (
{"context": lambda x: context_text, "question": lambda x: x}
| self.prompt
| combine_messages
| self.model
| extract_answer_from_output
)
first_response = first_chain.invoke(question)

print(self.prompt.format(context=context_text, question=question))
# The chain applies: prompt -> combine messages -> model ->
# extract answer from output.

second_chain = (
{"original_answer": lambda x: x}
| self.child_prompt
| combine_messages
| self.simplify_model
| extract_answer_from_output
)

return f"Encountered an error: {str(e)}"
final_response = second_chain.invoke(first_response)
return trim_incomplete_sentence(final_response)


def main():
Expand Down