diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..77a4175 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +results/* +*.pyc +*.md +*.txt +!*.py diff --git a/main.py b/main.py index b4b5bc6..426647a 100644 --- a/main.py +++ b/main.py @@ -7,24 +7,85 @@ from search_session import SearchSession + def load_config(config_path): if not os.path.isfile(config_path): return {} with open(config_path, "r") as f: return yaml.safe_load(f) + def main(): - parser = argparse.ArgumentParser(description="Multi-step RAG pipeline with depth-limited searching.") + parser = argparse.ArgumentParser( + description="Multi-step RAG pipeline with depth-limited searching." + ) parser.add_argument("--query", type=str, required=True, help="Initial user query") - parser.add_argument("--config", type=str, default="config.yaml", help="Path to YAML configuration file") - parser.add_argument("--corpus_dir", type=str, default=None, help="Path to local corpus folder") - parser.add_argument("--device", type=str, default="cpu", help="Device for retrieval model (cpu or cuda)") - parser.add_argument("--retrieval_model", type=str, choices=["colpali", "all-minilm"], default="colpali") - parser.add_argument("--top_k", type=int, default=3, help="Number of local docs to retrieve") - parser.add_argument("--web_search", action="store_true", default=False, help="Enable web search") - parser.add_argument("--personality", type=str, default=None, help="Optional personality for Gemma (e.g. cheerful)") - parser.add_argument("--rag_model", type=str, default="gemma", help="Which model to use for final RAG steps") - parser.add_argument("--max_depth", type=int, default=1, help="Depth limit for subquery expansions") + parser.add_argument( + "--config", + type=str, + default="config.yaml", + help="Path to YAML configuration file", + ) + parser.add_argument( + "--corpus_dir", type=str, default=None, help="Path to local corpus folder" + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="Device for retrieval model (cpu or cuda)", + ) + parser.add_argument( + "--retrieval_model", + type=str, + choices=["colpali", "all-minilm"], + default="colpali", + ) + parser.add_argument( + "--top_k", type=int, default=3, help="Number of local docs to retrieve" + ) + parser.add_argument( + "--web_search", action="store_true", default=False, help="Enable web search" + ) + parser.add_argument( + "--ddg_proxy", + type=str, + default=None, + help="Proxy for DuckDuckGo searches (format: http://user:pass@host:port)", + ) + parser.add_argument( + "--personality", + type=str, + default=None, + help="Optional personality for Gemma (e.g. cheerful)", + ) + parser.add_argument( + "--base_url", + type=str, + default="https://api.openai.com/v1", + help="Base URL for API (default: OpenAI official)", + ) + parser.add_argument( + "--rag_model", + type=str, + default="gemma", + help="Model name (e.g. 'gemma', 'gpt-4-turbo', 'google/gemini-2.0-flash-001')", + ) + parser.add_argument( + "--max_depth", type=int, default=1, help="Depth limit for subquery expansions" + ) + parser.add_argument( + "--ollama_model", + type=str, + default="gemma2:2b", + help="Ollama model for non-final tasks (query enhancement, summarization)", + ) + parser.add_argument( + "--max_context", + type=int, + default=24000, # ~16k tokens + help="Max context size in characters for final aggregation (default: 24000 ~16k tokens)", + ) args = parser.parse_args() config = load_config(args.config) @@ -39,7 +100,11 @@ def main(): web_search_enabled=args.web_search, personality=args.personality, rag_model=args.rag_model, - max_depth=args.max_depth + max_depth=args.max_depth, + base_url=args.base_url, + ddg_proxy=args.ddg_proxy, + ollama_model=args.ollama_model, + max_context=args.max_context, ) loop = asyncio.get_event_loop() diff --git a/requirements.txt b/requirements.txt index cf8f512..2e0cd98 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,17 +1,26 @@ # Core Dependencies -torch -transformers -sentence-transformers -numpy +torch>=2.3.0 +transformers>=4.40.0 +sentence-transformers>=3.0.0 +numpy>=1.26.4 pyyaml # Web and Parsing -aiohttp -duckduckgo_search -beautifulsoup4 -pymupdf # PyMuPDF for PDF handling +aiohttp>=3.9.3 +duckduckgo-search>=3.9.2 +beautifulsoup4>=4.12.3 +pymupdf>=1.24.0 # PyMuPDF (fitz) pytesseract -Pillow # for image handling +Pillow>=10.3.0 # Optional LLM Integration -ollama +ollama>=0.1.14 +openai>=1.30.1 +tiktoken>=0.7.0 # For OpenAI token counting + +# New packages +python-dotenv>=1.0.1 + +# Added from the code block +playwright>=1.46.0 +html2text>=2020.1.16 diff --git a/search_session.py b/search_session.py index 57603a8..686e371 100644 --- a/search_session.py +++ b/search_session.py @@ -7,9 +7,22 @@ import re import random import yaml - -from knowledge_base import KnowledgeBase, late_interaction_score, load_corpus_from_dir, load_retrieval_model, embed_text -from web_search import download_webpages_ddg, parse_html_to_text, group_web_results_by_domain, sanitize_filename +from dotenv import load_dotenv +import sys + +from knowledge_base import ( + KnowledgeBase, + late_interaction_score, + load_corpus_from_dir, + load_retrieval_model, + embed_text, +) +from web_search import ( + download_webpages_ddg, + parse_html_to_text, + group_web_results_by_domain, + sanitize_filename, +) from aggregator import aggregate_results ############################################# @@ -17,8 +30,19 @@ ############################################# from ollama import chat, ChatResponse +from openai import OpenAI + +load_dotenv(os.path.expanduser("~/.env")) # Proper home directory expansion +load_dotenv() # Load local .env + +# Check API key early if using non-local model +if os.getenv("NANOSAGE_API_KEY") is None: + print("[WARN] NANOSAGE_API_KEY not found in environment variables or .env file") + print("[WARN] This will cause errors if using external API models") -def call_gemma(prompt, model="gemma2:2b", personality=None): + +def call_gemma(prompt, model=None, personality=None): + model = model or "gemma2:2b" # Default if not specified system_message = "" if personality: system_message = f"You are a {personality} assistant.\n\n" @@ -29,62 +53,78 @@ def call_gemma(prompt, model="gemma2:2b", personality=None): response: ChatResponse = chat(model=model, messages=messages) return response.message.content + def extract_final_query(text): marker = "Final Enhanced Query:" if marker in text: return text.split(marker)[-1].strip() return text.strip() -def chain_of_thought_query_enhancement(query, personality=None): + +def chain_of_thought_query_enhancement(query, personality=None, model=None): + model = model or "gemma2:2b" prompt = ( "You are an expert search strategist. Think step-by-step through the implications and nuances " "of the following query and produce a final, enhanced query that covers more angles.\n\n" - f"Query: \"{query}\"\n\n" + f'Query: "{query}"\n\n' "After your reasoning, output only the final enhanced query on a single line - SHORT AND CONCISE.\n" "Provide your reasoning, and at the end output the line 'Final Enhanced Query:' followed by the enhanced query." ) - raw_output = call_gemma(prompt, personality=personality) + raw_output = call_gemma(prompt, personality=personality, model=model) return extract_final_query(raw_output) -def summarize_text(text, max_chars=6000, personality=None): + +def summarize_text(text, max_chars=24000, personality=None, model=None): + model = model or "gemma2:2b" if len(text) <= max_chars: prompt = f"Please summarize the following text succinctly:\n\n{text}" - return call_gemma(prompt, personality=personality) + return call_gemma(prompt, personality=personality, model=model) # If text is longer than max_chars, chunk it - chunks = [text[i:i+max_chars] for i in range(0, len(text), max_chars)] + chunks = [text[i : i + max_chars] for i in range(0, len(text), max_chars)] summaries = [] for i, chunk in enumerate(chunks): prompt = f"Summarize part {i+1}/{len(chunks)}:\n\n{chunk}" - summary = call_gemma(prompt, personality=personality) + summary = call_gemma(prompt, personality=personality, model=model) summaries.append(summary) time.sleep(1) combined = "\n".join(summaries) if len(combined) > max_chars: prompt = f"Combine these summaries into one concise summary:\n\n{combined}" - combined = call_gemma(prompt, personality=personality) + combined = call_gemma(prompt, personality=personality, model=model) return combined -def rag_final_answer(aggregation_prompt, rag_model="gemma", personality=None): - print("[INFO] Performing final RAG generation using model:", rag_model) + +def rag_final_answer( + aggregation_prompt, rag_model="gemma", personality=None, base_url=None +): + print( + f"[INFO] Performing final RAG generation using model: {rag_model} ({base_url})" + ) + # save the prompt to a file + with open("aggregation_prompt.txt", "w") as f: + f.write(aggregation_prompt) if rag_model == "gemma": return call_gemma(aggregation_prompt, personality=personality) elif rag_model == "pali": modified_prompt = f"PALI mode analysis:\n\n{aggregation_prompt}" return call_gemma(modified_prompt, personality=personality) else: - return call_gemma(aggregation_prompt, personality=personality) + return call_openai(aggregation_prompt, model=rag_model, base_url=base_url) + + +def follow_up_conversation(follow_up_prompt, personality=None, model="gemma2:2b"): + return call_gemma(follow_up_prompt, personality=personality, model=model) -def follow_up_conversation(follow_up_prompt, personality=None): - return call_gemma(follow_up_prompt, personality=personality) def clean_search_query(query): - query = re.sub(r'[\*\_`]', '', query) - query = re.sub(r'\s+', ' ', query) + query = re.sub(r"[\*\_`]", "", query) + query = re.sub(r"\s+", " ", query) return query.strip() + def split_query(query, max_len=200): - query = query.replace('"', '').replace("'", "") - sentences = query.split('.') + query = query.replace('"', "").replace("'", "") + sentences = query.split(".") subqueries = [] current = "" for sentence in sentences: @@ -102,19 +142,21 @@ def split_query(query, max_len=200): subqueries.append(current) return [sq for sq in subqueries if sq.strip()] + ############################################## # TOC Node: Represents a branch in the search tree ############################################## + class TOCNode: def __init__(self, query_text, depth=1): - self.query_text = query_text # The subquery text for this branch - self.depth = depth # Depth level in the tree - self.summary = "" # Summary of findings for this branch - self.web_results = [] # Web search results for this branch - self.corpus_entries = [] # Corpus entries generated from this branch - self.children = [] # Child TOCNode objects for further subqueries - self.relevance_score = 0.0 # Relevance score relative to the overall query + self.query_text = query_text # The subquery text for this branch + self.depth = depth # Depth level in the tree + self.summary = "" # Summary of findings for this branch + self.web_results = [] # Web search results for this branch + self.corpus_entries = [] # Corpus entries generated from this branch + self.children = [] # Child TOCNode objects for further subqueries + self.relevance_score = 0.0 # Relevance score relative to the overall query def add_child(self, child_node): self.children.append(child_node) @@ -122,6 +164,7 @@ def add_child(self, child_node): def __repr__(self): return f"TOCNode(query_text='{self.query_text}', depth={self.depth}, relevance_score={self.relevance_score:.2f}, children={len(self.children)})" + def build_toc_string(toc_nodes, indent=0): """ Recursively build a string representation of the TOC tree. @@ -132,19 +175,62 @@ def build_toc_string(toc_nodes, indent=0): summary_snippet = (node.summary[:150] + "...") if node.summary else "No summary" toc_str += f"{prefix}{node.query_text} (Relevance: {node.relevance_score:.2f}, Summary: {summary_snippet})\n" if node.children: - toc_str += build_toc_string(node.children, indent=indent+1) + toc_str += build_toc_string(node.children, indent=indent + 1) return toc_str + ######################################################### # The "SearchSession" class: orchestrate the entire pipeline, # including optional Monte Carlo subquery sampling, recursive web search, # TOC tracking, and relevance scoring. ######################################################### + +def check_api_health(base_url: str, model: str): + """Perform a basic API health check""" + try: + nanosage_api_key = os.getenv("NANOSAGE_API_KEY") + if not nanosage_api_key: + print("NANOSAGE_API_KEY not found in environment variables") + sys.exit(1) + else: + print("NANOSAGE_API_KEY found in environment variables") + client = OpenAI(base_url=base_url, api_key=nanosage_api_key) + # Test with a tiny prompt + test_prompt = "Respond only with 'pong'" + response = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": test_prompt}], + max_tokens=5, + timeout=5, + ) + if response.choices[0].message.content.strip().lower() != "pong": + raise ValueError("Unexpected API response") + print(f"[INFO] API health check passed for {base_url}") + except Exception as e: + print(f"[FATAL] API health check failed: {e}") + print("Verify your NANOSAGE_API_KEY and base URL") + sys.exit(1) + + class SearchSession: - def __init__(self, query, config, corpus_dir=None, device="cpu", - retrieval_model="colpali", top_k=3, web_search_enabled=False, - personality=None, rag_model="gemma", max_depth=1): + def __init__( + self, + query, + config, + corpus_dir=None, + device="cpu", + retrieval_model="colpali", + top_k=3, + web_search_enabled=False, + personality=None, + rag_model="gemma", + max_depth=1, + base_url="https://api.openai.com/v1", + ddg_proxy=None, + ollama_model="gemma2:2b", + max_context=24000, + ): """ :param max_depth: Maximum recursion depth for subquery expansion. """ @@ -158,37 +244,73 @@ def __init__(self, query, config, corpus_dir=None, device="cpu", self.personality = personality self.rag_model = rag_model self.max_depth = max_depth + self.base_url = base_url + self.ddg_proxy = ddg_proxy + self.ollama_model = ollama_model + self.max_context = max_context self.query_id = str(uuid.uuid4())[:8] - self.base_result_dir = os.path.join(self.config.get("results_base_dir", "results"), self.query_id) + self.base_result_dir = os.path.join( + self.config.get("results_base_dir", "results"), self.query_id + ) os.makedirs(self.base_result_dir, exist_ok=True) print(f"[INFO] Initializing SearchSession for query_id={self.query_id}") + # Strict API key check for external services + if self.rag_model not in ["gemma", "pali"] and not os.getenv( + "NANOSAGE_API_KEY" + ): + print( + "[FATAL] NANOSAGE_API_KEY required for external models. " + f"Current model: {self.rag_model} using base URL: {self.base_url}" + ) + sys.exit(1) + + # After API key check + if self.rag_model not in ["gemma", "pali"]: + print(f"[INFO] Performing API health check to {self.base_url}...") + check_api_health(self.base_url, self.rag_model) + # Enhance the query via chain-of-thought. - self.enhanced_query = chain_of_thought_query_enhancement(self.query, personality=self.personality) + self.enhanced_query = chain_of_thought_query_enhancement( + self.query, personality=self.personality, model=self.ollama_model + ) if not self.enhanced_query: self.enhanced_query = self.query # Load retrieval model. self.model, self.processor, self.model_type = load_retrieval_model( - model_choice=self.retrieval_model, - device=self.device + model_choice=self.retrieval_model, device=self.device ) # Compute the overall enhanced query embedding once. print("[INFO] Computing embedding for enhanced query...") - self.enhanced_query_embedding = embed_text(self.enhanced_query, self.model, self.processor, self.model_type, self.device) + self.enhanced_query_embedding = embed_text( + self.enhanced_query, + self.model, + self.processor, + self.model_type, + self.device, + ) # Create a knowledge base. print("[INFO] Creating KnowledgeBase...") - self.kb = KnowledgeBase(self.model, self.processor, model_type=self.model_type, device=self.device) + self.kb = KnowledgeBase( + self.model, self.processor, model_type=self.model_type, device=self.device + ) # Load local corpus if available. self.corpus = [] if self.corpus_dir: print(f"[INFO] Loading local documents from {self.corpus_dir}") - local_docs = load_corpus_from_dir(self.corpus_dir, self.model, self.processor, self.device, self.model_type) + local_docs = load_corpus_from_dir( + self.corpus_dir, + self.model, + self.processor, + self.device, + self.model_type, + ) self.corpus.extend(local_docs) self.kb.add_documents(self.corpus) @@ -202,21 +324,33 @@ async def run_session(self): """ Main entry point: perform recursive web search (if enabled) and then local retrieval. """ - print(f"[INFO] Starting session with query_id={self.query_id}, max_depth={self.max_depth}") + print( + f"[INFO] Starting session with query_id={self.query_id}, max_depth={self.max_depth}" + ) plain_enhanced_query = clean_search_query(self.enhanced_query) # 1) Generate subqueries from the enhanced query - initial_subqueries = split_query(plain_enhanced_query, max_len=self.config.get("max_query_length", 200)) - print(f"[INFO] Generated {len(initial_subqueries)} initial subqueries from the enhanced query.") + initial_subqueries = split_query( + plain_enhanced_query, max_len=self.config.get("max_query_length", 200) + ) + print( + f"[INFO] Generated {len(initial_subqueries)} initial subqueries from the enhanced query." + ) # 2) Optionally do a Monte Carlo approach to sample subqueries if self.config.get("monte_carlo_search", True): print("[INFO] Using Monte Carlo approach to sample subqueries.") - initial_subqueries = self.perform_monte_carlo_subqueries(plain_enhanced_query, initial_subqueries) + initial_subqueries = self.perform_monte_carlo_subqueries( + plain_enhanced_query, initial_subqueries + ) # 3) If web search is enabled and max_depth >= 1, do the recursive expansion if self.web_search_enabled and self.max_depth >= 1: - web_results, web_entries, grouped, toc_nodes = await self.perform_recursive_web_searches(initial_subqueries, current_depth=1) + web_results, web_entries, grouped, toc_nodes = ( + await self.perform_recursive_web_searches( + initial_subqueries, current_depth=1 + ) + ) self.web_results = web_results self.grouped_web_results = grouped self.toc_tree = toc_nodes @@ -224,7 +358,9 @@ async def run_session(self): self.corpus.extend(web_entries) self.kb.add_documents(web_entries) else: - print("[INFO] Web search is disabled or max_depth < 1, skipping web expansion.") + print( + "[INFO] Web search is disabled or max_depth < 1, skipping web expansion." + ) # 4) Local retrieval print(f"[INFO] Retrieving top {self.top_k} local documents for final answer.") @@ -244,25 +380,31 @@ def perform_monte_carlo_subqueries(self, parent_query, subqueries): 2) Weighted random selection of a subset (k=3) based on relevance scores. """ max_subqs = self.config.get("monte_carlo_samples", 3) - print(f"[DEBUG] Monte Carlo: randomly picking up to {max_subqs} subqueries from {len(subqueries)} total.") + print( + f"[DEBUG] Monte Carlo: randomly picking up to {max_subqs} subqueries from {len(subqueries)} total." + ) scored_subqs = [] for sq in subqueries: sq_clean = clean_search_query(sq) if not sq_clean: continue - node_emb = embed_text(sq_clean, self.model, self.processor, self.model_type, self.device) + node_emb = embed_text( + sq_clean, self.model, self.processor, self.model_type, self.device + ) score = late_interaction_score(self.enhanced_query_embedding, node_emb) scored_subqs.append((sq_clean, score)) if not scored_subqs: - print("[WARN] No valid subqueries found for Monte Carlo. Returning original list.") + print( + "[WARN] No valid subqueries found for Monte Carlo. Returning original list." + ) return subqueries # Weighted random choice chosen = random.choices( population=scored_subqs, weights=[s for (_, s) in scored_subqs], - k=min(max_subqs, len(scored_subqs)) + k=min(max_subqs, len(scored_subqs)), ) # Return just the chosen subqueries chosen_sqs = [ch[0] for ch in chosen] @@ -288,21 +430,42 @@ async def perform_recursive_web_searches(self, subqueries, current_depth=1): # Create a TOC node toc_node = TOCNode(query_text=sq_clean, depth=current_depth) # Relevance - node_embedding = embed_text(sq_clean, self.model, self.processor, self.model_type, self.device) - relevance = late_interaction_score(self.enhanced_query_embedding, node_embedding) + node_embedding = embed_text( + sq_clean, self.model, self.processor, self.model_type, self.device + ) + relevance = late_interaction_score( + self.enhanced_query_embedding, node_embedding + ) toc_node.relevance_score = relevance if relevance < min_relevance: - print(f"[INFO] Skipping branch '{sq_clean}' due to low relevance ({relevance:.2f} < {min_relevance}).") + print( + f"[INFO] Skipping branch '{sq_clean}' due to low relevance ({relevance:.2f} < {min_relevance})." + ) continue # Create subdirectory safe_subquery = sanitize_filename(sq_clean)[:30] subquery_dir = os.path.join(self.base_result_dir, f"web_{safe_subquery}") os.makedirs(subquery_dir, exist_ok=True) - print(f"[DEBUG] Searching web for subquery '{sq_clean}' at depth={current_depth}...") - - pages = await download_webpages_ddg(sq_clean, limit=self.config.get("web_search_limit", 5), output_dir=subquery_dir) + print( + f"[DEBUG] Searching web for subquery '{sq_clean}' at depth={current_depth}..." + ) + + # Get retry parameters from config + ddg_retries = self.config.get("ddg_search_retries", 3) + ddg_timeout = self.config.get("ddg_search_timeout", 15) + ddg_retry_delay = self.config.get("ddg_retry_delay", 3) + + pages = await download_webpages_ddg( + sq_clean, + limit=self.config.get("web_search_limit", 5), + output_dir=subquery_dir, + proxy=self.ddg_proxy, + ddg_timeout=ddg_timeout, + ddg_retries=ddg_retries, + ddg_retry_delay=ddg_retry_delay, + ) branch_web_results = [] branch_corpus_entries = [] for page in pages: @@ -315,11 +478,16 @@ async def perform_recursive_web_searches(self, subqueries, current_depth=1): raw_text = parse_html_to_text(file_path) if not raw_text.strip(): continue - snippet = raw_text[:100].replace('\n', ' ') + "..." + snippet = raw_text[:100].replace("\n", " ") + "..." limited_text = raw_text[:2048] try: if self.model_type == "colpali": - inputs = self.processor(text=[limited_text], truncation=True, max_length=512, return_tensors="pt").to(self.device) + inputs = self.processor( + text=[limited_text], + truncation=True, + max_length=512, + return_tensors="pt", + ).to(self.device) outputs = self.model(**inputs) emb = outputs.embeddings.mean(dim=1).squeeze(0) else: @@ -330,8 +498,8 @@ async def perform_recursive_web_searches(self, subqueries, current_depth=1): "file_path": file_path, "type": "webhtml", "snippet": snippet, - "url": url - } + "url": url, + }, } branch_corpus_entries.append(entry) branch_web_results.append({"url": url, "snippet": snippet}) @@ -339,19 +507,35 @@ async def perform_recursive_web_searches(self, subqueries, current_depth=1): print(f"[WARN] Error embedding page '{url}': {e}") # Summarize - branch_snippets = " ".join([r.get("snippet", "") for r in branch_web_results]) - toc_node.summary = summarize_text(branch_snippets, personality=self.personality) + branch_snippets = " ".join( + [r.get("snippet", "") for r in branch_web_results] + ) + toc_node.summary = summarize_text( + branch_snippets, + max_chars=self.max_context, + personality=self.personality, + model=self.ollama_model, + ) toc_node.web_results = branch_web_results toc_node.corpus_entries = branch_corpus_entries additional_subqueries = [] if current_depth < self.max_depth: - additional_query = chain_of_thought_query_enhancement(sq_clean, personality=self.personality) + additional_query = chain_of_thought_query_enhancement( + sq_clean, personality=self.personality, model=self.ollama_model + ) if additional_query and additional_query != sq_clean: - additional_subqueries = split_query(additional_query, max_len=self.config.get("max_query_length", 200)) + additional_subqueries = split_query( + additional_query, + max_len=self.config.get("max_query_length", 200), + ) if additional_subqueries: - deeper_web_results, deeper_corpus_entries, _, deeper_toc_nodes = await self.perform_recursive_web_searches(additional_subqueries, current_depth=current_depth+1) + deeper_web_results, deeper_corpus_entries, _, deeper_toc_nodes = ( + await self.perform_recursive_web_searches( + additional_subqueries, current_depth=current_depth + 1 + ) + ) branch_web_results.extend(deeper_web_results) branch_corpus_entries.extend(deeper_corpus_entries) for child_node in deeper_toc_nodes: @@ -362,8 +546,14 @@ async def perform_recursive_web_searches(self, subqueries, current_depth=1): toc_nodes.append(toc_node) grouped = group_web_results_by_domain( - [{"url": r["url"], "file_path": e["metadata"]["file_path"], "content_type": e["metadata"].get("content_type", "")} - for r, e in zip(aggregated_web_results, aggregated_corpus_entries)] + [ + { + "url": r["url"], + "file_path": e["metadata"]["file_path"], + "content_type": e["metadata"].get("content_type", ""), + } + for r, e in zip(aggregated_web_results, aggregated_corpus_entries) + ] ) return aggregated_web_results, aggregated_corpus_entries, grouped, toc_nodes @@ -371,27 +561,45 @@ def _summarize_web_results(self, web_results): lines = [] reference_urls = [] for w in web_results: - url = w.get('url') - snippet = w.get('snippet') + url = w.get("url") + snippet = w.get("snippet") lines.append(f"URL: {url} - snippet: {snippet}") reference_urls.append(url) text = "\n".join(lines) # We'll store reference URLs in self._reference_links for final prompt self._reference_links = list(set(reference_urls)) # unique - return summarize_text(text, personality=self.personality) + return summarize_text( + text, + max_chars=self.max_context, + personality=self.personality, + model=self.ollama_model, + ) def _summarize_local_results(self, local_results): lines = [] for doc in local_results: - meta = doc.get('metadata', {}) - file_path = meta.get('file_path') - snippet = meta.get('snippet', '') + meta = doc.get("metadata", {}) + file_path = meta.get("file_path") + snippet = meta.get("snippet", "") lines.append(f"File: {file_path} snippet: {snippet}") text = "\n".join(lines) - return summarize_text(text, personality=self.personality) + return summarize_text( + text, + max_chars=self.max_context, + personality=self.personality, + model=self.ollama_model, + ) - def _build_final_answer(self, summarized_web, summarized_local, previous_results_content="", follow_up_convo=""): - toc_str = build_toc_string(self.toc_tree) if self.toc_tree else "No TOC available." + def _build_final_answer( + self, + summarized_web, + summarized_local, + previous_results_content="", + follow_up_convo="", + ): + toc_str = ( + build_toc_string(self.toc_tree) if self.toc_tree else "No TOC available." + ) # Build a reference links string from _reference_links, if available reference_links = "" if hasattr(self, "_reference_links"): @@ -427,11 +635,20 @@ def _build_final_answer(self, summarized_web, summarized_local, previous_results Report: """ print("[DEBUG] Final RAG prompt constructed. Passing to rag_final_answer()...") - final_answer = rag_final_answer(aggregation_prompt, rag_model=self.rag_model, personality=self.personality) + final_answer = rag_final_answer( + aggregation_prompt, + rag_model=self.rag_model, + personality=self.personality, + base_url=self.base_url, + ) return final_answer def save_report(self, final_answer, previous_results=None, follow_up_convo=None): print("[INFO] Saving final report to disk...") + follow_up_prompt = follow_up_convo or "" + follow_up_convo = follow_up_conversation( + follow_up_prompt, personality=self.personality, model=self.ollama_model + ) return aggregate_results( self.query_id, self.enhanced_query, @@ -441,5 +658,32 @@ def save_report(self, final_answer, previous_results=None, follow_up_convo=None) self.config, grouped_web_results=self.grouped_web_results, previous_results=previous_results, - follow_up_conversation=follow_up_convo + follow_up_conversation=follow_up_convo, + ) + + +def call_openai( + prompt, model="gpt-4-turbo", base_url="https://api.openai.com/v1", temperature=0.7 +): + """Call OpenAI-compatible API endpoint""" + api_key = os.getenv("NANOSAGE_API_KEY") + if not api_key: + raise ValueError( + "NANOSAGE_API_KEY not found in environment variables or .env file" + ) + + client = OpenAI(base_url=base_url, api_key=api_key) + + try: + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": "You are an expert research assistant."}, + {"role": "user", "content": prompt}, + ], + temperature=temperature, ) + return response.choices[0].message.content + except Exception as e: + print(f"API error ({base_url}): {e}") + return "" diff --git a/web_search.py b/web_search.py index bf91cc4..f4beb2a 100644 --- a/web_search.py +++ b/web_search.py @@ -5,10 +5,18 @@ from bs4 import BeautifulSoup from duckduckgo_search import DDGS import fitz # PyMuPDF +from playwright.async_api import async_playwright +import html2text +from functools import wraps +from typing import Callable, Any + def sanitize_filename(filename): """Sanitize a filename by allowing only alphanumerics, dot, underscore, and dash.""" - return "".join(char if char.isalnum() or char in "._-" else "_" for char in filename) + return "".join( + char if char.isalnum() or char in "._-" else "_" for char in filename + ) + def sanitize_path(path): """ @@ -22,30 +30,123 @@ def sanitize_path(path): else: return os.sep.join(sanitized_parts) + +# Update the User-Agent +MODERN_USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36" + + async def download_page(session, url, headers, timeout, file_path): try: + headers = { + "User-Agent": MODERN_USER_AGENT, + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", + "Accept-Language": "en-US,en;q=0.5", + } + async with session.get(url, headers=headers, timeout=timeout) as response: response.raise_for_status() - content_type = response.headers.get('Content-Type', '') + content_type = response.headers.get("Content-Type", "") # If it's a PDF or image (or other binary content), read as binary. - if ('application/pdf' in content_type) or file_path.lower().endswith('.pdf') or \ - ('image/' in content_type): + if ( + ("application/pdf" in content_type) + or file_path.lower().endswith(".pdf") + or ("image/" in content_type) + ): content = await response.read() - mode = 'wb' + mode = "wb" open_kwargs = {} else: content = await response.text() - mode = 'w' - open_kwargs = {'encoding': 'utf-8'} # write text as UTF-8 to avoid charmap errors + mode = "w" + open_kwargs = { + "encoding": "utf-8" + } # write text as UTF-8 to avoid charmap errors with open(file_path, mode, **open_kwargs) as f: f.write(content) print(f"[INFO] Saved '{url}' -> '{file_path}'") - return {'url': url, 'file_path': file_path, 'content_type': content_type} + return {"url": url, "file_path": file_path, "content_type": content_type} + except Exception as e: + print(f"[WARN] Initial fetch failed for '{url}': {e}, trying browser render...") + return await download_with_browser(url, file_path) + + +async def download_with_browser(url, file_path): + try: + async with async_playwright() as p: + browser = await p.chromium.launch() + context = await browser.new_context( + user_agent=MODERN_USER_AGENT, java_script_enabled=True + ) + page = await context.new_page() + + await page.goto(url, wait_until="domcontentloaded", timeout=15000) + + # Handle PDFs differently + if url.lower().endswith(".pdf"): + content = await page.content() + with open(file_path, "w", encoding="utf-8") as f: + f.write(content) + return { + "url": url, + "file_path": file_path, + "content_type": "application/pdf", + } + + # Convert HTML to clean text + content = await page.content() + text_maker = html2text.HTML2Text() + text_maker.ignore_links = False + text = text_maker.handle(content) + + with open(file_path, "w", encoding="utf-8") as f: + f.write(text) + + print(f"[INFO] Browser-rendered page saved: {file_path}") + return {"url": url, "file_path": file_path, "content_type": "text/html"} + except Exception as e: - print(f"[WARN] Couldn't fetch '{url}': {e}") + print(f"[WARN] Browser render failed for '{url}': {e}") return None -async def download_webpages_ddg(keyword, limit=5, output_dir='downloaded_webpages'): + +def retry_with_timeout(max_retries: int = 3, timeout: int = 10, retry_delay: int = 2): + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs) -> Any: + for attempt in range(1, max_retries + 1): + try: + # Run synchronous DDGS calls in thread pool + return await asyncio.wait_for( + asyncio.to_thread(func, *args, **kwargs), timeout=timeout + ) + except Exception as e: + print( + f"[WARN] DDG search attempt {attempt}/{max_retries} failed: {e}" + ) + if attempt < max_retries: + print(f"Retrying in {retry_delay} seconds...") + await asyncio.sleep(retry_delay) + raise Exception(f"DDG search failed after {max_retries} attempts") + + return wrapper + + return decorator + + +@retry_with_timeout(max_retries=3, timeout=15, retry_delay=3) +def ddg_search_with_retry(ddgs, keyword, limit): + return ddgs.text(keyword, max_results=limit) + + +async def download_webpages_ddg( + keyword, + limit=5, + output_dir="downloaded_webpages", + proxy=None, + ddg_timeout=15, + ddg_retries=3, + ddg_retry_delay=3, +): """ Perform a DuckDuckGo text search and download pages asynchronously. Returns a list of dicts with 'url', 'file_path', and optionally 'content_type'. @@ -53,21 +154,21 @@ async def download_webpages_ddg(keyword, limit=5, output_dir='downloaded_webpage # Sanitize the output directory output_dir = sanitize_path(output_dir) os.makedirs(output_dir, exist_ok=True) - + results_info = [] if not keyword.strip(): print("[WARN] Empty keyword provided to DuckDuckGo search; skipping search.") return [] - - with DDGS() as ddgs: - results = ddgs.text(keyword, max_results=limit) + + with DDGS(proxy=proxy) as ddgs: + results = await ddg_search_with_retry(ddgs, keyword, limit) if not results: print(f"[WARN] No results found for '{keyword}'.") return [] - - headers = {'User-Agent': 'Mozilla/5.0'} + + headers = {"User-Agent": "Mozilla/5.0"} timeout = aiohttp.ClientTimeout(total=10) - + async with aiohttp.ClientSession(timeout=timeout) as session: tasks = [] for idx, result in enumerate(results): @@ -89,6 +190,7 @@ async def download_webpages_ddg(keyword, limit=5, output_dir='downloaded_webpage results_info.append(page) return results_info + def parse_pdf_to_text(pdf_file_path, max_pages=10): """ Extract text from a PDF using PyMuPDF. @@ -107,11 +209,13 @@ def parse_pdf_to_text(pdf_file_path, max_pages=10): print(f"[INFO] Extracted text from PDF: {pdf_file_path}") return text else: - print(f"[INFO] No text found in PDF: {pdf_file_path}, converting pages to images") + print( + f"[INFO] No text found in PDF: {pdf_file_path}, converting pages to images" + ) for i in range(min(max_pages, doc.page_count)): page = doc.load_page(i) pix = page.get_pixmap() - image_file = pdf_file_path.replace('.pdf', f'_page_{i+1}.png') + image_file = pdf_file_path.replace(".pdf", f"_page_{i+1}.png") pix.save(image_file) print(f"[INFO] Saved page {i+1} as image: {image_file}") return "" @@ -119,6 +223,7 @@ def parse_pdf_to_text(pdf_file_path, max_pages=10): print(f"[WARN] Failed to parse PDF {pdf_file_path}: {e}") return "" + def parse_html_to_text(file_path, max_pdf_pages=10): """ If the file is HTML, parse it and return its plain text. @@ -126,25 +231,26 @@ def parse_html_to_text(file_path, max_pdf_pages=10): If the PDF has little or no text, convert up to max_pdf_pages to images. """ try: - if file_path.lower().endswith('.pdf'): + if file_path.lower().endswith(".pdf"): return parse_pdf_to_text(file_path, max_pages=max_pdf_pages) - with open(file_path, 'r', encoding='utf-8', errors='replace') as f: + with open(file_path, "r", encoding="utf-8", errors="replace") as f: html_data = f.read() - soup = BeautifulSoup(html_data, 'html.parser') + soup = BeautifulSoup(html_data, "html.parser") for tag in soup(["script", "style"]): tag.decompose() - return soup.get_text(separator=' ', strip=True) + return soup.get_text(separator=" ", strip=True) except Exception as e: print(f"[WARN] Failed to parse HTML {file_path}: {e}") return "" + def group_web_results_by_domain(web_results): """ Takes a list of dicts, each with 'url', 'file_path', 'content_type', and groups them by domain. """ grouped = {} for item in web_results: - url = item.get('url') + url = item.get("url") if not url: continue domain = urlparse(url).netloc