Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 142 additions & 62 deletions rag_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Uses a model from HuggingFace with optional 4-bit quantization

import os
import re
import argparse
import torch
from transformers import pipeline
Expand All @@ -14,34 +15,46 @@


PROMPT_TEMPLATE = """
You are a highly intelligent Python coding assistant built for kids using the Sugar Learning Platform.
1. Focus on coding-related problems, errors, and explanations.
2. Use the knowledge from the provided Pygame, GTK, and Sugar Toolkit documentation.
3. Provide complete, clear and concise answers.
4. Your answer must be easy to understand for kids.
5. Always include Sugar-specific guidance when relevant to the question.

You are a smart and helpful assistant designed to answer coding questions using the Sugar Learning Platform.
Instructions:
1. You must ONLY use the information from the provided context to answer the question.
2. You must NOT use outside knowledge if the context provides an answer. If the context is empty or unrelated, use your general knowledge.
3. Do NOT mention the context, documents, or how the answer was generated. Just provide the answer naturally and clearly.
4. When possible, prioritize and include any relevant details from the context.
5. Always answer in a concise, accurate, and helpful manner.

Context: {context}
Question: {question}

Answer:
"""


CHILD_FRIENDLY_PROMPT = """
Your task is to answer children's questions using simple language.
You will be given an answer, you will have to paraphrase it.
Explain any difficult words in a way a 5-12-years-old can understand.
You are a friendly teacher talking to a child aged 3 to 10 years old.

Original answer: {original_answer}
Rewrite the answer below using simple words and short sentences so a young child can understand it.

Child-friendly answer:
"""
Include examples if needed. Stay close to the original meaning. Do not add extra commentary or explanation about what you are doing.

Here is the answer to simplify:
{original_answer}

Child-friendly answer:
"""

def format_docs(docs):
"""Return all document content separated by two newlines."""
return "\n\n".join(doc.page_content for doc in docs)


def trim_incomplete_sentence(text):
matches = list(re.finditer(r'\.\s', text))
if matches:
last_complete = matches[-1].end()
return text[:last_complete].strip()
else:
return text.strip()

def combine_messages(x):
"""
If 'x' has a method to_messages, combine message content with newline.
Expand Down Expand Up @@ -92,23 +105,41 @@ def __init__(self, model="Qwen/Qwen2-1.5B-Instruct",
"text-generation",
model=model_obj,
tokenizer=tokenizer,
max_length=1024,
truncation=True,
max_new_tokens=512,
return_full_text=False,
do_sample=False,
temperature=None,
top_p=None,
top_k=None,
repetition_penalty=1.2,
truncation=True
)

