diff --git a/src/mvt/config.yaml b/src/mvt/config.yaml index ab35349..b5b8c78 100644 --- a/src/mvt/config.yaml +++ b/src/mvt/config.yaml @@ -12,10 +12,13 @@ html_files: "html_files" persist_directory: "faiss_index" host: "127.0.0.1" port: 8080 -system_prompt: "You are an assistant for question-answering tasks that are related to Linux Foundation Decentralized Trust former Hyperledger blockchain. Use the following pieces of retrieved context to answer the question. If you don't know the answer, say that you don't know." use_query_rewriting: false # Set to true to enable query rewriting -query_rewriting_prompt: "You are an assistant helping to rewrite user queries to make them more specific and effective for searching documents. The context is the Linux Foundation Decentralized Trust former Hyperledger. Please rewrite the human query to be more specific, detailed, and optimized for document retrieval, considering the context mentioned." logo_pth: "https://upload.wikimedia.org/wikipedia/en/thumb/e/e2/The_Founder_Institute_Logo.png/250px-The_Founder_Institute_Logo.png" nr_retrieved_documents: 5 max_download_retries: 3 # Number of retry attempts for failed downloads -retry_delay_seconds: 2 # Base delay between retries (increases with each attempt) \ No newline at end of file +retry_delay_seconds: 2 # Base delay between retries (increases with each attempt) + +# Write prompts to files +prompts: + system_prompt: src/mvt/prompts/system_prompt.txt + query_rewriting_prompt: src/mvt/prompts/query_rewriting_prompt.txt \ No newline at end of file diff --git a/src/mvt/main.py b/src/mvt/main.py index c443fdc..5a94bce 100644 --- a/src/mvt/main.py +++ b/src/mvt/main.py @@ -7,10 +7,18 @@ from langchain_core.prompts import ChatPromptTemplate from langchain.chains import create_retrieval_chain from langchain_mistralai.embeddings import MistralAIEmbeddings -#from langchain.retrievers import BM25Retriever, EnsembleRetriever from langchain_core.documents import Document from typing import List +def load_prompt_from_file(file_path): + """Load a prompt template from a file.""" + try: + with open(file_path, 'r') as file: + prompt_text = file.read().strip() + return prompt_text + except Exception as e: + print(f"Error loading prompt from {file_path}: {e}") + return None def get_ragchain(filter): # Read config data with database prompt overrides @@ -42,15 +50,19 @@ def get_ragchain(filter): model=config_data["model_name"], temperature=0.7 ) - + # Load local vector db docsearch = FAISS.load_local(config_data["persist_directory"], embeddings, allow_dangerous_deserialization=True) # Define a retriever interface retriever = docsearch.as_retriever(search_kwargs={"k": config_data["nr_retrieved_documents"], "filter": filter}) - # read prompt string from config file - prompt_str = config_data["system_prompt"] + # Load system prompt from file or use the one in config + prompt_path = config_data.get("prompts", {}).get("system_prompt") + if prompt_path and os.path.exists(prompt_path): + prompt_str = load_prompt_from_file(prompt_path) + else: + prompt_str = config_data["system_prompt"] # Answer question qa_system_prompt = ( @@ -65,6 +77,7 @@ def get_ragchain(filter): ("user", "{input}"), ] ) + question_answer_chain = create_stuff_documents_chain(model, qa_prompt) rag_chain = create_retrieval_chain(retriever, question_answer_chain) diff --git a/src/mvt/prompts/query_rewriting_prompt.txt b/src/mvt/prompts/query_rewriting_prompt.txt new file mode 100644 index 0000000..5d0f2f7 --- /dev/null +++ b/src/mvt/prompts/query_rewriting_prompt.txt @@ -0,0 +1 @@ +You are an assistant helping to rewrite user queries to make them more specific and effective for searching documents. The context is the Linux Foundation Decentralized Trust former Hyperledger. Please rewrite the human query to be more specific, detailed, and optimized for document retrieval, considering the context mentioned. \ No newline at end of file diff --git a/src/mvt/prompts/system_prompt.txt b/src/mvt/prompts/system_prompt.txt new file mode 100644 index 0000000..0896e34 --- /dev/null +++ b/src/mvt/prompts/system_prompt.txt @@ -0,0 +1 @@ +You are an assistant for question-answering tasks that are related to Linux Foundation Decentralized Trust former Hyperledger blockchain. Use the following pieces of retrieved context to answer the question. If you don't know the answer, say that you don't know. \ No newline at end of file diff --git a/src/mvt/query_rewriting.py b/src/mvt/query_rewriting.py index ff39ded..fb3b7a6 100644 --- a/src/mvt/query_rewriting.py +++ b/src/mvt/query_rewriting.py @@ -2,6 +2,7 @@ from utils import load_yaml_file_with_db_prompts from dotenv import load_dotenv, find_dotenv from langchain_mistralai.chat_models import ChatMistralAI +from main import load_prompt_from_file def query_rewriting_llm(user_query, context="Founder Institute Keystone Chapter"): """ @@ -38,7 +39,12 @@ def query_rewriting_llm(user_query, context="Founder Institute Keystone Chapter" temperature=0.7 ) - query_rewriting_prompt = config_data["query_rewriting_prompt"] + # Load query rewriting prompt from file or use the one in config + prompt_path = config_data.get("prompts", {}).get("query_rewriting_prompt") + if prompt_path and os.path.exists(prompt_path): + query_rewriting_prompt = load_prompt_from_file(prompt_path) + else: + query_rewriting_prompt = config_data["query_rewriting_prompt"] messages = [ ("system", query_rewriting_prompt),