Skip to content
Open
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 76 additions & 45 deletions rag_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Uses a model from HuggingFace with optional 4-bit quantization

import os
import re
import argparse
import torch
from transformers import pipeline
Expand All @@ -14,34 +15,50 @@


PROMPT_TEMPLATE = """
You are a highly intelligent Python coding assistant built for kids using the Sugar Learning Platform.
1. Focus on coding-related problems, errors, and explanations.
2. Use the knowledge from the provided Pygame, GTK, and Sugar Toolkit documentation.
3. Provide complete, clear and concise answers.
4. Your answer must be easy to understand for kids.
5. Always include Sugar-specific guidance when relevant to the question.
You are a smart and helpful assistant designed to answer coding questions using context given.

1. The context above contains the relevant information needed to answer this specific question.
2. Use the information from this context to formulate your answer.
3. Prioritize and include any relevant details from the context.
4. Always answer in clear, complete and helpful way.

Context: {context}
Question: {question}

Answer:
"""

CHILD_FRIENDLY_PROMPT = """
Your task is to answer children's questions using simple language.
You will be given an answer, you will have to paraphrase it.
Explain any difficult words in a way a 5-12-years-old can understand.
You are a helpful assistant who rewrites answers so that children aged 3 to 10 can understand them.

Original answer: {original_answer}
Your task:
- ONLY rewrite the Original answer using simple words and short sentences.
- Do NOT explain what the answer means.
- Do NOT explain what you're doing.
- Do NOT say anything extra.
- Do NOT repeat ideas.
- Do NOT add tips or encouragement.

Child-friendly answer:
"""
Just give the rewritten answer. Nothing else.

Original answer:
{original_answer}

Child-friendly answer:
"""

def format_docs(docs):
"""Return all document content separated by two newlines."""
return "\n\n".join(doc.page_content for doc in docs)


def trim_incomplete_sentence(text):
matches = list(re.finditer(r'\.\s', text))
if matches:
last_complete = matches[-1].end()
return text[:last_complete].strip()
else:
return text.strip()

def combine_messages(x):
"""
If 'x' has a method to_messages, combine message content with newline.
Expand Down Expand Up @@ -92,23 +109,26 @@ def __init__(self, model="Qwen/Qwen2-1.5B-Instruct",
"text-generation",
model=model_obj,
tokenizer=tokenizer,
max_length=1024,
truncation=True,
max_new_tokens=1024,
temperature=0.3,
truncation=True
)

tokenizer2 = AutoTokenizer.from_pretrained(model)
self.simplify_model = pipeline(
"text-generation",
model=model_obj,
tokenizer=tokenizer2,
max_length=1024,
truncation=True,
max_new_tokens=1024,
temperature=0.3,
truncation=True
)
else:
self.model = pipeline(
"text-generation",
model=model,
max_length=1024,
max_new_tokens=1024,
temperature=0.3,
truncation=True,
torch_dtype=torch.float16,
device=0 if torch.cuda.is_available() else -1,
Expand All @@ -117,7 +137,8 @@ def __init__(self, model="Qwen/Qwen2-1.5B-Instruct",
self.simplify_model = pipeline(
"text-generation",
model=model,
max_length=1024,
max_new_tokens=1024,
temperature=0.3,
truncation=True,
torch_dtype=torch.float16,
device=0 if torch.cuda.is_available() else -1,
Expand All @@ -133,15 +154,17 @@ def set_model(self, model):
self.model = pipeline(
"text-generation",
model=model,
max_length=1024,
max_new_tokens=1024,
temperature=0.3,
truncation=True,
torch_dtype=torch.float16
)

self.simplify_model = pipeline(
"text-generation",
model=model,
max_length=1024,
max_new_tokens=1024,
temperature=0.3,
truncation=True,
torch_dtype=torch.float16
)
Expand All @@ -166,53 +189,58 @@ def setup_vectorstore(self, file_paths):
retriever = vector_store.as_retriever()
return retriever

def get_relevant_document(self, query, threshold=0.5):
def get_relevant_document(self, query):
results = self.retriever.invoke(query)

print(f"[DEBUG] Retrieved results: {results}")

if results:
top_result = results[0]
score = top_result.metadata.get("score", 0.0)
if score >= threshold:
return top_result, score
return results[:2], 1.0
return None, 0.0

def run(self, question):
"""
Build the QA chain and process the output from model generation.
Apply double prompting to make answers child-friendly.
"""
# Build the chain components:
chain_input = {
"context": self.retriever | format_docs,
"question": RunnablePassthrough()
}
# The chain applies: prompt -> combine messages -> model ->
# extract answer from output.
doc_result, _ = self.get_relevant_document(question)

print(f"TOP RELEVANT DOCUMENT CONTENT: {doc_result}")

context_text = format_docs(doc_result) if doc_result else ""

if not context_text.strip():
return "I couldn't find an answer in the documents."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all answers are supposed to be in the documents, the documents are there for Sugar specific context not as the sole source of answers.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agree the model should still answer even if there's no doc context. That was a temp check I forgot to remove. Will fix it


first_chain = (
chain_input
{"context": lambda x: context_text, "question": lambda x: x}
| self.prompt
| combine_messages
| self.model # Use the first model
| self.model
| extract_answer_from_output
)
doc_result, _ = self.get_relevant_document(question)
if doc_result:
first_response = first_chain.invoke({
"query": question,
"context": doc_result.page_content
})
else:
first_response = first_chain.invoke(question)
first_response = first_chain.invoke(question)

print(f"FIRST RESPONSE: {first_response}")

print(self.prompt.format(context=context_text, question=question))
# The chain applies: prompt -> combine messages -> model ->
# extract answer from output.

second_chain = (
{"original_answer": lambda x: x}
| self.child_prompt
| combine_messages
| self.simplify_model
| extract_answer_from_output
)

print(f"CHILD PROMPT: {self.child_prompt.format(original_answer=first_response)}")

final_response = second_chain.invoke(first_response)
return final_response

print(f"FINAL: {final_response}")

return trim_incomplete_sentence(final_response)


def main():
Expand Down Expand Up @@ -257,7 +285,10 @@ def main():
print("Response:", response)
except Exception as e:
print(f"An error occurred: {e}")
import traceback
traceback.print_exc()


if __name__ == "__main__":
main()