-
Notifications
You must be signed in to change notification settings - Fork 43
feat: improve rag responses with better prompt and gen params. #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
0652404
684a604
e7bc03a
9053d59
c0bb209
7fd33c6
415cac4
396edd4
ed2c83a
c394eb8
03132cb
2f8c373
8061058
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| # Uses a model from HuggingFace with optional 4-bit quantization | ||
|
|
||
| import os | ||
| import re | ||
| import argparse | ||
| import torch | ||
| from transformers import pipeline | ||
|
|
@@ -14,34 +15,46 @@ | |
|
|
||
|
|
||
| PROMPT_TEMPLATE = """ | ||
| You are a highly intelligent Python coding assistant built for kids using the Sugar Learning Platform. | ||
| 1. Focus on coding-related problems, errors, and explanations. | ||
| 2. Use the knowledge from the provided Pygame, GTK, and Sugar Toolkit documentation. | ||
| 3. Provide complete, clear and concise answers. | ||
| 4. Your answer must be easy to understand for kids. | ||
| 5. Always include Sugar-specific guidance when relevant to the question. | ||
|
|
||
| You are a smart and helpful assistant designed to answer coding questions using the Sugar Learning Platform. | ||
| Instructions: | ||
| 1. You must ONLY use the information from the provided context to answer the question. | ||
| 2. You must NOT use outside knowledge if the context provides an answer. If the context is empty or unrelated, use your general knowledge. | ||
| 3. Do NOT mention the context, documents, or how the answer was generated. Just provide the answer naturally and clearly. | ||
| 4. When possible, prioritize and include any relevant details from the context. | ||
| 5. Always answer in a concise, accurate, and helpful manner. | ||
|
|
||
| Context: {context} | ||
| Question: {question} | ||
|
|
||
| Answer: | ||
| """ | ||
|
|
||
|
|
||
| CHILD_FRIENDLY_PROMPT = """ | ||
| Your task is to answer children's questions using simple language. | ||
| You will be given an answer, you will have to paraphrase it. | ||
| Explain any difficult words in a way a 5-12-years-old can understand. | ||
| You are a friendly teacher talking to a child aged 3 to 10 years old. | ||
|
|
||
| Original answer: {original_answer} | ||
| Rewrite the answer below using simple words and short sentences so a young child can understand it. | ||
|
|
||
| Child-friendly answer: | ||
| """ | ||
| Include examples if needed. Stay close to the original meaning. Do not add extra commentary or explanation about what you are doing. | ||
|
|
||
| Here is the answer to simplify: | ||
| {original_answer} | ||
|
|
||
| Child-friendly answer: | ||
| """ | ||
|
|
||
| def format_docs(docs): | ||
| """Return all document content separated by two newlines.""" | ||
| return "\n\n".join(doc.page_content for doc in docs) | ||
|
|
||
|
|
||
| def trim_incomplete_sentence(text): | ||
| matches = list(re.finditer(r'\.\s', text)) | ||
| if matches: | ||
| last_complete = matches[-1].end() | ||
| return text[:last_complete].strip() | ||
| else: | ||
| return text.strip() | ||
|
|
||
| def combine_messages(x): | ||
| """ | ||
| If 'x' has a method to_messages, combine message content with newline. | ||
|
|
@@ -92,23 +105,41 @@ def __init__(self, model="Qwen/Qwen2-1.5B-Instruct", | |
| "text-generation", | ||
| model=model_obj, | ||
| tokenizer=tokenizer, | ||
| max_length=1024, | ||
| truncation=True, | ||
| max_new_tokens=512, | ||
| return_full_text=False, | ||
| do_sample=False, | ||
| temperature=None, | ||
| top_p=None, | ||
| top_k=None, | ||
| repetition_penalty=1.2, | ||
| truncation=True | ||
| ) | ||
|
|
||
| tokenizer2 = AutoTokenizer.from_pretrained(model) | ||
| self.simplify_model = pipeline( | ||
| "text-generation", | ||
| model=model_obj, | ||
| tokenizer=tokenizer2, | ||
| max_length=1024, | ||
| truncation=True, | ||
| max_new_tokens=512, | ||
| return_full_text=False, | ||
| do_sample=False, | ||
| temperature=None, | ||
| top_p=None, | ||
| top_k=None, | ||
| repetition_penalty=1.2, | ||
| truncation=True | ||
| ) | ||
| else: | ||
| self.model = pipeline( | ||
| "text-generation", | ||
| model=model, | ||
| max_length=1024, | ||
| max_new_tokens=512, | ||
Khan-Ramsha marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return_full_text=False, | ||
| do_sample=False, | ||
| temperature=None, | ||
| top_p=None, | ||
| top_k=None, | ||
| repetition_penalty=1.2, | ||
| truncation=True, | ||
| torch_dtype=torch.float16, | ||
| device=0 if torch.cuda.is_available() else -1, | ||
|
|
@@ -117,7 +148,13 @@ def __init__(self, model="Qwen/Qwen2-1.5B-Instruct", | |
| self.simplify_model = pipeline( | ||
| "text-generation", | ||
| model=model, | ||
| max_length=1024, | ||
| max_new_tokens=512, | ||
| return_full_text=False, | ||
| do_sample=False, | ||
| temperature=None, | ||
| top_p=None, | ||
| top_k=None, | ||
| repetition_penalty=1.2, | ||
| truncation=True, | ||
| torch_dtype=torch.float16, | ||
| device=0 if torch.cuda.is_available() else -1, | ||
|
|
@@ -133,15 +170,27 @@ def set_model(self, model): | |
| self.model = pipeline( | ||
| "text-generation", | ||
| model=model, | ||
| max_length=1024, | ||
| max_new_tokens=512, | ||
Khan-Ramsha marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return_full_text=False, | ||
| do_sample=False, | ||
| temperature=None, | ||
| top_p=None, | ||
| top_k=None, | ||
| repetition_penalty=1.2, | ||
| truncation=True, | ||
| torch_dtype=torch.float16 | ||
| ) | ||
|
|
||
| self.simplify_model = pipeline( | ||
| "text-generation", | ||
| model=model, | ||
| max_length=1024, | ||
| max_new_tokens=512, | ||
Khan-Ramsha marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return_full_text=False, | ||
| do_sample=False, | ||
| temperature=None, | ||
| top_p=None, | ||
| top_k=None, | ||
| repetition_penalty=1.2, | ||
| truncation=True, | ||
| torch_dtype=torch.float16 | ||
| ) | ||
|
|
@@ -167,52 +216,80 @@ def setup_vectorstore(self, file_paths): | |
| return retriever | ||
|
|
||
| def get_relevant_document(self, query, threshold=0.5): | ||
| results = self.retriever.invoke(query) | ||
| if results: | ||
| top_result = results[0] | ||
| score = top_result.metadata.get("score", 0.0) | ||
| if score >= threshold: | ||
| return top_result, score | ||
| return None, 0.0 | ||
| try: | ||
| if isinstance(query, dict): | ||
| if "query" in query: | ||
| query = query["query"] | ||
| else: | ||
| query = next(iter(query.values())) | ||
|
|
||
| query = str(query) | ||
|
|
||
| results = self.retriever.invoke(query) | ||
|
|
||
| print(f"[DEBUG] Retrieved results: {results}") | ||
Khan-Ramsha marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if results and len(results) > 0: | ||
| return results[0], 1.0 | ||
| return None, 0.0 | ||
|
|
||
| except Exception as e: | ||
| print(f"Error in get_relevant_document: {e}") | ||
| return None, 0.0 | ||
|
|
||
| def run(self, question): | ||
| """ | ||
| Build the QA chain and process the output from model generation. | ||
| Apply double prompting to make answers child-friendly. | ||
| Print the actual prompts sent to the models. | ||
| """ | ||
| # Build the chain components: | ||
| chain_input = { | ||
| "context": self.retriever | format_docs, | ||
| "question": RunnablePassthrough() | ||
| } | ||
| # The chain applies: prompt -> combine messages -> model -> | ||
| # extract answer from output. | ||
| first_chain = ( | ||
| chain_input | ||
| | self.prompt | ||
| | combine_messages | ||
| | self.model # Use the first model | ||
| | extract_answer_from_output | ||
| ) | ||
| doc_result, _ = self.get_relevant_document(question) | ||
| if doc_result: | ||
| try: | ||
|
|
||
| doc_result, _ = self.get_relevant_document(question) | ||
| context = doc_result.page_content if doc_result else "No relevant documents were found. So, context is empty" | ||
|
|
||
| def print_prompt_before_model(x): | ||
| prompt_text = combine_messages(x) | ||
| print("\nPrompt sent to main model:\n" + "-" * 40) | ||
| print(prompt_text) | ||
| print("-" * 40 + "\n") | ||
| return prompt_text | ||
|
|
||
| first_chain = ( | ||
| self.prompt | ||
| | print_prompt_before_model | ||
| | self.model | ||
| | extract_answer_from_output | ||
| ) | ||
|
|
||
| first_response = first_chain.invoke({ | ||
| "query": question, | ||
| "context": doc_result.page_content | ||
| "question": question, | ||
|
||
| "context": context | ||
| }) | ||
| else: | ||
| first_response = first_chain.invoke(question) | ||
|
|
||
| second_chain = ( | ||
| {"original_answer": lambda x: x} | ||
| | self.child_prompt | ||
| | combine_messages | ||
| | self.simplify_model | ||
| | extract_answer_from_output | ||
| ) | ||
|
|
||
| final_response = second_chain.invoke(first_response) | ||
| return final_response | ||
|
|
||
| def print_child_prompt_before_model(x): | ||
| child_prompt_text = combine_messages(x) | ||
| print("\nPrompt sent to child model:\n" + "-" * 40) | ||
| print(child_prompt_text) | ||
| print("-" * 40 + "\n") | ||
| return child_prompt_text | ||
|
|
||
| second_chain = ( | ||
| self.child_prompt | ||
| | print_child_prompt_before_model | ||
| | self.simplify_model | ||
| | extract_answer_from_output | ||
| ) | ||
|
|
||
| final_response = second_chain.invoke({ | ||
| "original_answer": first_response | ||
| }) | ||
|
|
||
| return trim_incomplete_sentence(final_response) | ||
|
|
||
| except Exception as e: | ||
| print(f"Error in run method: {e}") | ||
|
|
||
| return f"Encountered an error: {str(e)}" | ||
|
|
||
|
|
||
| def main(): | ||
|
|
@@ -223,7 +300,8 @@ def main(): | |
| choices=[ | ||
| 'bigscience/bloom-1b1', | ||
| 'facebook/opt-350m', | ||
| 'EleutherAI/gpt-neo-1.3B' | ||
| 'EleutherAI/gpt-neo-1.3B', | ||
| 'Qwen/Qwen2-1.5B-Instruct' | ||
Khan-Ramsha marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ], | ||
| default='bigscience/bloom-1b1', | ||
| help='Model name to use for text generation' | ||
|
|
@@ -257,7 +335,9 @@ def main(): | |
| print("Response:", response) | ||
| except Exception as e: | ||
| print(f"An error occurred: {e}") | ||
| import traceback | ||
| traceback.print_exc() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
| main() | ||
Khan-Ramsha marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
Uh oh!
There was an error while loading. Please reload this page.