tokenizer2 = AutoTokenizer.from_pretrained(model)
self.simplify_model = pipeline(
"text-generation",
model=model_obj,
tokenizer=tokenizer2,
max_length=1024,
truncation=True,
max_new_tokens=512,
return_full_text=False,
do_sample=False,
temperature=None,
top_p=None,
top_k=None,
repetition_penalty=1.2,
truncation=True
)
else:
self.model = pipeline(
"text-generation",
model=model,
max_length=1024,
max_new_tokens=512,
return_full_text=False,
do_sample=False,
temperature=None,
top_p=None,
top_k=None,
repetition_penalty=1.2,
truncation=True,
torch_dtype=torch.float16,
device=0 if torch.cuda.is_available() else -1,
Expand All @@ -117,7 +148,13 @@ def __init__(self, model="Qwen/Qwen2-1.5B-Instruct",
self.simplify_model = pipeline(
"text-generation",
model=model,
max_length=1024,
max_new_tokens=512,
return_full_text=False,
do_sample=False,
temperature=None,
top_p=None,
top_k=None,
repetition_penalty=1.2,
truncation=True,
torch_dtype=torch.float16,
device=0 if torch.cuda.is_available() else -1,
Expand All @@ -133,15 +170,27 @@ def set_model(self, model):
self.model = pipeline(
"text-generation",
model=model,
max_length=1024,
max_new_tokens=512,
return_full_text=False,
do_sample=False,
temperature=None,
top_p=None,
top_k=None,
repetition_penalty=1.2,
truncation=True,
torch_dtype=torch.float16
)

self.simplify_model = pipeline(
"text-generation",
model=model,
max_length=1024,
max_new_tokens=512,
return_full_text=False,
do_sample=False,
temperature=None,
top_p=None,
top_k=None,
repetition_penalty=1.2,
truncation=True,
torch_dtype=torch.float16
)
Expand All @@ -167,52 +216,80 @@ def setup_vectorstore(self, file_paths):
return retriever

def get_relevant_document(self, query, threshold=0.5):
results = self.retriever.invoke(query)
if results:
top_result = results[0]
score = top_result.metadata.get("score", 0.0)
if score >= threshold:
return top_result, score
return None, 0.0
try:
if isinstance(query, dict):
if "query" in query:
query = query["query"]
else:
query = next(iter(query.values()))

query = str(query)

results = self.retriever.invoke(query)

print(f"[DEBUG] Retrieved results: {results}")
if results and len(results) > 0:
return results[0], 1.0
return None, 0.0

except Exception as e:
print(f"Error in get_relevant_document: {e}")
return None, 0.0

def run(self, question):
"""
Build the QA chain and process the output from model generation.
Apply double prompting to make answers child-friendly.
Print the actual prompts sent to the models.
"""
# Build the chain components:
chain_input = {
"context": self.retriever | format_docs,
"question": RunnablePassthrough()
}
# The chain applies: prompt -> combine messages -> model ->
# extract answer from output.
first_chain = (
chain_input
| self.prompt
| combine_messages
| self.model # Use the first model
| extract_answer_from_output
)
doc_result, _ = self.get_relevant_document(question)
if doc_result:
try:

doc_result, _ = self.get_relevant_document(question)
context = doc_result.page_content if doc_result else "No relevant documents were found. So, context is empty"

def print_prompt_before_model(x):
prompt_text = combine_messages(x)
print("\nPrompt sent to main model:\n" + "-" * 40)
print(prompt_text)
print("-" * 40 + "\n")
return prompt_text

first_chain = (
self.prompt
| print_prompt_before_model
| self.model
| extract_answer_from_output
)

first_response = first_chain.invoke({
"query": question,
"context": doc_result.page_content
"question": question,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again , I don't see how this change is required here 245-292

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again , I don't see how this change is required here 245-292

Chatprompttemplate expects key to be 'question', not query. As in the prompt i am passing '{question}'

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if am changing this key from 'question' to 'query ' again i need to do changes in the Prompt too

"context": context
})
else:
first_response = first_chain.invoke(question)

second_chain = (
{"original_answer": lambda x: x}
| self.child_prompt
| combine_messages
| self.simplify_model
| extract_answer_from_output
)

final_response = second_chain.invoke(first_response)
return final_response

def print_child_prompt_before_model(x):
child_prompt_text = combine_messages(x)
print("\nPrompt sent to child model:\n" + "-" * 40)
print(child_prompt_text)
print("-" * 40 + "\n")
return child_prompt_text

second_chain = (
self.child_prompt
| print_child_prompt_before_model
| self.simplify_model
| extract_answer_from_output
)

final_response = second_chain.invoke({
"original_answer": first_response
})

return trim_incomplete_sentence(final_response)

except Exception as e:
print(f"Error in run method: {e}")

return f"Encountered an error: {str(e)}"


def main():
Expand All @@ -223,7 +300,8 @@ def main():
choices=[
'bigscience/bloom-1b1',
'facebook/opt-350m',
'EleutherAI/gpt-neo-1.3B'
'EleutherAI/gpt-neo-1.3B',
'Qwen/Qwen2-1.5B-Instruct'
],
default='bigscience/bloom-1b1',
help='Model name to use for text generation'
Expand Down Expand Up @@ -257,7 +335,9 @@ def main():
print("Response:", response)
except Exception as e:
print(f"An error occurred: {e}")
import traceback
traceback.print_exc()


if __name__ == "__main__":
main()
main()