From 6f3c16fe80d36766d459c4d9613898a479c5c9db Mon Sep 17 00:00:00 2001 From: Graham V Date: Mon, 18 Nov 2024 19:29:28 -0500 Subject: [PATCH 1/4] Two changes to utils.py - when fetching url_text_contents, if clipboard contains a url, then fetch the text contents of that url using jina.ai - update to token limit maintenance logic to ensure that the most recent message is never trimmed from context regardless of its token count --- utils.py | 60 +++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 46 insertions(+), 14 deletions(-) diff --git a/utils.py b/utils.py index 1c107f8..79c416e 100644 --- a/utils.py +++ b/utils.py @@ -7,6 +7,8 @@ import json import os import time +import requests +from ip2geotools.databases.noncommercial import DbIpCity def read_clipboard(model_supports_images=True): """Read text or image from clipboard.""" @@ -25,11 +27,25 @@ def read_clipboard(model_supports_images=True): clipboard_content = clipboard.paste() if isinstance(clipboard_content, str) and clipboard_content: # It's text + url_pattern = r'^https?:\/\/(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b(?:[-a-zA-Z0-9()@:%_\+.~#?&\/=]*)$' + if re.search(url_pattern, clipboard_content): + clipboard_content = fetch_url_text_contents(clipboard_content) or clipboard_content return {'type': 'text', 'content': clipboard_content} print("No valid content found in clipboard.") return None +def fetch_url_text_contents(url): + try: + response = requests.get('https://r.jina.ai/' + url) + response.raise_for_status() + print(response.text[:200]) + return response.text + except requests.exceptions.RequestException as e: + print(f"Error fetching URL: {e}") + return None + + def to_clipboard(text): """ Copy the given text to the clipboard. @@ -67,7 +83,7 @@ def sanitize_text(text): def _trim_messages(messages, max_tokens): """ - Trim the messages to fit within the maximum token limit. + Trim the messages to fit within the maximum token limit while preserving the last message. Args: messages (list): A list of messages to be trimmed. @@ -76,25 +92,39 @@ def _trim_messages(messages, max_tokens): Returns: list: The trimmed list of messages. """ - msg_token_count = 0 + if len(messages) <= 1: + return messages + + # Separate the last message from the rest + messages_without_last = messages[:-1] + last_message = messages[-1] + # Keep trimming messages until we're under the token limit or only system messages remain while True: - msg_token_count = _count_tokens(messages) - if msg_token_count <= max_tokens: + # Calculate total tokens including the last message + total_tokens = _count_tokens(messages_without_last + [last_message]) + if total_tokens <= max_tokens: break - # Remove the oldest non-system message - for i in range(len(messages)): - if messages[i].get('role') != 'system': - del messages[i] + + # Find the first non-system message to remove + for i in range(len(messages_without_last)): + if messages_without_last[i].get('role') != 'system': + del messages_without_last[i] break + else: # No more non-system messages to remove + break # Ensure the first non-system message is from the user - first_non_system_msg_index = next((i for i, message in enumerate(messages) if message.get('role') != 'system'), None) - while first_non_system_msg_index is not None and messages[first_non_system_msg_index].get('role') == 'assistant': - del messages[first_non_system_msg_index] - first_non_system_msg_index = next((i for i, message in enumerate(messages) if message.get('role') != 'system'), None) + first_non_system_msg_index = next((i for i, message in enumerate(messages_without_last) + if message.get('role') != 'system'), None) + while (first_non_system_msg_index is not None and + messages_without_last[first_non_system_msg_index].get('role') == 'assistant'): + del messages_without_last[first_non_system_msg_index] + first_non_system_msg_index = next((i for i, message in enumerate(messages_without_last) + if message.get('role') != 'system'), None) - return messages + # Combine the trimmed messages with the preserved last message + return messages_without_last + [last_message] def _count_tokens(messages, model="gpt-3.5-turbo"): """ @@ -128,6 +158,7 @@ def _count_tokens(messages, model="gpt-3.5-turbo"): def maintain_token_limit(messages, max_tokens): """ Maintain the token limit by trimming messages if the token count exceeds the maximum limit. + The most recent message (last in the array) will never be trimmed. Args: messages (list): A list of messages to maintain. @@ -239,4 +270,5 @@ def add_timestamp_to_message(message_content): message_content[-1]['text'] += timestamp else: message_content += timestamp - return message_content \ No newline at end of file + return message_content + From a740826652087ce727e97f050dd7f29881652c61 Mon Sep 17 00:00:00 2001 From: Graham V Date: Mon, 18 Nov 2024 19:53:08 -0500 Subject: [PATCH 2/4] feat: Add weather, news and search capabilities - Add weather_skill.py for current weather and forecasts via OpenWeatherMap - Add news_skill.py with multi-provider support and time-based queries - Add search_skill.py for general search across multiple providers - Implement search provider factory supporting Tavily, Bing, Brave and Exa - Add natural language processing for query understanding Include comprehensive test coverage for all new components - Add custom action always_ready_voice_assistant_nlp to utilize the above --- .../always_reddy_voice_assistant_nlp/main.py | 124 +++++++ config_default.py | 6 + nlp_manager.py | 212 ++++++++++++ search_providers/__init__.py | 5 + search_providers/base_provider.py | 42 +++ search_providers/bing_provider.py | 200 ++++++++++++ search_providers/brave_provider.py | 308 ++++++++++++++++++ search_providers/exa_provider.py | 231 +++++++++++++ search_providers/factory.py | 78 +++++ search_providers/tavily_provider.py | 160 +++++++++ search_providers/trusted_news_sources.json | 71 ++++ skills/news_skill.py | 237 ++++++++++++++ skills/search_skill.py | 136 ++++++++ skills/weather_skill.py | 241 ++++++++++++++ 14 files changed, 2051 insertions(+) create mode 100644 actions/always_reddy_voice_assistant_nlp/main.py create mode 100644 nlp_manager.py create mode 100644 search_providers/__init__.py create mode 100644 search_providers/base_provider.py create mode 100644 search_providers/bing_provider.py create mode 100644 search_providers/brave_provider.py create mode 100644 search_providers/exa_provider.py create mode 100644 search_providers/factory.py create mode 100644 search_providers/tavily_provider.py create mode 100644 search_providers/trusted_news_sources.json create mode 100644 skills/news_skill.py create mode 100644 skills/search_skill.py create mode 100644 skills/weather_skill.py diff --git a/actions/always_reddy_voice_assistant_nlp/main.py b/actions/always_reddy_voice_assistant_nlp/main.py new file mode 100644 index 0000000..d01b9c7 --- /dev/null +++ b/actions/always_reddy_voice_assistant_nlp/main.py @@ -0,0 +1,124 @@ +import time +from config_loader import config +from actions.base_action import BaseAction +from utils import to_clipboard, handle_clipboard_image, handle_clipboard_text, add_timestamp_to_message +import prompt +from nlp_manager import get_nlp_context +class AlwaysReddyVoiceAssistant_nlp(BaseAction): + """Action for handling voice assistant functionality.""" + def setup(self): + self.last_message_was_cut_off = False + + if config.RECORD_WITH_NLP_HOTKEY: + self.AR.add_action_hotkey( + config.RECORD_WITH_NLP_HOTKEY, + pressed=self.handle_default_assistant_response, + held_release=self.handle_default_assistant_response, + double_tap=self.AR.save_clipboard_text + ) + + print(f"'{config.RECORD_WITH_NLP_HOTKEY}': Start/stop talking to voice assistant (press to toggle on and off, or hold and release)") + if "+" in config.RECORD_WITH_NLP_HOTKEY: + hotkey_start, hotkey_end = config.RECORD__WITH_NLP_HOTKEY.rsplit("+", 1) + print(f"\tHold down '{hotkey_start}' and double tap '{hotkey_end}' to send clipboard content to AlwaysReddy") + else: + print(f"\tDouble tap '{config.RECORD_WITH_NLP_HOTKEY}' to send clipboard content to AlwaysReddy") + + if config.NEW_CHAT_HOTKEY: + self.AR.add_action_hotkey(config.NEW_CHAT_HOTKEY, pressed=self.new_chat) + print(f"'{config.NEW_CHAT_HOTKEY}': New chat for voice assistant") + + self.messages = prompt.build_initial_messages(config.ACTIVE_PROMPT) + + def handle_default_assistant_response(self): + """Handle the response from the transcription and generate a completion.""" + try: + recording_filename = self.AR.toggle_recording(self.handle_default_assistant_response) + if not recording_filename: + return + message = self.AR.transcription_manager.transcribe_audio(recording_filename) + + if not self.AR.stop_action and message: + print("\nTranscript:\n", message) + + if len(self.messages) > 0 and self.messages[0]["role"] == "system": + self.messages[0]["content"] = prompt.get_system_prompt_message(config.ACTIVE_PROMPT) + if self.last_message_was_cut_off: + message = "--> USER CUT THE ASSISTANT'S LAST MESSAGE SHORT <--\n" + message + + new_message = {"role": "user", "content": message} + start_time = time.time() + + # Handle potential NLP context requests for weather, news and search + nlp_context = get_nlp_context(self.AR,message) + if nlp_context: + new_message['content'] = nlp_context + + # Handle clipboard image + clipboard_image_content = handle_clipboard_image(self.AR, message) + if clipboard_image_content: + new_message['content'] = clipboard_image_content + else: + # Handle clipboard text + new_message['content'] = handle_clipboard_text(self.AR, new_message['content']) + + # Add timestamp if configured + if config.TIMESTAMP_MESSAGES: + new_message['content'] = add_timestamp_to_message(new_message['content']) + + self.messages.append(new_message) + + if self.AR.stop_action: + return + + # Ensure there's at least one message + if not self.messages: + print("Error: No messages to send to the API.") + return + + stream = self.AR.completion_client.get_completion_stream( + self.messages, + config.COMPLETION_MODEL, + **config.COMPLETION_PARAMS + ) + + end_time = time.time() + if self.AR.verbose: print(f"Execution time: {end_time - start_time:.1f} seconds") + response = self.AR.completion_client.process_text_stream( + stream, + marker_tuples=[(config.CLIPBOARD_TEXT_START_SEQ, config.CLIPBOARD_TEXT_END_SEQ, to_clipboard)], + sentence_callback=self.AR.tts.run_tts + ) + + while self.AR.tts.running_tts: + time.sleep(0.001) + + if not response: + if self.AR.verbose: + print("No response generated.") + self.messages = self.messages[:-1] + return + + self.last_message_was_cut_off = False + + if self.AR.stop_action: + index = response.rfind(self.AR.tts.last_sentence_spoken) + if index != -1: + response = response[:index + len(self.AR.tts.last_sentence_spoken)] + self.last_message_was_cut_off = True + + self.messages.append({"role": "assistant", "content": response}) + print("\nResponse:\n", response) + + except Exception as e: + print(f"An error occurred in handle_default_assistant_response: {e}") + if self.AR.verbose: + import traceback + traceback.print_exc() + + def new_chat(self): + """Clear the message history and start a new chat session.""" + self.messages = prompt.build_initial_messages(config.ACTIVE_PROMPT) + self.last_message_was_cut_off = False + self.AR.last_clipboard_text = None + print("New chat session started.") \ No newline at end of file diff --git a/config_default.py b/config_default.py index ebc09d8..4f2130c 100644 --- a/config_default.py +++ b/config_default.py @@ -131,3 +131,9 @@ CANCEL_SOUND_VOLUME = 0.09 MAX_RECORDING_DURATION= 600 # If you record for more than 10 minutes, the recording will stop automatically +### WEATHER,NEWS AND SEARCH SETTINGS ### +DEFAULT_UNITS = 'metric' #metric or imperial (or standard if you want your weather in Kelvin) +DEFAULT_LOCATION = '' +#Search and News providers can be Bing, Brave, Exa or Tavily, and don't need to be the same +SEARCH_PROVIDER = '' +NEWS_PROVIDER = '' \ No newline at end of file diff --git a/nlp_manager.py b/nlp_manager.py new file mode 100644 index 0000000..98efe32 --- /dev/null +++ b/nlp_manager.py @@ -0,0 +1,212 @@ +import spacy +from config_loader import config +from typing import List, Dict, Any, Tuple, Optional +from skills.weather_skill import WeatherSkill +from skills.news_skill import NewsSkill +from skills.search_skill import SearchSkill + +class NLPManager: + def __init__(self): + # Load spaCy model + """ + Initialize the NLPManager. + + This method loads the spaCy model and stores it in the class instance. + It also sets up the intent keywords and initializes the entity extractor. + """ + try: + self.nlp = spacy.load('en_core_web_md') + except OSError: + # Download if not available + spacy.cli.download('en_core_web_md') + self.nlp = spacy.load('en_core_web_md') + + # Intent keywords + self.intents = { + 'weather': ['weather', 'temperature', 'forecast', 'rain', 'snow', 'sunny', 'cloudy', 'humidity', 'heat', 'cold', 'hot', + 'wind', 'storm', 'precipitation', 'celsius', 'fahrenheit', 'degrees', 'sunrise', 'sunset'], + 'news': ['news', 'update', 'latest', 'current', 'happening', 'event', 'story', 'article', 'report'], + 'search': ['search', 'find', 'look up', 'lookup', 'research', 'information about', 'tell me about', 'what is', 'who is'] + } + + self.entity_extractor = EntityExtractor(self.nlp) + + def preprocess_text(self, text: str) -> List[str]: + # Process text using spaCy + """ + Process text using spaCy. + + This method takes a string of text and preprocesses it using the spaCy library. + It first converts the text to lowercase and then processes it using the + English language model. + + The method then filters out the tokens using spaCy's built-in attributes: + - is_stop: Stop words (e.g. "the", "a", etc.) + - is_punct: Punctuation (e.g. periods, commas, etc.) + - len(token.text) > 1: Tokens with more than one character + - token.text.isalpha(): Tokens that are alphabetic + + The method returns a list of lemmatized tokens (root words). + + :param text: The text to preprocess + :type text: str + :return: A list of lemmatized tokens + :rtype: List[str] + """ + doc = self.nlp(text.lower()) + + # Filter tokens using spaCy's built-in attributes + processed_tokens = [ + token.lemma_ for token in doc + if not token.is_stop + and not token.is_punct + and len(token.text) > 1 + and token.text.isalpha() + ] + + return processed_tokens + + def classify_intent(self, text: str) -> Tuple[str, float, Optional[Dict[str, Any]]]: + """ + Classify the intent of a given text. + + This method takes a string of text and classifies its intent into one of the + following categories: information_query, weather, news, search. + + The method preprocesses the text using the preprocess_text method and then + checks if any of the intent keywords are present in the tokens. If a keyword + is found, it calculates the confidence of the classification by dividing the + number of keywords found by the total number of tokens. If no keyword is found, + it defaults to an information_query intent with a confidence of 0.3. + + :param text: The text to classify + :type text: str + :return: A tuple containing the intent, confidence, and optional entity dictionary + :rtype: Tuple[str, float, Optional[Dict[str, Any]]] + """ + tokens = self.preprocess_text(text) + + # Intent classification + for intent, keywords in self.intents.items(): + # Special handling for multi-word keywords + text_lower = text.lower() + keyword_matches = sum(1 for keyword in keywords if keyword in text_lower) + if keyword_matches > 0: + confidence = keyword_matches / len(tokens) + return intent, confidence, None + + return 'information_query', 0.3, None + + +class EntityExtractor: + def __init__(self, nlp): + """ + Initialize the EntityExtractor. + + This method sets up the entity extractor with the provided spaCy NLP model. + + :param nlp: The spaCy NLP model used for entity extraction. + :type nlp: spacy.language.Language + """ + self.nlp = nlp + + def extract_entities(self, text: str) -> Dict[str, List[str]]: + doc = self.nlp(text) + entities = {} + for ent in doc.ents: + if ent.label_ not in entities: + entities[ent.label_] = [] + entities[ent.label_].append(ent.text) + return entities + +def get_nlp_context(AR, message_content: str) -> str: + """ + Add contextual information to a message based on the user's intent. + + If the intent is 'weather', add information about the user's preferred unit of measurement + for weather information, and the user's location if it can be detected. + + If the intent is 'news', add news search results related to the query. + + If the intent is 'search', add general search results related to the query. + + :param AR: The AlwaysReddy object + :param message_content: The message content + :return: The message content with additional context + :rtype: str + """ + nlp_manager = NLPManager() + intent, confidence, additional_info = nlp_manager.classify_intent(message_content) + + if intent == 'weather': + weather_skill = WeatherSkill() + message_content += f"\n\nTHE USER APPEARS TO HAVE A QUESTION ABOUT THE WEATHER, USE THIS DATA TO HELP YOU ANSWER IT:\n```{weather_skill.getLocationWeather(message_content)}```" + unit = str(config.DEFAULT_UNITS) or 'METRIC' + message_content += f"\n\nTHE USER PREFERS TO RECEIVE WEATHER INFORMATION IN A {unit} FORMAT. DO NOT ABBREVIATE UNITS OF MEASURE eg. USE 'CELSIUS' INSTEAD OF 'C'. KEEP THE FORECAST SOUNDING NATURAL AND NOT ROBOTIC." + + elif intent == 'news': + news_skill = NewsSkill(provider_type=config.NEWS_PROVIDER) + news_results = news_skill.search_news(message_content) + if 'error' not in news_results: + message_content += f"\n\nTHE USER APPEARS TO HAVE A QUESTION ABOUT NEWS, USE THIS DATA TO HELP YOU ANSWER IT:\n```{news_results}```" + message_content += "\n\nPLEASE PROVIDE A NATURAL SUMMARY OF THE NEWS, INCLUDING THE MOST RELEVANT AND RECENT INFORMATION. PROVIDE 5 SENTENCES OF DETAIL." + + elif intent == 'search': + search_skill = SearchSkill(provider_type=config.SEARCH_PROVIDER) + search_results = search_skill.search(message_content) + if 'error' not in search_results: + message_content += f"\n\nTHE USER APPEARS TO BE SEARCHING FOR INFORMATION, USE THIS DATA TO HELP YOU ANSWER IT:\n```{search_results}```" + message_content += "\n\nPLEASE PROVIDE A CLEAR AND CONCISE ANSWER BASED ON THE SEARCH RESULTS AND YOUR EXISTING KNOWLEDGE, FOCUSING ON THE MOST RELEVANT INFORMATION TO ANSWER THEIR QUESTION." + + return message_content + + + +def main(): + classifier = NLPManager() + + test_queries = [ + "What is the best restaurant in town?", + "Book a flight to New York", + "Compare iPhone and Android phones", + "Recommend a good book to read", + "What's the weather like in Buffalo today?", + "Will it rain tomorrow in New York?", + "How's the temperature in Los Angeles this weekend?", + "How's the weather in Los Angeles this weekend?", + "What is the weather in Miami?", + "Describe the wind conditions in Los Angeles for the rest of the week.", + "Give me a weather update for New York", + "Search for quantum computing", + "Look up the history of Rome", + "Find information about electric cars" + ] + + import time + for query in test_queries: + #start_time = time.time() + intent, confidence, additional_info = classifier.classify_intent(query) + entities = classifier.entity_extractor.extract_entities(query) + print(f"Query: '{query}'") + print(f"Intent: {intent}") + print(f"Confidence: {confidence:.2f}") + print(f"Entities: {entities}") + if additional_info: + print("Additional Information:") + for key, value in additional_info.items(): + print(f" {key}: {value}") + print() + start_time = time.time() + if intent == 'weather': + weather_skill = WeatherSkill() + print(str(weather_skill.getLocationWeather(query))[:200]) + elif intent == 'search': + search_skill = SearchSkill() + print(str(search_skill.search(query))[:200]) + print() + end_time = time.time() + print(f"Time taken: {end_time - start_time} seconds") + print() + +if __name__ == "__main__": + main() diff --git a/search_providers/__init__.py b/search_providers/__init__.py new file mode 100644 index 0000000..a9450b3 --- /dev/null +++ b/search_providers/__init__.py @@ -0,0 +1,5 @@ +from .base_provider import BaseSearchProvider +from .tavily_provider import TavilySearchProvider +from .factory import SearchProviderFactory + +__all__ = ['BaseSearchProvider', 'TavilySearchProvider', 'SearchProviderFactory'] diff --git a/search_providers/base_provider.py b/search_providers/base_provider.py new file mode 100644 index 0000000..e98942e --- /dev/null +++ b/search_providers/base_provider.py @@ -0,0 +1,42 @@ +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional + +class BaseSearchProvider(ABC): + """ + Abstract base class for search providers. + All search providers must implement these methods. + """ + + @abstractmethod + def __init__(self, api_key: Optional[str] = None): + """ + Initialize the search provider. + + Args: + api_key: Optional API key for the search provider + """ + pass + + @abstractmethod + def search(self, query: str, **kwargs) -> Dict[str, Any]: + """ + Perform a search using the provider. + + Args: + query: The search query string + **kwargs: Additional search parameters specific to the provider + + Returns: + Dict containing the search results or error information + """ + pass + + @abstractmethod + def is_configured(self) -> bool: + """ + Check if the provider is properly configured (e.g., has valid API key). + + Returns: + bool indicating if the provider is ready to use + """ + pass diff --git a/search_providers/bing_provider.py b/search_providers/bing_provider.py new file mode 100644 index 0000000..2fe76ad --- /dev/null +++ b/search_providers/bing_provider.py @@ -0,0 +1,200 @@ +from typing import Dict, Any, Optional +import os +import sys +from pathlib import Path +import requests +from datetime import datetime, timedelta +import json + +# Add parent directory to path for imports when running as script +if __name__ == "__main__": + sys.path.append(str(Path(__file__).parent.parent)) + from search_providers.base_provider import BaseSearchProvider +else: + from .base_provider import BaseSearchProvider + +class BingSearchProvider(BaseSearchProvider): + """ + Bing implementation of the search provider interface. + Handles both web and news-specific searches using Bing's APIs. + """ + + WEB_SEARCH_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search" + NEWS_SEARCH_ENDPOINT = "https://api.bing.microsoft.com/v7.0/news/search" + + def __init__(self, api_key: Optional[str] = None): + """ + Initialize the Bing search provider. + + Args: + api_key: Optional Bing API key. If not provided, will try to get from environment. + """ + self.api_key = api_key or os.getenv("BING_API_KEY") + self.headers = { + 'Ocp-Apim-Subscription-Key': self.api_key, + 'Accept': 'application/json' + } if self.api_key else None + + # Load trusted news sources + self.trusted_sources = self._load_trusted_sources() + + def _load_trusted_sources(self) -> list: + """Load first 5 trusted news sources from JSON file.""" + try: + json_path = Path(__file__).parent / "trusted_news_sources.json" + with open(json_path) as f: + data = json.load(f) + # Only load the first 16 sources as per MSFT limits + return data.get("trusted_sources", [])[:16] + except Exception as e: + print(f"Warning: Could not load trusted news sources: {e}") + return [] + + def is_configured(self) -> bool: + """Check if Bing API is properly configured.""" + return self.headers is not None + + def search(self, query: str, **kwargs) -> Dict[str, Any]: + """ + Perform a search using Bing API. + + Args: + query: The search query string + **kwargs: Additional search parameters: + - topic: Optional search topic (e.g., "news") + - max_results: Maximum number of results (default: 10) + - market: Market code (default: "en-US") + - days: Number of days to look back (for news searches) + + Returns: + Dict containing search results or error information + """ + if not self.is_configured(): + return {'error': 'Bing API key not configured'} + + try: + # Set default search parameters + search_params = { + 'count': str(kwargs.get('max_results', 10)), # Changed default from 5 to 10 + 'mkt': kwargs.get('market', 'en-US'), + 'textFormat': 'Raw' + } + + # Determine if this is a news search + if kwargs.get('topic') == 'news': + # Add freshness parameter for news if days specified + if 'days' in kwargs: + # Bing API expects 'day', 'week', or 'month' + search_params['freshness'] = 'week' if kwargs['days'] >1 else 'day' + + # Add site: operators for trusted sources + if self.trusted_sources: + site_operators = " OR ".join(f'site:{source}' for source in self.trusted_sources) + search_params['q'] = f"({query}) ({site_operators})" + else: + search_params['q'] = f"latest headlines about the topic: {query}" + + response = requests.get( + self.NEWS_SEARCH_ENDPOINT, + headers=self.headers, + params=search_params + ) + else: + search_params['q'] = query + response = requests.get( + self.WEB_SEARCH_ENDPOINT, + headers=self.headers, + params=search_params + ) + + if response.status_code != 200: + return {'error': f'API request failed with status {response.status_code}: {response.text}'} + + response_data = response.json() + + # Process results based on search type + if kwargs.get('topic') == 'news': + return self._process_news_results( + response_data, + days=kwargs.get('days', 3), + topic=query + ) + else: + return self._process_general_results(response_data) + + except requests.exceptions.RequestException as e: + return {'error': f'API request failed: {str(e)}'} + except Exception as e: + return {'error': f'An unexpected error occurred: {str(e)}'} + + def _process_general_results(self, response: Dict[str, Any]) -> Dict[str, Any]: + """Process results for general web searches.""" + webpages = response.get('webPages', {}).get('value', []) + return { + 'results': [{ + 'title': result.get('name', ''), + 'url': result.get('url', ''), + 'content': result.get('snippet', ''), + 'score': 1.0 # Bing doesn't provide relevance scores + } for result in webpages[:10]] # Changed from 3 to 10 + } + + def _process_news_results(self, response: Dict[str, Any], days: int, topic: str) -> Dict[str, Any]: + """Process results for news-specific searches.""" + articles = response.get('value', []) + return { + 'articles': [{ + 'title': article.get('name', ''), + 'url': article.get('url', ''), + 'published_date': article.get('datePublished', ''), + 'content': article.get('description', ''), + 'score': 1.0 # Bing doesn't provide relevance scores + } for article in articles], + 'time_period': f"Past {days} days", + 'topic': topic + } + +if __name__ == "__main__": + # Test code using actual API + provider = BingSearchProvider() + if not provider.is_configured(): + print("Error: Bing API key not configured") + exit(1) + + # Print loaded trusted sources + print("\n=== Loaded Trusted Sources ===") + print(provider.trusted_sources) + + # Test general search + print("\n=== Testing General Search ===") + general_result = provider.search( + "What is artificial intelligence?", + max_results=10 # Changed from 3 to 10 + ) + + if 'error' in general_result: + print(f"Error in general search: {general_result['error']}") + else: + print("\nTop Results:") + for idx, result in enumerate(general_result['results'], 1): + print(f"\n{idx}. {result['title']}") + print(f" URL: {result['url']}") + print(f" Preview: {result['content'][:400]}...") + + # Test news search + print("\n\n=== Testing News Search ===") + news_result = provider.search( + "mike tyson fight", + topic="news", + days=3 + ) + + if 'error' in news_result: + print(f"Error in news search: {news_result['error']}") + else: + print("\nRecent Articles:") + for idx, article in enumerate(news_result['articles'], 1): + print(f"\n{idx}. {article['title']}") + print(f" Published: {article['published_date']}") + print(f" URL: {article['url']}") + print(f" Preview: {article['content'][:400]}...") diff --git a/search_providers/brave_provider.py b/search_providers/brave_provider.py new file mode 100644 index 0000000..cca0a76 --- /dev/null +++ b/search_providers/brave_provider.py @@ -0,0 +1,308 @@ +from typing import Dict, Any, Optional +import os +import sys +from pathlib import Path +import requests +from datetime import datetime, timedelta +import json +from concurrent.futures import ThreadPoolExecutor + +# Add parent directory to path for imports when running as script +if __name__ == "__main__": + sys.path.append(str(Path(__file__).parent.parent)) + from search_providers.base_provider import BaseSearchProvider +else: + from .base_provider import BaseSearchProvider + +class BraveSearchProvider(BaseSearchProvider): + """ + Brave implementation of the search provider interface. + Handles both web and news-specific searches using Brave's APIs. + """ + + WEB_SEARCH_ENDPOINT = "https://api.search.brave.com/res/v1/web/search" + NEWS_SEARCH_ENDPOINT = "https://api.search.brave.com/res/v1/news/search" + SUMMARIZER_ENDPOINT = "https://api.search.brave.com/res/v1/summarizer/search" + + def __init__(self, api_key: Optional[str] = None): + """ + Initialize the Brave search provider. + + Args: + api_key: Optional Brave API key. If not provided, will try to get from environment. + """ + self.api_key = api_key or os.getenv("BRAVE_AI_API_KEY") + self.pro_api_key = os.getenv("BRAVE_AI_PRO_API_KEY") #Optional, used for AI summary requests + self.headers = { + 'X-Subscription-Token': self.api_key, + 'Accept': 'application/json' + } if self.api_key else None + self.proheaders = { + 'X-Subscription-Token': self.pro_api_key, + 'Accept': 'application/json' + } if self.pro_api_key else None + def is_configured(self) -> bool: + """Check if Brave API is properly configured.""" + return self.headers is not None + + def get_brave_summary(self, query): + # Query parameters + params = { + "q": query, + "summary": 1 + } + + # Make the initial web search request to get summarizer key + search_response = requests.get(self.WEB_SEARCH_ENDPOINT, headers=self.proheaders, params=params) + + if search_response.status_code == 200: + data = search_response.json() + + if "summarizer" in data and "key" in data["summarizer"]: + summarizer_key = data["summarizer"]["key"] + + # Make request to summarizer endpoint + summarizer_params = { + "key": summarizer_key, + "entity_info": 1 + } + + summary_response = requests.get( + self.SUMMARIZER_ENDPOINT, + headers=self.proheaders, + params=summarizer_params + ) + + if summary_response.status_code == 200: + summary_data = summary_response.json() + try: + return summary_data['summary'][0]['data'] + except (KeyError, IndexError): + return None + + return None + + def search(self, query: str, **kwargs) -> Dict[str, Any]: + """ + Perform a search using Brave API. + + Args: + query: The search query string + **kwargs: Additional search parameters: + - topic: Optional search topic (e.g., "news") + - max_results: Maximum number of results (default: 10) + - market: Market code (default: "en-US") + - days: Number of days to look back (for news searches) + + Returns: + Dict containing search results or error information + """ + if not self.is_configured(): + return {'error': 'Brave API key not configured'} + + try: + # Set default search parameters + search_params = { + 'count': str(kwargs.get('max_results', 10)), + 'country': kwargs.get('market', 'us'), # Brave uses country code + 'q': query + } + + # Determine if this is a news search + if kwargs.get('topic') == 'news': + # Add freshness parameter for news if days specified + if 'days' in kwargs: + days = kwargs['days'] + if days <= 1: + search_params['freshness'] = 'pd' # past day + elif days <= 7: + search_params['freshness'] = 'pw' # past week + else: + search_params['freshness'] = 'pm' # past month + + response = requests.get( + self.NEWS_SEARCH_ENDPOINT, + headers=self.headers, + params=search_params + ) + + response_data = response.json() + result = self._process_news_results(response_data, days=kwargs.get('days', 3), topic=query) + else: + response = requests.get( + self.WEB_SEARCH_ENDPOINT, + headers=self.headers, + params=search_params + ) + response_data = response.json() + result = self._process_general_results(response_data) + + # Include summarizer response if it exists + summary_response = self.get_brave_summary(query) + if summary_response: + result['summarizer'] = summary_response + + return result + + except requests.exceptions.RequestException as e: + return {'error': f'API request failed: {str(e)}'} + except Exception as e: + return {'error': f'An unexpected error occurred: {str(e)}'} + + def _process_general_results(self, response: Dict[str, Any]) -> Dict[str, Any]: + """Process results for general web searches.""" + web_results = response.get('web', {}).get('results', []) + with ThreadPoolExecutor() as executor: + # Use index as key instead of the result dictionary + futures = {i: executor.submit(self.get_brave_summary, result.get('title', '')) + for i, result in enumerate(web_results[:2])} + + results = [] + for i, result in enumerate(web_results): + summary = None + if i < 2: + try: + summary = futures[i].result() + except Exception as e: + print(f"Error getting summary: {e}") + + processed_result = { + 'title': result.get('title', ''), + 'url': result.get('url', ''), + 'content': result.get('description', ''), + 'score': result.get('score', 1.0), + 'extra_snippets': None, + 'summary': None + } + if summary: + processed_result['summary'] = summary + else: + processed_result['extra_snippets'] = result.get('extra_snippets', []) + results.append(processed_result) + return {'results': results} + + def _process_news_results(self, response: Dict[str, Any], days: int, topic: str) -> Dict[str, Any]: + """Process results for news-specific searches.""" + news_results = response.get('results', []) + def convert_age_to_minutes(age_str: str) -> int: + """ + Convert age string to minutes. + + Args: + age_str: Age string in the format of "X minutes", "X hours", "X days" + + Returns: + Age in minutes + """ + age_value = int(age_str.split()[0]) + age_unit = age_str.split()[1] + if age_unit == 'minutes': + return age_value + elif age_unit == 'hours': + return age_value * 60 + elif age_unit == 'days': + return age_value * 1440 # 24 hours * 60 minutes + else: + return 0 # Default to 0 if unknown unit + + # Sort news results based on the age field + news_results.sort(key=lambda x: convert_age_to_minutes(x.get('age', '0 minutes'))) + + with ThreadPoolExecutor() as executor: + # Use enumerate to create futures with index as key + futures = {i: executor.submit(self.get_brave_summary, article_data.get('title', '')) + for i, article_data in enumerate(news_results)} + + articles = [] + for i, article_data in enumerate(news_results): + try: + summary = futures[i].result() + except Exception as e: + print(f"Error getting summary: {e}") + summary = None + + article = { + 'title': article_data.get('title', ''), + 'url': article_data.get('url', ''), + 'published_date': article_data.get('age', ''), + 'breaking' : article_data.get('breaking', False), + 'content': article_data.get('description', ''), + 'extra_snippets': None, + 'summary': None, + 'score': article_data.get('score', 1.0) + } + if summary: + article['summary'] = summary + else: + article['extra_snippets'] = article_data.get('extra_snippets', []) + articles.append(article) + + return { + 'articles': articles, + 'time_period': f"Past {days} days", + 'topic': topic + } + +if __name__ == "__main__": + # Test code using actual API + provider = BraveSearchProvider() + if not provider.is_configured(): + print("Error: Brave API key not configured") + exit(1) + + # Test general search + print("\n=== Testing General Search ===") + general_result = provider.search( + "What is artificial intelligence?", + max_results=1 # Increased max_results to test summary limiting + ) + + if 'error' in general_result: + print(f"Error in general search: {general_result['error']}") + else: + print("\nTop Results:") + for idx, result in enumerate(general_result['results'], 1): + print(f"\n{idx}. {result['title']}") + print(f" URL: {result['url']}") + print(f" Preview: {result['content']}...") + print(f" Score: {result['score']}") + if result['extra_snippets']: + print(" Extra Snippets:") + for snippet in result['extra_snippets']: + print(f" - {snippet}") + if result['summary']: # Check if summary exists before printing + print(f" Summary: {result.get('summary', '')}...") + import time + time.sleep(1) + + # Test news search + print("\n\n=== Testing News Search ===") + import time + start_time = time.time() + news_result = provider.search( + "mike tyson fight", + topic="news", + days=3, + max_results=1 + ) + end_time = time.time() + + + if 'error' in news_result: + print(f"Error in news search: {news_result['error']}") + else: + print("\nRecent Articles:") + for idx, article in enumerate(news_result['articles'], 1): + print(f"\n{idx}. {article['title']}") + print(f" Published: {article['published_date']}") + print(f" Breaking: {article['breaking']}") + print(f" URL: {article['url']}") + print(f" Preview: {article['content'][:400]}...") + if article['extra_snippets']: + print(" Extra Snippets:") + for snippet in article['extra_snippets']: + print(f" - {snippet}") + if article['summary']: + print(f" Summary: {article.get('summary', '')}...") + + print(f"Execution time: {round(end_time - start_time, 1)} seconds") diff --git a/search_providers/exa_provider.py b/search_providers/exa_provider.py new file mode 100644 index 0000000..a20404b --- /dev/null +++ b/search_providers/exa_provider.py @@ -0,0 +1,231 @@ +from typing import Dict, Any, Optional +import os +import sys +import json +from pathlib import Path +import requests +from datetime import datetime, timedelta + +# Add parent directory to path for imports when running as script +if __name__ == "__main__": + sys.path.append(str(Path(__file__).parent.parent)) + from search_providers.base_provider import BaseSearchProvider +else: + from .base_provider import BaseSearchProvider + +class ExaSearchProvider(BaseSearchProvider): + """ + Exa.ai implementation of the search provider interface. + Handles web searches with optional full page content retrieval. + """ + + def __init__(self, api_key: Optional[str] = None): + """ + Initialize the Exa search provider. + + Args: + api_key: Optional Exa API key. If not provided, will try to get from environment. + """ + self.api_key = api_key or os.getenv("EXA_API_KEY") + self.base_url = "https://api.exa.ai/search" + self.trusted_sources = self._load_trusted_sources() + + def _load_trusted_sources(self) -> list: + """Load trusted news sources from JSON file.""" + try: + json_path = Path(__file__).parent / 'trusted_news_sources.json' + with open(json_path) as f: + data = json.load(f) + return data.get('trusted_sources', []) + except Exception as e: + print(f"Warning: Could not load trusted sources: {e}") + return [] + + def is_configured(self) -> bool: + """Check if Exa client is properly configured.""" + return bool(self.api_key) + + def search(self, query: str, **kwargs) -> Dict[str, Any]: + """ + Perform a search using Exa API. + + Args: + query: The search query string + **kwargs: Additional search parameters: + - include_content: Whether to retrieve full page contents (default: False) + - max_results: Maximum number of results (default: 3) + - days: Number of days to look back (for news searches) + + Returns: + Dict containing search results or error information + """ + if not self.is_configured(): + return {'error': 'Exa API key not configured'} + + try: + # Set default search parameters + search_params = { + 'query': query, + 'type': 'neural', + 'useAutoprompt': True, + 'numResults': kwargs.get('max_results', 3), + } + + # Add optional parameters + if kwargs.get('include_content'): + search_params['contents'] = { + "highlights": True, + "summary": True + } + + if kwargs.get('days'): + # Convert days to timestamp for time-based filtering + date_limit = datetime.now() - timedelta(days=kwargs['days']) + search_params['startPublishedTime'] = date_limit.isoformat() + + # Add trusted domains for news searches + if kwargs.get('topic') == 'news' and self.trusted_sources: + search_params['includeDomains'] = self.trusted_sources + + # Make API request + headers = { + 'x-api-key': self.api_key, + 'Content-Type': 'application/json', + 'accept': 'application/json' + } + + response = requests.post( + self.base_url, + headers=headers, + json=search_params + ) + response.raise_for_status() + data = response.json() + + # Process results based on whether it's a news search + if kwargs.get('topic') == 'news': + return self._process_news_results( + data, + days=kwargs.get('days', 3), + topic=query + ) + else: + return self._process_general_results(data) + + except requests.exceptions.RequestException as e: + if e.response and e.response.status_code == 401: + return {'error': 'Invalid Exa API key'} + elif e.response and e.response.status_code == 429: + return {'error': 'Exa API rate limit exceeded'} + else: + return {'error': f'An error occurred while making the request: {str(e)}'} + except Exception as e: + return {'error': f'An unexpected error occurred: {str(e)}'} + + def _process_general_results(self, response: Dict[str, Any]) -> Dict[str, Any]: + """Process results for general searches.""" + results = [] + for result in response.get('results', []): + processed_result = { + 'title': result.get('title', ''), + 'url': result.get('url', ''), + 'highlights': result.get('highlights', []), + 'summary': result.get('summary', ''), + 'score': result.get('score', 0.0) + } + results.append(processed_result) + + return { + 'results': results, + 'autoprompt': response.get('autopromptString', '') + } + + def _process_news_results(self, response: Dict[str, Any], days: int, topic: str) -> Dict[str, Any]: + """Process results for news-specific searches.""" + articles = [] + for article in response.get('results', []): + processed_article = { + 'title': article.get('title', ''), + 'url': article.get('url', ''), + 'published_date': article.get('publishedDate', ''), + 'highlights': article.get('highlights', []), + 'summary': article.get('summary', ''), + 'score': article.get('score', 0.0) + } + articles.append(processed_article) + + return { + 'articles': articles, + 'time_period': f"Past {days} days", + 'topic': topic, + 'autoprompt': response.get('autopromptString', '') + } + +if __name__ == "__main__": + # Test code for the Exa provider + provider = ExaSearchProvider() + if not provider.is_configured(): + print("Error: Exa API key not configured") + exit(1) + + # Test general search + print("\n=== Testing General Search ===") + import time + start_time = time.time() + general_result = provider.search( + "What is artificial intelligence?", + max_results=3, + include_content=True + ) + end_time = time.time() + + if 'error' in general_result: + print("Error:", general_result['error']) + else: + print("\nTop Results:") + print(f"Autoprompt: {general_result.get('autoprompt', '')}") + for idx, result in enumerate(general_result['results'], 1): + print(f"\n{idx}. {result['title']}") + print(f" URL: {result['url']}") + print(f" Score: {result['score']}") + print(f" Summary: {result['summary']}") + if result['highlights']: + print(" Highlights:") + for highlight in result['highlights']: + print(f" - {highlight}") + print(f"\n\nTime taken for general search: {end_time - start_time} seconds") + + # Test news search + print("\n\n=== Testing News Search ===") + start_time = time.time() + news_result = provider.search( + "Latest developments in AI", + topic="news", + days=3, + max_results=3, + include_content=True + ) + end_time = time.time() + + if 'error' in news_result: + print("Error:", news_result['error']) + else: + print("\nRecent Articles:") + print(f"Autoprompt: {news_result.get('autoprompt', '')}") + for idx, article in enumerate(news_result['articles'], 1): + print(f"\n{idx}. {article['title']}") + print(f" Published: {article['published_date']}") + print(f" URL: {article['url']}") + print(f" Score: {article['score']}") + print(f" Summary: {article['summary']}") + if article['highlights']: + print(" Highlights:") + for highlight in article['highlights']: + print(f" - {highlight}") + print(f"\n\nTime taken for news search: {end_time - start_time} seconds") + + # Test error handling + print("\n\n=== Testing Error Handling ===") + bad_provider = ExaSearchProvider(api_key="invalid_key") + error_result = bad_provider.search("test query") + print("\nExpected error with invalid API key:", error_result['error']) diff --git a/search_providers/factory.py b/search_providers/factory.py new file mode 100644 index 0000000..aaa2eab --- /dev/null +++ b/search_providers/factory.py @@ -0,0 +1,78 @@ +from typing import Optional, Dict, Type +from .base_provider import BaseSearchProvider +from .tavily_provider import TavilySearchProvider +from .bing_provider import BingSearchProvider +from .brave_provider import BraveSearchProvider +from .exa_provider import ExaSearchProvider + +class SearchProviderFactory: + """ + Factory class for creating search provider instances. + Supports multiple provider types and handles provider configuration. + """ + + # Registry of available providers + _providers: Dict[str, Type[BaseSearchProvider]] = { + 'tavily': TavilySearchProvider, + 'bing': BingSearchProvider, + 'brave': BraveSearchProvider, + 'exa': ExaSearchProvider + } + + @classmethod + def get_provider( + cls, + provider_type: Optional[str] = None, + api_key: Optional[str] = None + ) -> BaseSearchProvider: + """ + Get an instance of the specified search provider. + + Args: + provider_type: Type of provider to create (defaults to config.SEARCH_PROVIDER) + api_key: Optional API key for the provider + + Returns: + An instance of the specified provider + + Raises: + ValueError: If the specified provider type is not supported + """ + # Import config here to avoid circular imports + from config import SEARCH_PROVIDER + + # Use provided provider_type or fall back to config + provider_type = provider_type or SEARCH_PROVIDER + + provider_class = cls._providers.get(provider_type.lower()) + if not provider_class: + raise ValueError( + f"Unsupported provider type: {provider_type}. " + f"Available providers: {', '.join(cls._providers.keys())}" + ) + + return provider_class(api_key=api_key) + + @classmethod + def register_provider( + cls, + provider_type: str, + provider_class: Type[BaseSearchProvider] + ) -> None: + """ + Register a new provider type. + + Args: + provider_type: Name of the provider type + provider_class: Provider class that implements BaseSearchProvider + + Raises: + TypeError: If provider_class doesn't inherit from BaseSearchProvider + """ + if not issubclass(provider_class, BaseSearchProvider): + raise TypeError( + f"Provider class must inherit from BaseSearchProvider. " + f"Got {provider_class.__name__}" + ) + + cls._providers[provider_type.lower()] = provider_class diff --git a/search_providers/tavily_provider.py b/search_providers/tavily_provider.py new file mode 100644 index 0000000..043ef94 --- /dev/null +++ b/search_providers/tavily_provider.py @@ -0,0 +1,160 @@ +from typing import Dict, Any, Optional +import os +import sys +from pathlib import Path + +# Add parent directory to path for imports when running as script +if __name__ == "__main__": + sys.path.append(str(Path(__file__).parent.parent)) + from search_providers.base_provider import BaseSearchProvider +else: + from .base_provider import BaseSearchProvider + +from tavily import TavilyClient, MissingAPIKeyError, InvalidAPIKeyError, UsageLimitExceededError + +class TavilySearchProvider(BaseSearchProvider): + """ + Tavily implementation of the search provider interface. + Handles both general and news-specific searches. + """ + + def __init__(self, api_key: Optional[str] = None): + """ + Initialize the Tavily search provider. + + Args: + api_key: Optional Tavily API key. If not provided, will try to get from environment. + """ + self.api_key = api_key or os.getenv("TAVILY_API_KEY") + try: + self.client = TavilyClient(api_key=self.api_key) if self.api_key else None + except MissingAPIKeyError: + self.client = None + + def is_configured(self) -> bool: + """Check if Tavily client is properly configured.""" + return self.client is not None + + def search(self, query: str, **kwargs) -> Dict[str, Any]: + """ + Perform a search using Tavily API. + + Args: + query: The search query string + **kwargs: Additional search parameters: + - search_depth: "basic" or "advanced" (default: "basic") + - topic: Optional search topic (e.g., "news") + - max_results: Maximum number of results (default: 5) + - include_answer: Whether to include AI-generated answer (default: True) + - include_images: Whether to include images (default: False) + - days: Number of days to look back (for news searches) + + Returns: + Dict containing search results or error information + """ + if not self.is_configured(): + return {'error': 'Tavily API key not configured'} + + try: + # Set default search parameters + search_params = { + 'search_depth': "basic", + 'max_results': 5, + 'include_answer': True, + 'include_images': False + } + + # Update with any provided parameters + search_params.update(kwargs) + + # Execute search + response = self.client.search(query, **search_params) + + # Process results based on whether it's a news search + if kwargs.get('topic') == 'news': + return self._process_news_results( + response, + days=kwargs.get('days', 3), + topic=query + ) + else: + return self._process_general_results(response) + + except InvalidAPIKeyError: + return {'error': 'Invalid Tavily API key'} + except UsageLimitExceededError: + return {'error': 'Tavily API usage limit exceeded'} + except Exception as e: + return {'error': f'An unexpected error occurred: {e}'} + + def _process_general_results(self, response: Dict[str, Any]) -> Dict[str, Any]: + """Process results for general searches.""" + return { + 'answer': response.get('answer', ''), + 'results': [{ + 'title': result.get('title', ''), + 'url': result.get('url', ''), + 'content': result.get('content', '')[:500] + '...' if result.get('content') else '', + 'score': result.get('score', 0.0) + } for result in response.get('results', [])[:3]] + } + + def _process_news_results(self, response: Dict[str, Any], days: int, topic: str) -> Dict[str, Any]: + """Process results for news-specific searches.""" + return { + 'answer': response.get('answer', ''), + 'articles': [{ + 'title': article.get('title', ''), + 'url': article.get('url', ''), + 'published_date': article.get('published_date', ''), + 'content': article.get('content', '')[:500] + '...' if article.get('content') else '', + 'score': article.get('score', 0.0) + } for article in response.get('results', [])], + 'time_period': f"Past {days} days", + 'topic': topic + } + +if __name__ == "__main__": + # Test code for the Tavily provider + provider = TavilySearchProvider() + if not provider.is_configured(): + print("Error: Tavily API key not configured") + exit(1) + + # Test general search + print("\n=== Testing General Search ===") + general_result = provider.search( + "What is artificial intelligence?", + search_depth="advanced", + max_results=3 + ) + print("\nQuery Answer:", general_result['answer']) + print("\nTop Results:") + for idx, result in enumerate(general_result['results'], 1): + print(f"\n{idx}. {result['title']}") + print(f" URL: {result['url']}") + print(f" Score: {result['score']}") + print(f" Preview: {result['content'][:200]}...") + + # Test news search + print("\n\n=== Testing News Search ===") + news_result = provider.search( + "Latest developments in AI", + topic="news", + days=3, + search_depth="advanced" + ) + print("\nNews Summary:", news_result['answer']) + print("\nRecent Articles:") + for idx, article in enumerate(news_result['articles'], 1): + print(f"\n{idx}. {article['title']}") + print(f" Published: {article['published_date']}") + print(f" URL: {article['url']}") + print(f" Score: {article['score']}") + print(f" Preview: {article['content'][:400]}...") + + # Test error handling + print("\n\n=== Testing Error Handling ===") + bad_provider = TavilySearchProvider(api_key="invalid_key") + error_result = bad_provider.search("test query") + print("\nExpected error with invalid API key:", error_result['error']) diff --git a/search_providers/trusted_news_sources.json b/search_providers/trusted_news_sources.json new file mode 100644 index 0000000..b5e3c77 --- /dev/null +++ b/search_providers/trusted_news_sources.json @@ -0,0 +1,71 @@ +{ + "trusted_sources": [ + "apnews.com", + "reuters.com", + "bbc.com", + "wsj.com", + "nytimes.com", + "economist.com", + "bloomberg.com", + "ft.com", + "aljazeera.com", + "afp.com", + "techcrunch.com", + "wired.com", + "arstechnica.com", + "theverge.com", + "cnet.com", + "theguardian.com", + "businessinsider.com", + "dw.com", + "time.com", + "afp.com", + "pbs.org", + "npr.org", + "cnbc.com", + "forbes.com", + "thehill.com", + "politico.com", + "axios.com", + "euronews.com", + "japantimes.co.jp", + "scmp.com", + "straitstimes.com", + "themoscowtimes.com", + "haaretz.com", + "timesofindia.com", + "globeandmail.com", + "abc.net.au", + "rte.ie", + "swissinfo.ch", + "thelocal.fr", + "thelocal.de", + "thelocal.se", + "kyivpost.com", + "arabnews.com", + "koreatimes.co.kr", + "bangkokpost.com", + "zdnet.com", + "cnet.com", + "engadget.com", + "gizmodo.com", + "thenextweb.com", + "venturebeat.com", + "techradar.com", + "tomshardware.com", + "anandtech.com", + "slashdot.org", + "techspot.com", + "phoronix.com", + "404media.co", + "theregister.com", + "techdirt.com", + "techrepublic.com", + "mit.edu", + "protocol.com", + "theinformation.com", + "restofworld.org", + "news.ycombinator.com" + ] + } + \ No newline at end of file diff --git a/skills/news_skill.py b/skills/news_skill.py new file mode 100644 index 0000000..b0e343d --- /dev/null +++ b/skills/news_skill.py @@ -0,0 +1,237 @@ +from nltk import word_tokenize, pos_tag +from typing import Dict, Any, Optional +import unittest +from unittest.mock import patch, MagicMock +from datetime import datetime, timedelta +import re +import sys +from pathlib import Path +sys.path.append(str(Path(__file__).parent.parent)) +from search_providers import SearchProviderFactory + +class NewsSkill: + def __init__(self, provider_type: Optional[str] = None): + self.topic_indicators = { + 'about', 'regarding', 'on', 'related to', 'news about', + 'latest', 'recent', 'current', 'today', 'breaking', 'update' + } + # Expanded time indicators with more natural language expressions + self.time_indicators = { + 'today': 1, + 'yesterday': 2, + 'last week': 7, + 'past week': 7, + 'this week': 7, + 'recent': 3, + 'latest': 3, + 'last month': 30, + 'past month': 30, + 'this month': 30, + 'past year': 365, + 'last year': 365, + 'this year': 365, + 'past few days': 3, + 'last few days': 3, + 'past couple days': 2, + 'last couple days': 2, + 'past 24 hours': 1, + 'last 24 hours': 1, + 'past hour': 1, + 'last hour': 1, + 'past few weeks': 21, + 'last few weeks': 21, + 'past couple weeks': 14, + 'last couple weeks': 14, + 'past few months': 90, + 'last few months': 90, + 'past couple months': 60, + 'last couple months': 60 + } + # Regular expressions for relative time + self.relative_time_patterns = [ + (r'past (\d+) days?', lambda x: int(x)), + (r'last (\d+) days?', lambda x: int(x)), + (r'past (\d+) weeks?', lambda x: int(x) * 7), + (r'last (\d+) weeks?', lambda x: int(x) * 7), + (r'past (\d+) months?', lambda x: int(x) * 30), + (r'last (\d+) months?', lambda x: int(x) * 30), + (r'past (\d+) years?', lambda x: int(x) * 365), + (r'last (\d+) years?', lambda x: int(x) * 365) + ] + self.provider = SearchProviderFactory.get_provider(provider_type=provider_type) + + def extract_time_reference(self, text: str) -> int: + """ + Extract a time reference from text and convert it to number of days. + Handles both fixed and relative time expressions. + + Args: + text: The text to extract the time reference from. + + Returns: + Number of days to look back for news (default: 3 if no time reference found) + """ + text_lower = text.lower() + + # Check for exact matches first + for time_ref, days in self.time_indicators.items(): + if time_ref in text_lower: + return days + + # Check for relative time patterns + for pattern, converter in self.relative_time_patterns: + match = re.search(pattern, text_lower) + if match: + return converter(match.group(1)) + + # If no time reference found, default to 3 days + return 3 + + def extract_search_topic(self, text: str) -> str: + """ + Extract the main search topic from the query text. + Improved to handle compound topics better. + + Args: + text: The query text to extract the topic from. + + Returns: + The extracted search topic or the original text if no specific topic is found. + """ + # Tokenize and tag parts of speech + tokens = word_tokenize(text.lower()) + tagged = pos_tag(tokens) + + # Look for topic after common indicators + for i, (word, _) in enumerate(tagged): + if word in self.topic_indicators and i + 1 < len(tagged): + # Extract everything after the indicator as the topic + topic_words = [] + for word, pos in tagged[i+1:]: + # Include nouns, adjectives, verbs, and conjunctions for compound topics + if pos.startswith(('NN', 'JJ', 'VB')) or word in ['and', 'or']: + topic_words.append(word) + if topic_words: + return ' '.join(topic_words) + + # If no specific pattern found, use the original text with common words filtered out + stop_words = {'what', 'is', 'are', 'the', 'tell', 'me', 'search', 'find', 'get', 'news'} + topic_words = [] + for word, pos in tagged: + # Include conjunctions to better handle compound topics + if (word not in stop_words and pos.startswith(('NN', 'JJ', 'VB'))) or word in ['and', '&', 'or']: + topic_words.append(word) + + return ' '.join(topic_words) if topic_words else text + + def search_news(self, query: str, provider_type: Optional[str] = None) -> Dict[str, Any]: + """ + Search for news articles using the configured search provider. + + Args: + query: The search query string. + provider_type: Optional provider type to use for this specific search + + Returns: + Dict containing the search results or error information. + """ + # Use a new provider just for this search if specified + provider = (SearchProviderFactory.get_provider(provider_type=provider_type) + if provider_type else self.provider) + + if not provider.is_configured(): + return {'error': 'Search provider not configured'} + + search_topic = self.extract_search_topic(query) + days_to_search = self.extract_time_reference(query) + + return provider.search( + search_topic, + search_depth="basic", + topic="news", + max_results=5, + include_answer=True, + include_images=False, + days=days_to_search + ) + +# Test code +class TestNewsSkill(unittest.TestCase): + def setUp(self): + self.news_skill = NewsSkill() + + def test_extract_search_topic(self): + test_cases = [ + ("What's the latest news about artificial intelligence?", "artificial intelligence"), + ("Tell me the current news regarding climate change", "climate change"), + ("Search for news about SpaceX launches", "spacex launches"), + ("What's happening in technology today?", "technology"), + ("Tell me about AI and machine learning", "ai and machine learning"), # Test compound topic + ] + + for query, expected in test_cases: + result = self.news_skill.extract_search_topic(query) + self.assertEqual(result.lower(), expected.lower()) + + def test_extract_time_reference(self): + test_cases = [ + ("What happened today in tech?", 1), + ("Show me last week's news about AI", 7), + ("What's the latest on climate change?", 3), + ("Tell me about space news from last month", 30), + ("What happened in politics this year?", 365), + ("Show me news about crypto", 3), # Default case + ("News from past 5 days", 5), # Test relative time + ("Updates from last 2 weeks", 14), # Test relative time + ] + + for query, expected_days in test_cases: + result = self.news_skill.extract_time_reference(query) + self.assertEqual(result, expected_days) + + @patch('search_providers.factory.SearchProviderFactory.get_provider') + def test_search_news_success(self, mock_factory): + # Mock successful provider response + mock_response = { + "answer": "Recent developments in AI...", + "articles": [ + { + "title": "AI Breakthrough", + "url": "https://example.com/ai-news", + "published_date": "2024-03-20", + "content": "Scientists have made significant progress...", + "score": 0.95 + } + ], + "time_period": "Past 3 days", + "topic": "AI" + } + mock_provider = MagicMock() + mock_provider.is_configured.return_value = True + mock_provider.search.return_value = mock_response + mock_factory.return_value = mock_provider + + # Create a new instance with the mocked provider + self.news_skill = NewsSkill() + + result = self.news_skill.search_news("What's new in AI?") + self.assertIn("answer", result) + self.assertIn("articles", result) + self.assertTrue(len(result["articles"]) > 0) + self.assertIn("score", result["articles"][0]) + + @patch('search_providers.factory.SearchProviderFactory.get_provider') + def test_search_news_provider_not_configured(self, mock_factory): + mock_provider = MagicMock() + mock_provider.is_configured.return_value = False + mock_factory.return_value = mock_provider + + # Create a new instance with the mocked provider + self.news_skill = NewsSkill() + + result = self.news_skill.search_news("What's new in AI?") + self.assertIn("error", result) + self.assertEqual(result["error"], "Search provider not configured") + +if __name__ == '__main__': + unittest.main() diff --git a/skills/search_skill.py b/skills/search_skill.py new file mode 100644 index 0000000..10be7f2 --- /dev/null +++ b/skills/search_skill.py @@ -0,0 +1,136 @@ +from nltk import word_tokenize, pos_tag +from typing import Dict, Any, Optional +import unittest +from unittest.mock import patch, MagicMock +import sys +from pathlib import Path +sys.path.append(str(Path(__file__).parent.parent)) +from search_providers import SearchProviderFactory + +class SearchSkill: + def __init__(self, provider_type: Optional[str] = None): + self.topic_indicators = { + 'about', 'for', 'on', 'related to', 'search', + 'find', 'look up', 'information about' + } + self.provider = SearchProviderFactory.get_provider(provider_type=provider_type) + + def extract_search_topic(self, text: str) -> str: + """ + Extract the main search topic from the query text. + + Args: + text: The query text to extract the topic from. + + Returns: + The extracted search topic or the original text if no specific topic is found. + """ + # Tokenize and tag parts of speech + tokens = word_tokenize(text.lower()) + tagged = pos_tag(tokens) + + # Look for topic after common indicators + for i, (word, _) in enumerate(tagged): + if word in self.topic_indicators and i + 1 < len(tagged): + # Filter out common words and get only content words + topic_words = [] + for word, pos in tagged[i+1:]: + if pos.startswith(('NN', 'JJ', 'VB')): # Only include nouns, adjectives, and verbs + topic_words.append(word) + if topic_words: + return ' '.join(topic_words) + + # If no specific pattern found, use the original text with common words filtered out + stop_words = {'what', 'is', 'are', 'the', 'tell', 'me', 'can', 'you', 'please', 'for', 'about'} + topic_words = [word for word, pos in tagged if word not in stop_words and pos.startswith(('NN', 'JJ', 'VB'))] + + return ' '.join(topic_words) if topic_words else text + + def search(self, query: str, provider_type: Optional[str] = None) -> Dict[str, Any]: + """ + Perform a general search using the configured search provider. + + Args: + query: The search query string. + provider_type: Optional provider type to use for this specific search + + Returns: + Dict containing the search results or error information. + """ + # Use a new provider just for this search if specified + provider = (SearchProviderFactory.get_provider(provider_type=provider_type) + if provider_type else self.provider) + + if not provider.is_configured(): + return {'error': 'Search provider not configured'} + + search_topic = self.extract_search_topic(query) + return provider.search( + search_topic, + search_depth="basic", + max_results=3, + include_answer=True, + include_images=False + ) + +# Test code +class TestSearchSkill(unittest.TestCase): + def setUp(self): + self.search_skill = SearchSkill() + + def test_extract_search_topic(self): + test_cases = [ + ("Search for quantum computing", "quantum computing"), + ("Tell me about the history of Rome", "history rome"), + ("Look up information about electric cars", "electric cars"), + ("What is machine learning?", "machine learning"), + ("Find recipes for chocolate cake", "recipes chocolate cake"), + ] + + for query, expected in test_cases: + result = self.search_skill.extract_search_topic(query) + self.assertEqual(result.lower(), expected.lower()) + + @patch('search_providers.factory.SearchProviderFactory.get_provider') + def test_search_success(self, mock_factory): + # Mock successful provider response + mock_response = { + "answer": "Here's what I found about quantum computing...", + "results": [ + { + "title": "Introduction to Quantum Computing", + "url": "https://example.com/quantum", + "content": "Quantum computing is an emerging technology...", + "score": 0.95 + } + ] + } + mock_provider = MagicMock() + mock_provider.is_configured.return_value = True + mock_provider.search.return_value = mock_response + mock_factory.return_value = mock_provider + + # Create a new instance with the mocked provider + self.search_skill = SearchSkill() + + result = self.search_skill.search("What is quantum computing?") + self.assertIn("answer", result) + self.assertIn("results", result) + self.assertTrue(len(result["results"]) > 0) + self.assertIn("score", result["results"][0]) + + @patch('search_providers.factory.SearchProviderFactory.get_provider') + def test_search_provider_not_configured(self, mock_factory): + mock_provider = MagicMock() + mock_provider.is_configured.return_value = False + mock_factory.return_value = mock_provider + + # Create a new instance with the mocked provider + self.search_skill = SearchSkill() + + result = self.search_skill.search("What is quantum computing?") + self.assertIn("error", result) + self.assertEqual(result["error"], "Search provider not configured") + +if __name__ == '__main__': + unittest.main() diff --git a/skills/weather_skill.py b/skills/weather_skill.py new file mode 100644 index 0000000..3649d06 --- /dev/null +++ b/skills/weather_skill.py @@ -0,0 +1,241 @@ +import requests +import os +from nltk import word_tokenize, pos_tag, ne_chunk +from nltk.tree import Tree +from typing import Dict, Any, Optional +from config_loader import config +import unittest +from unittest.mock import patch + +class WeatherSkill: + def __init__(self): + self.time_indicators = { + 'today', 'tomorrow', 'tonight', 'morning', 'afternoon', + 'evening', 'week', 'weekend', 'monday', 'tuesday', + 'wednesday', 'thursday', 'friday', 'saturday', 'sunday' + } + self.location_labels = {'GPE', 'LOC'} + + def extract_location(self, text: str) -> Optional[str]: + """ + Extract a location from a text using named entity recognition. + + Args: + text: The text to extract the location from. + + Returns: + The location found in the text, or None if no location was found. + """ + tokens = word_tokenize(text) + tagged = pos_tag(tokens) + chunks = ne_chunk(tagged) + + locations = [] + if isinstance(chunks, Tree): + for subtree in chunks: + if isinstance(subtree, Tree) and subtree.label() in self.location_labels: + location = ' '.join([token for token, pos in subtree.leaves()]) + locations.append(location) + + return locations[0] if locations else None + + def extract_time_reference(self, text: str) -> Optional[str]: + """ + Extract a time reference from a text using a set of predefined time indicators. + + Args: + text: The text to extract the time reference from. + + Returns: + The time reference found in the text, or None if no time reference was found. + """ + tokens = word_tokenize(text.lower()) + for token in tokens: + if token in self.time_indicators: + return token + return None + + def condense_json(self, input_json): + """ + Condense a JSON object by rounding all floating point numbers to integers + and removing specific weather elements (sea_level, grnd_level, pressure, visibility, temp_kf). + + Args: + input_json: The JSON object to condense. + + Returns: + The condensed JSON object with rounded numbers and removed elements. + """ + def round_numbers_and_filter(obj): + if isinstance(obj, dict): + filtered_dict = {} + for k, v in obj.items(): + # Skip specific weather elements + if k not in ['sea_level', 'grnd_level', 'pressure', 'visibility','temp_kf']: + filtered_dict[k] = round_numbers_and_filter(v) + return filtered_dict + elif isinstance(obj, list): + return [round_numbers_and_filter(item) for item in obj] + elif isinstance(obj, float): + return int(obj) + else: + return obj + + # Create a deep copy of the input JSON, round the numbers, and filter elements + condensed_json = round_numbers_and_filter(input_json) + + return condensed_json + + def getLocationWeather(self, query: str) -> Dict[str, Any]: + """ + Retrieve weather information for a given location from a query string. + + This function uses OpenWeatherMap API to fetch the current weather data + for a location extracted from the query. If the query includes a time + reference other than 'today', it also fetches the weather forecast. + + Args: + query (str): The query string containing the location and optionally a time reference. + + Returns: + Dict[str, Any]: A dictionary containing the weather data. If an error occurs, + the dictionary contains an 'error' key with the error message. + """ + try: + location_name = self.extract_location(query) or str(config.DEFAULT_LOCATION) + if not location_name: + return {'error': 'No location found in query and DEFAULT_LOCATION is not set'} + time_ref = self.extract_time_reference(query) + + apiKey = "34805111fe90be66c8b6923016ef27c0" + + + #if the query implies a time reference, get the forecast + if time_ref and time_ref != 'today': + forecastUrl = "https://api.openweathermap.org/data/2.5/forecast?q=" + location_name + "&units=" + str(config.DEFAULT_UNITS) + "&appid=" + apiKey + forecast_response = requests.get(forecastUrl) + forecast_response.raise_for_status() + weather_data = self.condense_json(forecast_response.json()) + + else: + #get the current weather + completeUrl = "https://api.openweathermap.org/data/2.5/weather?q=" + location_name + "&units=" + str(config.DEFAULT_UNITS) + "&appid=" + apiKey + response = requests.get(completeUrl) + response.raise_for_status() + weather_data = self.condense_json(response.json()) + + return weather_data + except requests.exceptions.RequestException as e: + return {'error': f'Error fetching weather data: {e}'} + except Exception as e: + return {'error': f'An unexpected error occurred: {e}'} + +# Test code +class TestWeatherSkill(unittest.TestCase): + def setUp(self): + self.weather_skill = WeatherSkill() + + def test_extract_location(self): + self.assertEqual(self.weather_skill.extract_location("What's the weather like in New York?"), "New York") + self.assertIsNone(self.weather_skill.extract_location("What's the weather like today?")) + + def test_extract_time_reference(self): + self.assertEqual(self.weather_skill.extract_time_reference("What's the weather like tomorrow?"), "tomorrow") + self.assertIsNone(self.weather_skill.extract_time_reference("What's the weather like in London?")) + + def test_condense_json(self): + test_json = { + "temp": 20.5, + "humidity": 45.7, + "pressure": 1013, + "sea_level": 1014, + "grnd_level": 1012, + "forecast": [ + { + "temp": 22.3, + "wind": 5.8, + "pressure": 1015, + "sea_level": 1016, + "grnd_level": 1014 + }, + { + "temp": 21.7, + "wind": 6.2, + "pressure": 1014, + "sea_level": 1015, + "grnd_level": 1013 + } + ] + } + expected = { + "temp": 20, + "humidity": 45, + "forecast": [ + { + "temp": 22, + "wind": 5 + }, + { + "temp": 21, + "wind": 6 + } + ] + } + result = self.weather_skill.condense_json(test_json) + self.assertEqual(result, expected) + + @patch('requests.get') + def test_getLocationWeather_current(self, mock_get): + # Mock response for current weather + mock_current = unittest.mock.Mock() + mock_current.json.return_value = {"weather": [{"description": "clear sky"}], "main": {"temp": 20}} + mock_current.raise_for_status = lambda: None + mock_get.return_value = mock_current + + result = self.weather_skill.getLocationWeather("What's the weather like in London?") + self.assertIn("weather", result) + self.assertIn("main", result) + + @patch('requests.get') + def test_getLocationWeather_forecast(self, mock_get): + # Mock responses for both current weather and forecast + mock_current = unittest.mock.Mock() + mock_current.json.return_value = {"weather": [{"description": "clear sky"}], "main": {"temp": 20}} + mock_current.raise_for_status = lambda: None + + mock_forecast = unittest.mock.Mock() + mock_forecast.json.return_value = { + "list": [ + {"main": {"temp": 22.5, "pressure": 1013}, "weather": [{"description": "sunny"}]}, + {"main": {"temp": 21.3, "pressure": 1014}, "weather": [{"description": "cloudy"}]} + ] + } + mock_forecast.raise_for_status = lambda: None + + # Configure mock_get to return different responses for different URLs + def side_effect(url): + if "forecast" in url: + return mock_forecast + return mock_current + + mock_get.side_effect = side_effect + + result = self.weather_skill.getLocationWeather("What's the weather like tomorrow in London?") + self.assertIn("list", result) + # Verify pressure was removed from the response + self.assertNotIn("pressure", result["list"][0]["main"]) + + def test_getLocationWeather_no_location(self): + result = self.weather_skill.getLocationWeather("What's the weather like?") + self.assertIn("error", result) + self.assertEqual(result["error"], "No location found in query") + + @patch('requests.get') + def test_getLocationWeather_api_error(self, mock_get): + mock_get.side_effect = requests.exceptions.RequestException("API Error") + result = self.weather_skill.getLocationWeather("What's the weather like in London?") + self.assertIn("error", result) + self.assertTrue(result["error"].startswith("Error fetching weather data")) + +if __name__ == '__main__': + unittest.main() From f5fc30347ebf4300582d285c57dfacb435fd1866 Mon Sep 17 00:00:00 2001 From: Graham V Date: Tue, 19 Nov 2024 17:15:24 -0500 Subject: [PATCH 3/4] Revert "feat: Add weather, news and search capabilities" This reverts commit a740826652087ce727e97f050dd7f29881652c61. --- .../always_reddy_voice_assistant_nlp/main.py | 124 ------- config_default.py | 6 - nlp_manager.py | 212 ------------ search_providers/__init__.py | 5 - search_providers/base_provider.py | 42 --- search_providers/bing_provider.py | 200 ------------ search_providers/brave_provider.py | 308 ------------------ search_providers/exa_provider.py | 231 ------------- search_providers/factory.py | 78 ----- search_providers/tavily_provider.py | 160 --------- search_providers/trusted_news_sources.json | 71 ---- skills/news_skill.py | 237 -------------- skills/search_skill.py | 136 -------- skills/weather_skill.py | 241 -------------- 14 files changed, 2051 deletions(-) delete mode 100644 actions/always_reddy_voice_assistant_nlp/main.py delete mode 100644 nlp_manager.py delete mode 100644 search_providers/__init__.py delete mode 100644 search_providers/base_provider.py delete mode 100644 search_providers/bing_provider.py delete mode 100644 search_providers/brave_provider.py delete mode 100644 search_providers/exa_provider.py delete mode 100644 search_providers/factory.py delete mode 100644 search_providers/tavily_provider.py delete mode 100644 search_providers/trusted_news_sources.json delete mode 100644 skills/news_skill.py delete mode 100644 skills/search_skill.py delete mode 100644 skills/weather_skill.py diff --git a/actions/always_reddy_voice_assistant_nlp/main.py b/actions/always_reddy_voice_assistant_nlp/main.py deleted file mode 100644 index d01b9c7..0000000 --- a/actions/always_reddy_voice_assistant_nlp/main.py +++ /dev/null @@ -1,124 +0,0 @@ -import time -from config_loader import config -from actions.base_action import BaseAction -from utils import to_clipboard, handle_clipboard_image, handle_clipboard_text, add_timestamp_to_message -import prompt -from nlp_manager import get_nlp_context -class AlwaysReddyVoiceAssistant_nlp(BaseAction): - """Action for handling voice assistant functionality.""" - def setup(self): - self.last_message_was_cut_off = False - - if config.RECORD_WITH_NLP_HOTKEY: - self.AR.add_action_hotkey( - config.RECORD_WITH_NLP_HOTKEY, - pressed=self.handle_default_assistant_response, - held_release=self.handle_default_assistant_response, - double_tap=self.AR.save_clipboard_text - ) - - print(f"'{config.RECORD_WITH_NLP_HOTKEY}': Start/stop talking to voice assistant (press to toggle on and off, or hold and release)") - if "+" in config.RECORD_WITH_NLP_HOTKEY: - hotkey_start, hotkey_end = config.RECORD__WITH_NLP_HOTKEY.rsplit("+", 1) - print(f"\tHold down '{hotkey_start}' and double tap '{hotkey_end}' to send clipboard content to AlwaysReddy") - else: - print(f"\tDouble tap '{config.RECORD_WITH_NLP_HOTKEY}' to send clipboard content to AlwaysReddy") - - if config.NEW_CHAT_HOTKEY: - self.AR.add_action_hotkey(config.NEW_CHAT_HOTKEY, pressed=self.new_chat) - print(f"'{config.NEW_CHAT_HOTKEY}': New chat for voice assistant") - - self.messages = prompt.build_initial_messages(config.ACTIVE_PROMPT) - - def handle_default_assistant_response(self): - """Handle the response from the transcription and generate a completion.""" - try: - recording_filename = self.AR.toggle_recording(self.handle_default_assistant_response) - if not recording_filename: - return - message = self.AR.transcription_manager.transcribe_audio(recording_filename) - - if not self.AR.stop_action and message: - print("\nTranscript:\n", message) - - if len(self.messages) > 0 and self.messages[0]["role"] == "system": - self.messages[0]["content"] = prompt.get_system_prompt_message(config.ACTIVE_PROMPT) - if self.last_message_was_cut_off: - message = "--> USER CUT THE ASSISTANT'S LAST MESSAGE SHORT <--\n" + message - - new_message = {"role": "user", "content": message} - start_time = time.time() - - # Handle potential NLP context requests for weather, news and search - nlp_context = get_nlp_context(self.AR,message) - if nlp_context: - new_message['content'] = nlp_context - - # Handle clipboard image - clipboard_image_content = handle_clipboard_image(self.AR, message) - if clipboard_image_content: - new_message['content'] = clipboard_image_content - else: - # Handle clipboard text - new_message['content'] = handle_clipboard_text(self.AR, new_message['content']) - - # Add timestamp if configured - if config.TIMESTAMP_MESSAGES: - new_message['content'] = add_timestamp_to_message(new_message['content']) - - self.messages.append(new_message) - - if self.AR.stop_action: - return - - # Ensure there's at least one message - if not self.messages: - print("Error: No messages to send to the API.") - return - - stream = self.AR.completion_client.get_completion_stream( - self.messages, - config.COMPLETION_MODEL, - **config.COMPLETION_PARAMS - ) - - end_time = time.time() - if self.AR.verbose: print(f"Execution time: {end_time - start_time:.1f} seconds") - response = self.AR.completion_client.process_text_stream( - stream, - marker_tuples=[(config.CLIPBOARD_TEXT_START_SEQ, config.CLIPBOARD_TEXT_END_SEQ, to_clipboard)], - sentence_callback=self.AR.tts.run_tts - ) - - while self.AR.tts.running_tts: - time.sleep(0.001) - - if not response: - if self.AR.verbose: - print("No response generated.") - self.messages = self.messages[:-1] - return - - self.last_message_was_cut_off = False - - if self.AR.stop_action: - index = response.rfind(self.AR.tts.last_sentence_spoken) - if index != -1: - response = response[:index + len(self.AR.tts.last_sentence_spoken)] - self.last_message_was_cut_off = True - - self.messages.append({"role": "assistant", "content": response}) - print("\nResponse:\n", response) - - except Exception as e: - print(f"An error occurred in handle_default_assistant_response: {e}") - if self.AR.verbose: - import traceback - traceback.print_exc() - - def new_chat(self): - """Clear the message history and start a new chat session.""" - self.messages = prompt.build_initial_messages(config.ACTIVE_PROMPT) - self.last_message_was_cut_off = False - self.AR.last_clipboard_text = None - print("New chat session started.") \ No newline at end of file diff --git a/config_default.py b/config_default.py index 4f2130c..ebc09d8 100644 --- a/config_default.py +++ b/config_default.py @@ -131,9 +131,3 @@ CANCEL_SOUND_VOLUME = 0.09 MAX_RECORDING_DURATION= 600 # If you record for more than 10 minutes, the recording will stop automatically -### WEATHER,NEWS AND SEARCH SETTINGS ### -DEFAULT_UNITS = 'metric' #metric or imperial (or standard if you want your weather in Kelvin) -DEFAULT_LOCATION = '' -#Search and News providers can be Bing, Brave, Exa or Tavily, and don't need to be the same -SEARCH_PROVIDER = '' -NEWS_PROVIDER = '' \ No newline at end of file diff --git a/nlp_manager.py b/nlp_manager.py deleted file mode 100644 index 98efe32..0000000 --- a/nlp_manager.py +++ /dev/null @@ -1,212 +0,0 @@ -import spacy -from config_loader import config -from typing import List, Dict, Any, Tuple, Optional -from skills.weather_skill import WeatherSkill -from skills.news_skill import NewsSkill -from skills.search_skill import SearchSkill - -class NLPManager: - def __init__(self): - # Load spaCy model - """ - Initialize the NLPManager. - - This method loads the spaCy model and stores it in the class instance. - It also sets up the intent keywords and initializes the entity extractor. - """ - try: - self.nlp = spacy.load('en_core_web_md') - except OSError: - # Download if not available - spacy.cli.download('en_core_web_md') - self.nlp = spacy.load('en_core_web_md') - - # Intent keywords - self.intents = { - 'weather': ['weather', 'temperature', 'forecast', 'rain', 'snow', 'sunny', 'cloudy', 'humidity', 'heat', 'cold', 'hot', - 'wind', 'storm', 'precipitation', 'celsius', 'fahrenheit', 'degrees', 'sunrise', 'sunset'], - 'news': ['news', 'update', 'latest', 'current', 'happening', 'event', 'story', 'article', 'report'], - 'search': ['search', 'find', 'look up', 'lookup', 'research', 'information about', 'tell me about', 'what is', 'who is'] - } - - self.entity_extractor = EntityExtractor(self.nlp) - - def preprocess_text(self, text: str) -> List[str]: - # Process text using spaCy - """ - Process text using spaCy. - - This method takes a string of text and preprocesses it using the spaCy library. - It first converts the text to lowercase and then processes it using the - English language model. - - The method then filters out the tokens using spaCy's built-in attributes: - - is_stop: Stop words (e.g. "the", "a", etc.) - - is_punct: Punctuation (e.g. periods, commas, etc.) - - len(token.text) > 1: Tokens with more than one character - - token.text.isalpha(): Tokens that are alphabetic - - The method returns a list of lemmatized tokens (root words). - - :param text: The text to preprocess - :type text: str - :return: A list of lemmatized tokens - :rtype: List[str] - """ - doc = self.nlp(text.lower()) - - # Filter tokens using spaCy's built-in attributes - processed_tokens = [ - token.lemma_ for token in doc - if not token.is_stop - and not token.is_punct - and len(token.text) > 1 - and token.text.isalpha() - ] - - return processed_tokens - - def classify_intent(self, text: str) -> Tuple[str, float, Optional[Dict[str, Any]]]: - """ - Classify the intent of a given text. - - This method takes a string of text and classifies its intent into one of the - following categories: information_query, weather, news, search. - - The method preprocesses the text using the preprocess_text method and then - checks if any of the intent keywords are present in the tokens. If a keyword - is found, it calculates the confidence of the classification by dividing the - number of keywords found by the total number of tokens. If no keyword is found, - it defaults to an information_query intent with a confidence of 0.3. - - :param text: The text to classify - :type text: str - :return: A tuple containing the intent, confidence, and optional entity dictionary - :rtype: Tuple[str, float, Optional[Dict[str, Any]]] - """ - tokens = self.preprocess_text(text) - - # Intent classification - for intent, keywords in self.intents.items(): - # Special handling for multi-word keywords - text_lower = text.lower() - keyword_matches = sum(1 for keyword in keywords if keyword in text_lower) - if keyword_matches > 0: - confidence = keyword_matches / len(tokens) - return intent, confidence, None - - return 'information_query', 0.3, None - - -class EntityExtractor: - def __init__(self, nlp): - """ - Initialize the EntityExtractor. - - This method sets up the entity extractor with the provided spaCy NLP model. - - :param nlp: The spaCy NLP model used for entity extraction. - :type nlp: spacy.language.Language - """ - self.nlp = nlp - - def extract_entities(self, text: str) -> Dict[str, List[str]]: - doc = self.nlp(text) - entities = {} - for ent in doc.ents: - if ent.label_ not in entities: - entities[ent.label_] = [] - entities[ent.label_].append(ent.text) - return entities - -def get_nlp_context(AR, message_content: str) -> str: - """ - Add contextual information to a message based on the user's intent. - - If the intent is 'weather', add information about the user's preferred unit of measurement - for weather information, and the user's location if it can be detected. - - If the intent is 'news', add news search results related to the query. - - If the intent is 'search', add general search results related to the query. - - :param AR: The AlwaysReddy object - :param message_content: The message content - :return: The message content with additional context - :rtype: str - """ - nlp_manager = NLPManager() - intent, confidence, additional_info = nlp_manager.classify_intent(message_content) - - if intent == 'weather': - weather_skill = WeatherSkill() - message_content += f"\n\nTHE USER APPEARS TO HAVE A QUESTION ABOUT THE WEATHER, USE THIS DATA TO HELP YOU ANSWER IT:\n```{weather_skill.getLocationWeather(message_content)}```" - unit = str(config.DEFAULT_UNITS) or 'METRIC' - message_content += f"\n\nTHE USER PREFERS TO RECEIVE WEATHER INFORMATION IN A {unit} FORMAT. DO NOT ABBREVIATE UNITS OF MEASURE eg. USE 'CELSIUS' INSTEAD OF 'C'. KEEP THE FORECAST SOUNDING NATURAL AND NOT ROBOTIC." - - elif intent == 'news': - news_skill = NewsSkill(provider_type=config.NEWS_PROVIDER) - news_results = news_skill.search_news(message_content) - if 'error' not in news_results: - message_content += f"\n\nTHE USER APPEARS TO HAVE A QUESTION ABOUT NEWS, USE THIS DATA TO HELP YOU ANSWER IT:\n```{news_results}```" - message_content += "\n\nPLEASE PROVIDE A NATURAL SUMMARY OF THE NEWS, INCLUDING THE MOST RELEVANT AND RECENT INFORMATION. PROVIDE 5 SENTENCES OF DETAIL." - - elif intent == 'search': - search_skill = SearchSkill(provider_type=config.SEARCH_PROVIDER) - search_results = search_skill.search(message_content) - if 'error' not in search_results: - message_content += f"\n\nTHE USER APPEARS TO BE SEARCHING FOR INFORMATION, USE THIS DATA TO HELP YOU ANSWER IT:\n```{search_results}```" - message_content += "\n\nPLEASE PROVIDE A CLEAR AND CONCISE ANSWER BASED ON THE SEARCH RESULTS AND YOUR EXISTING KNOWLEDGE, FOCUSING ON THE MOST RELEVANT INFORMATION TO ANSWER THEIR QUESTION." - - return message_content - - - -def main(): - classifier = NLPManager() - - test_queries = [ - "What is the best restaurant in town?", - "Book a flight to New York", - "Compare iPhone and Android phones", - "Recommend a good book to read", - "What's the weather like in Buffalo today?", - "Will it rain tomorrow in New York?", - "How's the temperature in Los Angeles this weekend?", - "How's the weather in Los Angeles this weekend?", - "What is the weather in Miami?", - "Describe the wind conditions in Los Angeles for the rest of the week.", - "Give me a weather update for New York", - "Search for quantum computing", - "Look up the history of Rome", - "Find information about electric cars" - ] - - import time - for query in test_queries: - #start_time = time.time() - intent, confidence, additional_info = classifier.classify_intent(query) - entities = classifier.entity_extractor.extract_entities(query) - print(f"Query: '{query}'") - print(f"Intent: {intent}") - print(f"Confidence: {confidence:.2f}") - print(f"Entities: {entities}") - if additional_info: - print("Additional Information:") - for key, value in additional_info.items(): - print(f" {key}: {value}") - print() - start_time = time.time() - if intent == 'weather': - weather_skill = WeatherSkill() - print(str(weather_skill.getLocationWeather(query))[:200]) - elif intent == 'search': - search_skill = SearchSkill() - print(str(search_skill.search(query))[:200]) - print() - end_time = time.time() - print(f"Time taken: {end_time - start_time} seconds") - print() - -if __name__ == "__main__": - main() diff --git a/search_providers/__init__.py b/search_providers/__init__.py deleted file mode 100644 index a9450b3..0000000 --- a/search_providers/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .base_provider import BaseSearchProvider -from .tavily_provider import TavilySearchProvider -from .factory import SearchProviderFactory - -__all__ = ['BaseSearchProvider', 'TavilySearchProvider', 'SearchProviderFactory'] diff --git a/search_providers/base_provider.py b/search_providers/base_provider.py deleted file mode 100644 index e98942e..0000000 --- a/search_providers/base_provider.py +++ /dev/null @@ -1,42 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Dict, Any, Optional - -class BaseSearchProvider(ABC): - """ - Abstract base class for search providers. - All search providers must implement these methods. - """ - - @abstractmethod - def __init__(self, api_key: Optional[str] = None): - """ - Initialize the search provider. - - Args: - api_key: Optional API key for the search provider - """ - pass - - @abstractmethod - def search(self, query: str, **kwargs) -> Dict[str, Any]: - """ - Perform a search using the provider. - - Args: - query: The search query string - **kwargs: Additional search parameters specific to the provider - - Returns: - Dict containing the search results or error information - """ - pass - - @abstractmethod - def is_configured(self) -> bool: - """ - Check if the provider is properly configured (e.g., has valid API key). - - Returns: - bool indicating if the provider is ready to use - """ - pass diff --git a/search_providers/bing_provider.py b/search_providers/bing_provider.py deleted file mode 100644 index 2fe76ad..0000000 --- a/search_providers/bing_provider.py +++ /dev/null @@ -1,200 +0,0 @@ -from typing import Dict, Any, Optional -import os -import sys -from pathlib import Path -import requests -from datetime import datetime, timedelta -import json - -# Add parent directory to path for imports when running as script -if __name__ == "__main__": - sys.path.append(str(Path(__file__).parent.parent)) - from search_providers.base_provider import BaseSearchProvider -else: - from .base_provider import BaseSearchProvider - -class BingSearchProvider(BaseSearchProvider): - """ - Bing implementation of the search provider interface. - Handles both web and news-specific searches using Bing's APIs. - """ - - WEB_SEARCH_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search" - NEWS_SEARCH_ENDPOINT = "https://api.bing.microsoft.com/v7.0/news/search" - - def __init__(self, api_key: Optional[str] = None): - """ - Initialize the Bing search provider. - - Args: - api_key: Optional Bing API key. If not provided, will try to get from environment. - """ - self.api_key = api_key or os.getenv("BING_API_KEY") - self.headers = { - 'Ocp-Apim-Subscription-Key': self.api_key, - 'Accept': 'application/json' - } if self.api_key else None - - # Load trusted news sources - self.trusted_sources = self._load_trusted_sources() - - def _load_trusted_sources(self) -> list: - """Load first 5 trusted news sources from JSON file.""" - try: - json_path = Path(__file__).parent / "trusted_news_sources.json" - with open(json_path) as f: - data = json.load(f) - # Only load the first 16 sources as per MSFT limits - return data.get("trusted_sources", [])[:16] - except Exception as e: - print(f"Warning: Could not load trusted news sources: {e}") - return [] - - def is_configured(self) -> bool: - """Check if Bing API is properly configured.""" - return self.headers is not None - - def search(self, query: str, **kwargs) -> Dict[str, Any]: - """ - Perform a search using Bing API. - - Args: - query: The search query string - **kwargs: Additional search parameters: - - topic: Optional search topic (e.g., "news") - - max_results: Maximum number of results (default: 10) - - market: Market code (default: "en-US") - - days: Number of days to look back (for news searches) - - Returns: - Dict containing search results or error information - """ - if not self.is_configured(): - return {'error': 'Bing API key not configured'} - - try: - # Set default search parameters - search_params = { - 'count': str(kwargs.get('max_results', 10)), # Changed default from 5 to 10 - 'mkt': kwargs.get('market', 'en-US'), - 'textFormat': 'Raw' - } - - # Determine if this is a news search - if kwargs.get('topic') == 'news': - # Add freshness parameter for news if days specified - if 'days' in kwargs: - # Bing API expects 'day', 'week', or 'month' - search_params['freshness'] = 'week' if kwargs['days'] >1 else 'day' - - # Add site: operators for trusted sources - if self.trusted_sources: - site_operators = " OR ".join(f'site:{source}' for source in self.trusted_sources) - search_params['q'] = f"({query}) ({site_operators})" - else: - search_params['q'] = f"latest headlines about the topic: {query}" - - response = requests.get( - self.NEWS_SEARCH_ENDPOINT, - headers=self.headers, - params=search_params - ) - else: - search_params['q'] = query - response = requests.get( - self.WEB_SEARCH_ENDPOINT, - headers=self.headers, - params=search_params - ) - - if response.status_code != 200: - return {'error': f'API request failed with status {response.status_code}: {response.text}'} - - response_data = response.json() - - # Process results based on search type - if kwargs.get('topic') == 'news': - return self._process_news_results( - response_data, - days=kwargs.get('days', 3), - topic=query - ) - else: - return self._process_general_results(response_data) - - except requests.exceptions.RequestException as e: - return {'error': f'API request failed: {str(e)}'} - except Exception as e: - return {'error': f'An unexpected error occurred: {str(e)}'} - - def _process_general_results(self, response: Dict[str, Any]) -> Dict[str, Any]: - """Process results for general web searches.""" - webpages = response.get('webPages', {}).get('value', []) - return { - 'results': [{ - 'title': result.get('name', ''), - 'url': result.get('url', ''), - 'content': result.get('snippet', ''), - 'score': 1.0 # Bing doesn't provide relevance scores - } for result in webpages[:10]] # Changed from 3 to 10 - } - - def _process_news_results(self, response: Dict[str, Any], days: int, topic: str) -> Dict[str, Any]: - """Process results for news-specific searches.""" - articles = response.get('value', []) - return { - 'articles': [{ - 'title': article.get('name', ''), - 'url': article.get('url', ''), - 'published_date': article.get('datePublished', ''), - 'content': article.get('description', ''), - 'score': 1.0 # Bing doesn't provide relevance scores - } for article in articles], - 'time_period': f"Past {days} days", - 'topic': topic - } - -if __name__ == "__main__": - # Test code using actual API - provider = BingSearchProvider() - if not provider.is_configured(): - print("Error: Bing API key not configured") - exit(1) - - # Print loaded trusted sources - print("\n=== Loaded Trusted Sources ===") - print(provider.trusted_sources) - - # Test general search - print("\n=== Testing General Search ===") - general_result = provider.search( - "What is artificial intelligence?", - max_results=10 # Changed from 3 to 10 - ) - - if 'error' in general_result: - print(f"Error in general search: {general_result['error']}") - else: - print("\nTop Results:") - for idx, result in enumerate(general_result['results'], 1): - print(f"\n{idx}. {result['title']}") - print(f" URL: {result['url']}") - print(f" Preview: {result['content'][:400]}...") - - # Test news search - print("\n\n=== Testing News Search ===") - news_result = provider.search( - "mike tyson fight", - topic="news", - days=3 - ) - - if 'error' in news_result: - print(f"Error in news search: {news_result['error']}") - else: - print("\nRecent Articles:") - for idx, article in enumerate(news_result['articles'], 1): - print(f"\n{idx}. {article['title']}") - print(f" Published: {article['published_date']}") - print(f" URL: {article['url']}") - print(f" Preview: {article['content'][:400]}...") diff --git a/search_providers/brave_provider.py b/search_providers/brave_provider.py deleted file mode 100644 index cca0a76..0000000 --- a/search_providers/brave_provider.py +++ /dev/null @@ -1,308 +0,0 @@ -from typing import Dict, Any, Optional -import os -import sys -from pathlib import Path -import requests -from datetime import datetime, timedelta -import json -from concurrent.futures import ThreadPoolExecutor - -# Add parent directory to path for imports when running as script -if __name__ == "__main__": - sys.path.append(str(Path(__file__).parent.parent)) - from search_providers.base_provider import BaseSearchProvider -else: - from .base_provider import BaseSearchProvider - -class BraveSearchProvider(BaseSearchProvider): - """ - Brave implementation of the search provider interface. - Handles both web and news-specific searches using Brave's APIs. - """ - - WEB_SEARCH_ENDPOINT = "https://api.search.brave.com/res/v1/web/search" - NEWS_SEARCH_ENDPOINT = "https://api.search.brave.com/res/v1/news/search" - SUMMARIZER_ENDPOINT = "https://api.search.brave.com/res/v1/summarizer/search" - - def __init__(self, api_key: Optional[str] = None): - """ - Initialize the Brave search provider. - - Args: - api_key: Optional Brave API key. If not provided, will try to get from environment. - """ - self.api_key = api_key or os.getenv("BRAVE_AI_API_KEY") - self.pro_api_key = os.getenv("BRAVE_AI_PRO_API_KEY") #Optional, used for AI summary requests - self.headers = { - 'X-Subscription-Token': self.api_key, - 'Accept': 'application/json' - } if self.api_key else None - self.proheaders = { - 'X-Subscription-Token': self.pro_api_key, - 'Accept': 'application/json' - } if self.pro_api_key else None - def is_configured(self) -> bool: - """Check if Brave API is properly configured.""" - return self.headers is not None - - def get_brave_summary(self, query): - # Query parameters - params = { - "q": query, - "summary": 1 - } - - # Make the initial web search request to get summarizer key - search_response = requests.get(self.WEB_SEARCH_ENDPOINT, headers=self.proheaders, params=params) - - if search_response.status_code == 200: - data = search_response.json() - - if "summarizer" in data and "key" in data["summarizer"]: - summarizer_key = data["summarizer"]["key"] - - # Make request to summarizer endpoint - summarizer_params = { - "key": summarizer_key, - "entity_info": 1 - } - - summary_response = requests.get( - self.SUMMARIZER_ENDPOINT, - headers=self.proheaders, - params=summarizer_params - ) - - if summary_response.status_code == 200: - summary_data = summary_response.json() - try: - return summary_data['summary'][0]['data'] - except (KeyError, IndexError): - return None - - return None - - def search(self, query: str, **kwargs) -> Dict[str, Any]: - """ - Perform a search using Brave API. - - Args: - query: The search query string - **kwargs: Additional search parameters: - - topic: Optional search topic (e.g., "news") - - max_results: Maximum number of results (default: 10) - - market: Market code (default: "en-US") - - days: Number of days to look back (for news searches) - - Returns: - Dict containing search results or error information - """ - if not self.is_configured(): - return {'error': 'Brave API key not configured'} - - try: - # Set default search parameters - search_params = { - 'count': str(kwargs.get('max_results', 10)), - 'country': kwargs.get('market', 'us'), # Brave uses country code - 'q': query - } - - # Determine if this is a news search - if kwargs.get('topic') == 'news': - # Add freshness parameter for news if days specified - if 'days' in kwargs: - days = kwargs['days'] - if days <= 1: - search_params['freshness'] = 'pd' # past day - elif days <= 7: - search_params['freshness'] = 'pw' # past week - else: - search_params['freshness'] = 'pm' # past month - - response = requests.get( - self.NEWS_SEARCH_ENDPOINT, - headers=self.headers, - params=search_params - ) - - response_data = response.json() - result = self._process_news_results(response_data, days=kwargs.get('days', 3), topic=query) - else: - response = requests.get( - self.WEB_SEARCH_ENDPOINT, - headers=self.headers, - params=search_params - ) - response_data = response.json() - result = self._process_general_results(response_data) - - # Include summarizer response if it exists - summary_response = self.get_brave_summary(query) - if summary_response: - result['summarizer'] = summary_response - - return result - - except requests.exceptions.RequestException as e: - return {'error': f'API request failed: {str(e)}'} - except Exception as e: - return {'error': f'An unexpected error occurred: {str(e)}'} - - def _process_general_results(self, response: Dict[str, Any]) -> Dict[str, Any]: - """Process results for general web searches.""" - web_results = response.get('web', {}).get('results', []) - with ThreadPoolExecutor() as executor: - # Use index as key instead of the result dictionary - futures = {i: executor.submit(self.get_brave_summary, result.get('title', '')) - for i, result in enumerate(web_results[:2])} - - results = [] - for i, result in enumerate(web_results): - summary = None - if i < 2: - try: - summary = futures[i].result() - except Exception as e: - print(f"Error getting summary: {e}") - - processed_result = { - 'title': result.get('title', ''), - 'url': result.get('url', ''), - 'content': result.get('description', ''), - 'score': result.get('score', 1.0), - 'extra_snippets': None, - 'summary': None - } - if summary: - processed_result['summary'] = summary - else: - processed_result['extra_snippets'] = result.get('extra_snippets', []) - results.append(processed_result) - return {'results': results} - - def _process_news_results(self, response: Dict[str, Any], days: int, topic: str) -> Dict[str, Any]: - """Process results for news-specific searches.""" - news_results = response.get('results', []) - def convert_age_to_minutes(age_str: str) -> int: - """ - Convert age string to minutes. - - Args: - age_str: Age string in the format of "X minutes", "X hours", "X days" - - Returns: - Age in minutes - """ - age_value = int(age_str.split()[0]) - age_unit = age_str.split()[1] - if age_unit == 'minutes': - return age_value - elif age_unit == 'hours': - return age_value * 60 - elif age_unit == 'days': - return age_value * 1440 # 24 hours * 60 minutes - else: - return 0 # Default to 0 if unknown unit - - # Sort news results based on the age field - news_results.sort(key=lambda x: convert_age_to_minutes(x.get('age', '0 minutes'))) - - with ThreadPoolExecutor() as executor: - # Use enumerate to create futures with index as key - futures = {i: executor.submit(self.get_brave_summary, article_data.get('title', '')) - for i, article_data in enumerate(news_results)} - - articles = [] - for i, article_data in enumerate(news_results): - try: - summary = futures[i].result() - except Exception as e: - print(f"Error getting summary: {e}") - summary = None - - article = { - 'title': article_data.get('title', ''), - 'url': article_data.get('url', ''), - 'published_date': article_data.get('age', ''), - 'breaking' : article_data.get('breaking', False), - 'content': article_data.get('description', ''), - 'extra_snippets': None, - 'summary': None, - 'score': article_data.get('score', 1.0) - } - if summary: - article['summary'] = summary - else: - article['extra_snippets'] = article_data.get('extra_snippets', []) - articles.append(article) - - return { - 'articles': articles, - 'time_period': f"Past {days} days", - 'topic': topic - } - -if __name__ == "__main__": - # Test code using actual API - provider = BraveSearchProvider() - if not provider.is_configured(): - print("Error: Brave API key not configured") - exit(1) - - # Test general search - print("\n=== Testing General Search ===") - general_result = provider.search( - "What is artificial intelligence?", - max_results=1 # Increased max_results to test summary limiting - ) - - if 'error' in general_result: - print(f"Error in general search: {general_result['error']}") - else: - print("\nTop Results:") - for idx, result in enumerate(general_result['results'], 1): - print(f"\n{idx}. {result['title']}") - print(f" URL: {result['url']}") - print(f" Preview: {result['content']}...") - print(f" Score: {result['score']}") - if result['extra_snippets']: - print(" Extra Snippets:") - for snippet in result['extra_snippets']: - print(f" - {snippet}") - if result['summary']: # Check if summary exists before printing - print(f" Summary: {result.get('summary', '')}...") - import time - time.sleep(1) - - # Test news search - print("\n\n=== Testing News Search ===") - import time - start_time = time.time() - news_result = provider.search( - "mike tyson fight", - topic="news", - days=3, - max_results=1 - ) - end_time = time.time() - - - if 'error' in news_result: - print(f"Error in news search: {news_result['error']}") - else: - print("\nRecent Articles:") - for idx, article in enumerate(news_result['articles'], 1): - print(f"\n{idx}. {article['title']}") - print(f" Published: {article['published_date']}") - print(f" Breaking: {article['breaking']}") - print(f" URL: {article['url']}") - print(f" Preview: {article['content'][:400]}...") - if article['extra_snippets']: - print(" Extra Snippets:") - for snippet in article['extra_snippets']: - print(f" - {snippet}") - if article['summary']: - print(f" Summary: {article.get('summary', '')}...") - - print(f"Execution time: {round(end_time - start_time, 1)} seconds") diff --git a/search_providers/exa_provider.py b/search_providers/exa_provider.py deleted file mode 100644 index a20404b..0000000 --- a/search_providers/exa_provider.py +++ /dev/null @@ -1,231 +0,0 @@ -from typing import Dict, Any, Optional -import os -import sys -import json -from pathlib import Path -import requests -from datetime import datetime, timedelta - -# Add parent directory to path for imports when running as script -if __name__ == "__main__": - sys.path.append(str(Path(__file__).parent.parent)) - from search_providers.base_provider import BaseSearchProvider -else: - from .base_provider import BaseSearchProvider - -class ExaSearchProvider(BaseSearchProvider): - """ - Exa.ai implementation of the search provider interface. - Handles web searches with optional full page content retrieval. - """ - - def __init__(self, api_key: Optional[str] = None): - """ - Initialize the Exa search provider. - - Args: - api_key: Optional Exa API key. If not provided, will try to get from environment. - """ - self.api_key = api_key or os.getenv("EXA_API_KEY") - self.base_url = "https://api.exa.ai/search" - self.trusted_sources = self._load_trusted_sources() - - def _load_trusted_sources(self) -> list: - """Load trusted news sources from JSON file.""" - try: - json_path = Path(__file__).parent / 'trusted_news_sources.json' - with open(json_path) as f: - data = json.load(f) - return data.get('trusted_sources', []) - except Exception as e: - print(f"Warning: Could not load trusted sources: {e}") - return [] - - def is_configured(self) -> bool: - """Check if Exa client is properly configured.""" - return bool(self.api_key) - - def search(self, query: str, **kwargs) -> Dict[str, Any]: - """ - Perform a search using Exa API. - - Args: - query: The search query string - **kwargs: Additional search parameters: - - include_content: Whether to retrieve full page contents (default: False) - - max_results: Maximum number of results (default: 3) - - days: Number of days to look back (for news searches) - - Returns: - Dict containing search results or error information - """ - if not self.is_configured(): - return {'error': 'Exa API key not configured'} - - try: - # Set default search parameters - search_params = { - 'query': query, - 'type': 'neural', - 'useAutoprompt': True, - 'numResults': kwargs.get('max_results', 3), - } - - # Add optional parameters - if kwargs.get('include_content'): - search_params['contents'] = { - "highlights": True, - "summary": True - } - - if kwargs.get('days'): - # Convert days to timestamp for time-based filtering - date_limit = datetime.now() - timedelta(days=kwargs['days']) - search_params['startPublishedTime'] = date_limit.isoformat() - - # Add trusted domains for news searches - if kwargs.get('topic') == 'news' and self.trusted_sources: - search_params['includeDomains'] = self.trusted_sources - - # Make API request - headers = { - 'x-api-key': self.api_key, - 'Content-Type': 'application/json', - 'accept': 'application/json' - } - - response = requests.post( - self.base_url, - headers=headers, - json=search_params - ) - response.raise_for_status() - data = response.json() - - # Process results based on whether it's a news search - if kwargs.get('topic') == 'news': - return self._process_news_results( - data, - days=kwargs.get('days', 3), - topic=query - ) - else: - return self._process_general_results(data) - - except requests.exceptions.RequestException as e: - if e.response and e.response.status_code == 401: - return {'error': 'Invalid Exa API key'} - elif e.response and e.response.status_code == 429: - return {'error': 'Exa API rate limit exceeded'} - else: - return {'error': f'An error occurred while making the request: {str(e)}'} - except Exception as e: - return {'error': f'An unexpected error occurred: {str(e)}'} - - def _process_general_results(self, response: Dict[str, Any]) -> Dict[str, Any]: - """Process results for general searches.""" - results = [] - for result in response.get('results', []): - processed_result = { - 'title': result.get('title', ''), - 'url': result.get('url', ''), - 'highlights': result.get('highlights', []), - 'summary': result.get('summary', ''), - 'score': result.get('score', 0.0) - } - results.append(processed_result) - - return { - 'results': results, - 'autoprompt': response.get('autopromptString', '') - } - - def _process_news_results(self, response: Dict[str, Any], days: int, topic: str) -> Dict[str, Any]: - """Process results for news-specific searches.""" - articles = [] - for article in response.get('results', []): - processed_article = { - 'title': article.get('title', ''), - 'url': article.get('url', ''), - 'published_date': article.get('publishedDate', ''), - 'highlights': article.get('highlights', []), - 'summary': article.get('summary', ''), - 'score': article.get('score', 0.0) - } - articles.append(processed_article) - - return { - 'articles': articles, - 'time_period': f"Past {days} days", - 'topic': topic, - 'autoprompt': response.get('autopromptString', '') - } - -if __name__ == "__main__": - # Test code for the Exa provider - provider = ExaSearchProvider() - if not provider.is_configured(): - print("Error: Exa API key not configured") - exit(1) - - # Test general search - print("\n=== Testing General Search ===") - import time - start_time = time.time() - general_result = provider.search( - "What is artificial intelligence?", - max_results=3, - include_content=True - ) - end_time = time.time() - - if 'error' in general_result: - print("Error:", general_result['error']) - else: - print("\nTop Results:") - print(f"Autoprompt: {general_result.get('autoprompt', '')}") - for idx, result in enumerate(general_result['results'], 1): - print(f"\n{idx}. {result['title']}") - print(f" URL: {result['url']}") - print(f" Score: {result['score']}") - print(f" Summary: {result['summary']}") - if result['highlights']: - print(" Highlights:") - for highlight in result['highlights']: - print(f" - {highlight}") - print(f"\n\nTime taken for general search: {end_time - start_time} seconds") - - # Test news search - print("\n\n=== Testing News Search ===") - start_time = time.time() - news_result = provider.search( - "Latest developments in AI", - topic="news", - days=3, - max_results=3, - include_content=True - ) - end_time = time.time() - - if 'error' in news_result: - print("Error:", news_result['error']) - else: - print("\nRecent Articles:") - print(f"Autoprompt: {news_result.get('autoprompt', '')}") - for idx, article in enumerate(news_result['articles'], 1): - print(f"\n{idx}. {article['title']}") - print(f" Published: {article['published_date']}") - print(f" URL: {article['url']}") - print(f" Score: {article['score']}") - print(f" Summary: {article['summary']}") - if article['highlights']: - print(" Highlights:") - for highlight in article['highlights']: - print(f" - {highlight}") - print(f"\n\nTime taken for news search: {end_time - start_time} seconds") - - # Test error handling - print("\n\n=== Testing Error Handling ===") - bad_provider = ExaSearchProvider(api_key="invalid_key") - error_result = bad_provider.search("test query") - print("\nExpected error with invalid API key:", error_result['error']) diff --git a/search_providers/factory.py b/search_providers/factory.py deleted file mode 100644 index aaa2eab..0000000 --- a/search_providers/factory.py +++ /dev/null @@ -1,78 +0,0 @@ -from typing import Optional, Dict, Type -from .base_provider import BaseSearchProvider -from .tavily_provider import TavilySearchProvider -from .bing_provider import BingSearchProvider -from .brave_provider import BraveSearchProvider -from .exa_provider import ExaSearchProvider - -class SearchProviderFactory: - """ - Factory class for creating search provider instances. - Supports multiple provider types and handles provider configuration. - """ - - # Registry of available providers - _providers: Dict[str, Type[BaseSearchProvider]] = { - 'tavily': TavilySearchProvider, - 'bing': BingSearchProvider, - 'brave': BraveSearchProvider, - 'exa': ExaSearchProvider - } - - @classmethod - def get_provider( - cls, - provider_type: Optional[str] = None, - api_key: Optional[str] = None - ) -> BaseSearchProvider: - """ - Get an instance of the specified search provider. - - Args: - provider_type: Type of provider to create (defaults to config.SEARCH_PROVIDER) - api_key: Optional API key for the provider - - Returns: - An instance of the specified provider - - Raises: - ValueError: If the specified provider type is not supported - """ - # Import config here to avoid circular imports - from config import SEARCH_PROVIDER - - # Use provided provider_type or fall back to config - provider_type = provider_type or SEARCH_PROVIDER - - provider_class = cls._providers.get(provider_type.lower()) - if not provider_class: - raise ValueError( - f"Unsupported provider type: {provider_type}. " - f"Available providers: {', '.join(cls._providers.keys())}" - ) - - return provider_class(api_key=api_key) - - @classmethod - def register_provider( - cls, - provider_type: str, - provider_class: Type[BaseSearchProvider] - ) -> None: - """ - Register a new provider type. - - Args: - provider_type: Name of the provider type - provider_class: Provider class that implements BaseSearchProvider - - Raises: - TypeError: If provider_class doesn't inherit from BaseSearchProvider - """ - if not issubclass(provider_class, BaseSearchProvider): - raise TypeError( - f"Provider class must inherit from BaseSearchProvider. " - f"Got {provider_class.__name__}" - ) - - cls._providers[provider_type.lower()] = provider_class diff --git a/search_providers/tavily_provider.py b/search_providers/tavily_provider.py deleted file mode 100644 index 043ef94..0000000 --- a/search_providers/tavily_provider.py +++ /dev/null @@ -1,160 +0,0 @@ -from typing import Dict, Any, Optional -import os -import sys -from pathlib import Path - -# Add parent directory to path for imports when running as script -if __name__ == "__main__": - sys.path.append(str(Path(__file__).parent.parent)) - from search_providers.base_provider import BaseSearchProvider -else: - from .base_provider import BaseSearchProvider - -from tavily import TavilyClient, MissingAPIKeyError, InvalidAPIKeyError, UsageLimitExceededError - -class TavilySearchProvider(BaseSearchProvider): - """ - Tavily implementation of the search provider interface. - Handles both general and news-specific searches. - """ - - def __init__(self, api_key: Optional[str] = None): - """ - Initialize the Tavily search provider. - - Args: - api_key: Optional Tavily API key. If not provided, will try to get from environment. - """ - self.api_key = api_key or os.getenv("TAVILY_API_KEY") - try: - self.client = TavilyClient(api_key=self.api_key) if self.api_key else None - except MissingAPIKeyError: - self.client = None - - def is_configured(self) -> bool: - """Check if Tavily client is properly configured.""" - return self.client is not None - - def search(self, query: str, **kwargs) -> Dict[str, Any]: - """ - Perform a search using Tavily API. - - Args: - query: The search query string - **kwargs: Additional search parameters: - - search_depth: "basic" or "advanced" (default: "basic") - - topic: Optional search topic (e.g., "news") - - max_results: Maximum number of results (default: 5) - - include_answer: Whether to include AI-generated answer (default: True) - - include_images: Whether to include images (default: False) - - days: Number of days to look back (for news searches) - - Returns: - Dict containing search results or error information - """ - if not self.is_configured(): - return {'error': 'Tavily API key not configured'} - - try: - # Set default search parameters - search_params = { - 'search_depth': "basic", - 'max_results': 5, - 'include_answer': True, - 'include_images': False - } - - # Update with any provided parameters - search_params.update(kwargs) - - # Execute search - response = self.client.search(query, **search_params) - - # Process results based on whether it's a news search - if kwargs.get('topic') == 'news': - return self._process_news_results( - response, - days=kwargs.get('days', 3), - topic=query - ) - else: - return self._process_general_results(response) - - except InvalidAPIKeyError: - return {'error': 'Invalid Tavily API key'} - except UsageLimitExceededError: - return {'error': 'Tavily API usage limit exceeded'} - except Exception as e: - return {'error': f'An unexpected error occurred: {e}'} - - def _process_general_results(self, response: Dict[str, Any]) -> Dict[str, Any]: - """Process results for general searches.""" - return { - 'answer': response.get('answer', ''), - 'results': [{ - 'title': result.get('title', ''), - 'url': result.get('url', ''), - 'content': result.get('content', '')[:500] + '...' if result.get('content') else '', - 'score': result.get('score', 0.0) - } for result in response.get('results', [])[:3]] - } - - def _process_news_results(self, response: Dict[str, Any], days: int, topic: str) -> Dict[str, Any]: - """Process results for news-specific searches.""" - return { - 'answer': response.get('answer', ''), - 'articles': [{ - 'title': article.get('title', ''), - 'url': article.get('url', ''), - 'published_date': article.get('published_date', ''), - 'content': article.get('content', '')[:500] + '...' if article.get('content') else '', - 'score': article.get('score', 0.0) - } for article in response.get('results', [])], - 'time_period': f"Past {days} days", - 'topic': topic - } - -if __name__ == "__main__": - # Test code for the Tavily provider - provider = TavilySearchProvider() - if not provider.is_configured(): - print("Error: Tavily API key not configured") - exit(1) - - # Test general search - print("\n=== Testing General Search ===") - general_result = provider.search( - "What is artificial intelligence?", - search_depth="advanced", - max_results=3 - ) - print("\nQuery Answer:", general_result['answer']) - print("\nTop Results:") - for idx, result in enumerate(general_result['results'], 1): - print(f"\n{idx}. {result['title']}") - print(f" URL: {result['url']}") - print(f" Score: {result['score']}") - print(f" Preview: {result['content'][:200]}...") - - # Test news search - print("\n\n=== Testing News Search ===") - news_result = provider.search( - "Latest developments in AI", - topic="news", - days=3, - search_depth="advanced" - ) - print("\nNews Summary:", news_result['answer']) - print("\nRecent Articles:") - for idx, article in enumerate(news_result['articles'], 1): - print(f"\n{idx}. {article['title']}") - print(f" Published: {article['published_date']}") - print(f" URL: {article['url']}") - print(f" Score: {article['score']}") - print(f" Preview: {article['content'][:400]}...") - - # Test error handling - print("\n\n=== Testing Error Handling ===") - bad_provider = TavilySearchProvider(api_key="invalid_key") - error_result = bad_provider.search("test query") - print("\nExpected error with invalid API key:", error_result['error']) diff --git a/search_providers/trusted_news_sources.json b/search_providers/trusted_news_sources.json deleted file mode 100644 index b5e3c77..0000000 --- a/search_providers/trusted_news_sources.json +++ /dev/null @@ -1,71 +0,0 @@ -{ - "trusted_sources": [ - "apnews.com", - "reuters.com", - "bbc.com", - "wsj.com", - "nytimes.com", - "economist.com", - "bloomberg.com", - "ft.com", - "aljazeera.com", - "afp.com", - "techcrunch.com", - "wired.com", - "arstechnica.com", - "theverge.com", - "cnet.com", - "theguardian.com", - "businessinsider.com", - "dw.com", - "time.com", - "afp.com", - "pbs.org", - "npr.org", - "cnbc.com", - "forbes.com", - "thehill.com", - "politico.com", - "axios.com", - "euronews.com", - "japantimes.co.jp", - "scmp.com", - "straitstimes.com", - "themoscowtimes.com", - "haaretz.com", - "timesofindia.com", - "globeandmail.com", - "abc.net.au", - "rte.ie", - "swissinfo.ch", - "thelocal.fr", - "thelocal.de", - "thelocal.se", - "kyivpost.com", - "arabnews.com", - "koreatimes.co.kr", - "bangkokpost.com", - "zdnet.com", - "cnet.com", - "engadget.com", - "gizmodo.com", - "thenextweb.com", - "venturebeat.com", - "techradar.com", - "tomshardware.com", - "anandtech.com", - "slashdot.org", - "techspot.com", - "phoronix.com", - "404media.co", - "theregister.com", - "techdirt.com", - "techrepublic.com", - "mit.edu", - "protocol.com", - "theinformation.com", - "restofworld.org", - "news.ycombinator.com" - ] - } - \ No newline at end of file diff --git a/skills/news_skill.py b/skills/news_skill.py deleted file mode 100644 index b0e343d..0000000 --- a/skills/news_skill.py +++ /dev/null @@ -1,237 +0,0 @@ -from nltk import word_tokenize, pos_tag -from typing import Dict, Any, Optional -import unittest -from unittest.mock import patch, MagicMock -from datetime import datetime, timedelta -import re -import sys -from pathlib import Path -sys.path.append(str(Path(__file__).parent.parent)) -from search_providers import SearchProviderFactory - -class NewsSkill: - def __init__(self, provider_type: Optional[str] = None): - self.topic_indicators = { - 'about', 'regarding', 'on', 'related to', 'news about', - 'latest', 'recent', 'current', 'today', 'breaking', 'update' - } - # Expanded time indicators with more natural language expressions - self.time_indicators = { - 'today': 1, - 'yesterday': 2, - 'last week': 7, - 'past week': 7, - 'this week': 7, - 'recent': 3, - 'latest': 3, - 'last month': 30, - 'past month': 30, - 'this month': 30, - 'past year': 365, - 'last year': 365, - 'this year': 365, - 'past few days': 3, - 'last few days': 3, - 'past couple days': 2, - 'last couple days': 2, - 'past 24 hours': 1, - 'last 24 hours': 1, - 'past hour': 1, - 'last hour': 1, - 'past few weeks': 21, - 'last few weeks': 21, - 'past couple weeks': 14, - 'last couple weeks': 14, - 'past few months': 90, - 'last few months': 90, - 'past couple months': 60, - 'last couple months': 60 - } - # Regular expressions for relative time - self.relative_time_patterns = [ - (r'past (\d+) days?', lambda x: int(x)), - (r'last (\d+) days?', lambda x: int(x)), - (r'past (\d+) weeks?', lambda x: int(x) * 7), - (r'last (\d+) weeks?', lambda x: int(x) * 7), - (r'past (\d+) months?', lambda x: int(x) * 30), - (r'last (\d+) months?', lambda x: int(x) * 30), - (r'past (\d+) years?', lambda x: int(x) * 365), - (r'last (\d+) years?', lambda x: int(x) * 365) - ] - self.provider = SearchProviderFactory.get_provider(provider_type=provider_type) - - def extract_time_reference(self, text: str) -> int: - """ - Extract a time reference from text and convert it to number of days. - Handles both fixed and relative time expressions. - - Args: - text: The text to extract the time reference from. - - Returns: - Number of days to look back for news (default: 3 if no time reference found) - """ - text_lower = text.lower() - - # Check for exact matches first - for time_ref, days in self.time_indicators.items(): - if time_ref in text_lower: - return days - - # Check for relative time patterns - for pattern, converter in self.relative_time_patterns: - match = re.search(pattern, text_lower) - if match: - return converter(match.group(1)) - - # If no time reference found, default to 3 days - return 3 - - def extract_search_topic(self, text: str) -> str: - """ - Extract the main search topic from the query text. - Improved to handle compound topics better. - - Args: - text: The query text to extract the topic from. - - Returns: - The extracted search topic or the original text if no specific topic is found. - """ - # Tokenize and tag parts of speech - tokens = word_tokenize(text.lower()) - tagged = pos_tag(tokens) - - # Look for topic after common indicators - for i, (word, _) in enumerate(tagged): - if word in self.topic_indicators and i + 1 < len(tagged): - # Extract everything after the indicator as the topic - topic_words = [] - for word, pos in tagged[i+1:]: - # Include nouns, adjectives, verbs, and conjunctions for compound topics - if pos.startswith(('NN', 'JJ', 'VB')) or word in ['and', 'or']: - topic_words.append(word) - if topic_words: - return ' '.join(topic_words) - - # If no specific pattern found, use the original text with common words filtered out - stop_words = {'what', 'is', 'are', 'the', 'tell', 'me', 'search', 'find', 'get', 'news'} - topic_words = [] - for word, pos in tagged: - # Include conjunctions to better handle compound topics - if (word not in stop_words and pos.startswith(('NN', 'JJ', 'VB'))) or word in ['and', '&', 'or']: - topic_words.append(word) - - return ' '.join(topic_words) if topic_words else text - - def search_news(self, query: str, provider_type: Optional[str] = None) -> Dict[str, Any]: - """ - Search for news articles using the configured search provider. - - Args: - query: The search query string. - provider_type: Optional provider type to use for this specific search - - Returns: - Dict containing the search results or error information. - """ - # Use a new provider just for this search if specified - provider = (SearchProviderFactory.get_provider(provider_type=provider_type) - if provider_type else self.provider) - - if not provider.is_configured(): - return {'error': 'Search provider not configured'} - - search_topic = self.extract_search_topic(query) - days_to_search = self.extract_time_reference(query) - - return provider.search( - search_topic, - search_depth="basic", - topic="news", - max_results=5, - include_answer=True, - include_images=False, - days=days_to_search - ) - -# Test code -class TestNewsSkill(unittest.TestCase): - def setUp(self): - self.news_skill = NewsSkill() - - def test_extract_search_topic(self): - test_cases = [ - ("What's the latest news about artificial intelligence?", "artificial intelligence"), - ("Tell me the current news regarding climate change", "climate change"), - ("Search for news about SpaceX launches", "spacex launches"), - ("What's happening in technology today?", "technology"), - ("Tell me about AI and machine learning", "ai and machine learning"), # Test compound topic - ] - - for query, expected in test_cases: - result = self.news_skill.extract_search_topic(query) - self.assertEqual(result.lower(), expected.lower()) - - def test_extract_time_reference(self): - test_cases = [ - ("What happened today in tech?", 1), - ("Show me last week's news about AI", 7), - ("What's the latest on climate change?", 3), - ("Tell me about space news from last month", 30), - ("What happened in politics this year?", 365), - ("Show me news about crypto", 3), # Default case - ("News from past 5 days", 5), # Test relative time - ("Updates from last 2 weeks", 14), # Test relative time - ] - - for query, expected_days in test_cases: - result = self.news_skill.extract_time_reference(query) - self.assertEqual(result, expected_days) - - @patch('search_providers.factory.SearchProviderFactory.get_provider') - def test_search_news_success(self, mock_factory): - # Mock successful provider response - mock_response = { - "answer": "Recent developments in AI...", - "articles": [ - { - "title": "AI Breakthrough", - "url": "https://example.com/ai-news", - "published_date": "2024-03-20", - "content": "Scientists have made significant progress...", - "score": 0.95 - } - ], - "time_period": "Past 3 days", - "topic": "AI" - } - mock_provider = MagicMock() - mock_provider.is_configured.return_value = True - mock_provider.search.return_value = mock_response - mock_factory.return_value = mock_provider - - # Create a new instance with the mocked provider - self.news_skill = NewsSkill() - - result = self.news_skill.search_news("What's new in AI?") - self.assertIn("answer", result) - self.assertIn("articles", result) - self.assertTrue(len(result["articles"]) > 0) - self.assertIn("score", result["articles"][0]) - - @patch('search_providers.factory.SearchProviderFactory.get_provider') - def test_search_news_provider_not_configured(self, mock_factory): - mock_provider = MagicMock() - mock_provider.is_configured.return_value = False - mock_factory.return_value = mock_provider - - # Create a new instance with the mocked provider - self.news_skill = NewsSkill() - - result = self.news_skill.search_news("What's new in AI?") - self.assertIn("error", result) - self.assertEqual(result["error"], "Search provider not configured") - -if __name__ == '__main__': - unittest.main() diff --git a/skills/search_skill.py b/skills/search_skill.py deleted file mode 100644 index 10be7f2..0000000 --- a/skills/search_skill.py +++ /dev/null @@ -1,136 +0,0 @@ -from nltk import word_tokenize, pos_tag -from typing import Dict, Any, Optional -import unittest -from unittest.mock import patch, MagicMock -import sys -from pathlib import Path -sys.path.append(str(Path(__file__).parent.parent)) -from search_providers import SearchProviderFactory - -class SearchSkill: - def __init__(self, provider_type: Optional[str] = None): - self.topic_indicators = { - 'about', 'for', 'on', 'related to', 'search', - 'find', 'look up', 'information about' - } - self.provider = SearchProviderFactory.get_provider(provider_type=provider_type) - - def extract_search_topic(self, text: str) -> str: - """ - Extract the main search topic from the query text. - - Args: - text: The query text to extract the topic from. - - Returns: - The extracted search topic or the original text if no specific topic is found. - """ - # Tokenize and tag parts of speech - tokens = word_tokenize(text.lower()) - tagged = pos_tag(tokens) - - # Look for topic after common indicators - for i, (word, _) in enumerate(tagged): - if word in self.topic_indicators and i + 1 < len(tagged): - # Filter out common words and get only content words - topic_words = [] - for word, pos in tagged[i+1:]: - if pos.startswith(('NN', 'JJ', 'VB')): # Only include nouns, adjectives, and verbs - topic_words.append(word) - if topic_words: - return ' '.join(topic_words) - - # If no specific pattern found, use the original text with common words filtered out - stop_words = {'what', 'is', 'are', 'the', 'tell', 'me', 'can', 'you', 'please', 'for', 'about'} - topic_words = [word for word, pos in tagged if word not in stop_words and pos.startswith(('NN', 'JJ', 'VB'))] - - return ' '.join(topic_words) if topic_words else text - - def search(self, query: str, provider_type: Optional[str] = None) -> Dict[str, Any]: - """ - Perform a general search using the configured search provider. - - Args: - query: The search query string. - provider_type: Optional provider type to use for this specific search - - Returns: - Dict containing the search results or error information. - """ - # Use a new provider just for this search if specified - provider = (SearchProviderFactory.get_provider(provider_type=provider_type) - if provider_type else self.provider) - - if not provider.is_configured(): - return {'error': 'Search provider not configured'} - - search_topic = self.extract_search_topic(query) - return provider.search( - search_topic, - search_depth="basic", - max_results=3, - include_answer=True, - include_images=False - ) - -# Test code -class TestSearchSkill(unittest.TestCase): - def setUp(self): - self.search_skill = SearchSkill() - - def test_extract_search_topic(self): - test_cases = [ - ("Search for quantum computing", "quantum computing"), - ("Tell me about the history of Rome", "history rome"), - ("Look up information about electric cars", "electric cars"), - ("What is machine learning?", "machine learning"), - ("Find recipes for chocolate cake", "recipes chocolate cake"), - ] - - for query, expected in test_cases: - result = self.search_skill.extract_search_topic(query) - self.assertEqual(result.lower(), expected.lower()) - - @patch('search_providers.factory.SearchProviderFactory.get_provider') - def test_search_success(self, mock_factory): - # Mock successful provider response - mock_response = { - "answer": "Here's what I found about quantum computing...", - "results": [ - { - "title": "Introduction to Quantum Computing", - "url": "https://example.com/quantum", - "content": "Quantum computing is an emerging technology...", - "score": 0.95 - } - ] - } - mock_provider = MagicMock() - mock_provider.is_configured.return_value = True - mock_provider.search.return_value = mock_response - mock_factory.return_value = mock_provider - - # Create a new instance with the mocked provider - self.search_skill = SearchSkill() - - result = self.search_skill.search("What is quantum computing?") - self.assertIn("answer", result) - self.assertIn("results", result) - self.assertTrue(len(result["results"]) > 0) - self.assertIn("score", result["results"][0]) - - @patch('search_providers.factory.SearchProviderFactory.get_provider') - def test_search_provider_not_configured(self, mock_factory): - mock_provider = MagicMock() - mock_provider.is_configured.return_value = False - mock_factory.return_value = mock_provider - - # Create a new instance with the mocked provider - self.search_skill = SearchSkill() - - result = self.search_skill.search("What is quantum computing?") - self.assertIn("error", result) - self.assertEqual(result["error"], "Search provider not configured") - -if __name__ == '__main__': - unittest.main() diff --git a/skills/weather_skill.py b/skills/weather_skill.py deleted file mode 100644 index 3649d06..0000000 --- a/skills/weather_skill.py +++ /dev/null @@ -1,241 +0,0 @@ -import requests -import os -from nltk import word_tokenize, pos_tag, ne_chunk -from nltk.tree import Tree -from typing import Dict, Any, Optional -from config_loader import config -import unittest -from unittest.mock import patch - -class WeatherSkill: - def __init__(self): - self.time_indicators = { - 'today', 'tomorrow', 'tonight', 'morning', 'afternoon', - 'evening', 'week', 'weekend', 'monday', 'tuesday', - 'wednesday', 'thursday', 'friday', 'saturday', 'sunday' - } - self.location_labels = {'GPE', 'LOC'} - - def extract_location(self, text: str) -> Optional[str]: - """ - Extract a location from a text using named entity recognition. - - Args: - text: The text to extract the location from. - - Returns: - The location found in the text, or None if no location was found. - """ - tokens = word_tokenize(text) - tagged = pos_tag(tokens) - chunks = ne_chunk(tagged) - - locations = [] - if isinstance(chunks, Tree): - for subtree in chunks: - if isinstance(subtree, Tree) and subtree.label() in self.location_labels: - location = ' '.join([token for token, pos in subtree.leaves()]) - locations.append(location) - - return locations[0] if locations else None - - def extract_time_reference(self, text: str) -> Optional[str]: - """ - Extract a time reference from a text using a set of predefined time indicators. - - Args: - text: The text to extract the time reference from. - - Returns: - The time reference found in the text, or None if no time reference was found. - """ - tokens = word_tokenize(text.lower()) - for token in tokens: - if token in self.time_indicators: - return token - return None - - def condense_json(self, input_json): - """ - Condense a JSON object by rounding all floating point numbers to integers - and removing specific weather elements (sea_level, grnd_level, pressure, visibility, temp_kf). - - Args: - input_json: The JSON object to condense. - - Returns: - The condensed JSON object with rounded numbers and removed elements. - """ - def round_numbers_and_filter(obj): - if isinstance(obj, dict): - filtered_dict = {} - for k, v in obj.items(): - # Skip specific weather elements - if k not in ['sea_level', 'grnd_level', 'pressure', 'visibility','temp_kf']: - filtered_dict[k] = round_numbers_and_filter(v) - return filtered_dict - elif isinstance(obj, list): - return [round_numbers_and_filter(item) for item in obj] - elif isinstance(obj, float): - return int(obj) - else: - return obj - - # Create a deep copy of the input JSON, round the numbers, and filter elements - condensed_json = round_numbers_and_filter(input_json) - - return condensed_json - - def getLocationWeather(self, query: str) -> Dict[str, Any]: - """ - Retrieve weather information for a given location from a query string. - - This function uses OpenWeatherMap API to fetch the current weather data - for a location extracted from the query. If the query includes a time - reference other than 'today', it also fetches the weather forecast. - - Args: - query (str): The query string containing the location and optionally a time reference. - - Returns: - Dict[str, Any]: A dictionary containing the weather data. If an error occurs, - the dictionary contains an 'error' key with the error message. - """ - try: - location_name = self.extract_location(query) or str(config.DEFAULT_LOCATION) - if not location_name: - return {'error': 'No location found in query and DEFAULT_LOCATION is not set'} - time_ref = self.extract_time_reference(query) - - apiKey = "34805111fe90be66c8b6923016ef27c0" - - - #if the query implies a time reference, get the forecast - if time_ref and time_ref != 'today': - forecastUrl = "https://api.openweathermap.org/data/2.5/forecast?q=" + location_name + "&units=" + str(config.DEFAULT_UNITS) + "&appid=" + apiKey - forecast_response = requests.get(forecastUrl) - forecast_response.raise_for_status() - weather_data = self.condense_json(forecast_response.json()) - - else: - #get the current weather - completeUrl = "https://api.openweathermap.org/data/2.5/weather?q=" + location_name + "&units=" + str(config.DEFAULT_UNITS) + "&appid=" + apiKey - response = requests.get(completeUrl) - response.raise_for_status() - weather_data = self.condense_json(response.json()) - - return weather_data - except requests.exceptions.RequestException as e: - return {'error': f'Error fetching weather data: {e}'} - except Exception as e: - return {'error': f'An unexpected error occurred: {e}'} - -# Test code -class TestWeatherSkill(unittest.TestCase): - def setUp(self): - self.weather_skill = WeatherSkill() - - def test_extract_location(self): - self.assertEqual(self.weather_skill.extract_location("What's the weather like in New York?"), "New York") - self.assertIsNone(self.weather_skill.extract_location("What's the weather like today?")) - - def test_extract_time_reference(self): - self.assertEqual(self.weather_skill.extract_time_reference("What's the weather like tomorrow?"), "tomorrow") - self.assertIsNone(self.weather_skill.extract_time_reference("What's the weather like in London?")) - - def test_condense_json(self): - test_json = { - "temp": 20.5, - "humidity": 45.7, - "pressure": 1013, - "sea_level": 1014, - "grnd_level": 1012, - "forecast": [ - { - "temp": 22.3, - "wind": 5.8, - "pressure": 1015, - "sea_level": 1016, - "grnd_level": 1014 - }, - { - "temp": 21.7, - "wind": 6.2, - "pressure": 1014, - "sea_level": 1015, - "grnd_level": 1013 - } - ] - } - expected = { - "temp": 20, - "humidity": 45, - "forecast": [ - { - "temp": 22, - "wind": 5 - }, - { - "temp": 21, - "wind": 6 - } - ] - } - result = self.weather_skill.condense_json(test_json) - self.assertEqual(result, expected) - - @patch('requests.get') - def test_getLocationWeather_current(self, mock_get): - # Mock response for current weather - mock_current = unittest.mock.Mock() - mock_current.json.return_value = {"weather": [{"description": "clear sky"}], "main": {"temp": 20}} - mock_current.raise_for_status = lambda: None - mock_get.return_value = mock_current - - result = self.weather_skill.getLocationWeather("What's the weather like in London?") - self.assertIn("weather", result) - self.assertIn("main", result) - - @patch('requests.get') - def test_getLocationWeather_forecast(self, mock_get): - # Mock responses for both current weather and forecast - mock_current = unittest.mock.Mock() - mock_current.json.return_value = {"weather": [{"description": "clear sky"}], "main": {"temp": 20}} - mock_current.raise_for_status = lambda: None - - mock_forecast = unittest.mock.Mock() - mock_forecast.json.return_value = { - "list": [ - {"main": {"temp": 22.5, "pressure": 1013}, "weather": [{"description": "sunny"}]}, - {"main": {"temp": 21.3, "pressure": 1014}, "weather": [{"description": "cloudy"}]} - ] - } - mock_forecast.raise_for_status = lambda: None - - # Configure mock_get to return different responses for different URLs - def side_effect(url): - if "forecast" in url: - return mock_forecast - return mock_current - - mock_get.side_effect = side_effect - - result = self.weather_skill.getLocationWeather("What's the weather like tomorrow in London?") - self.assertIn("list", result) - # Verify pressure was removed from the response - self.assertNotIn("pressure", result["list"][0]["main"]) - - def test_getLocationWeather_no_location(self): - result = self.weather_skill.getLocationWeather("What's the weather like?") - self.assertIn("error", result) - self.assertEqual(result["error"], "No location found in query") - - @patch('requests.get') - def test_getLocationWeather_api_error(self, mock_get): - mock_get.side_effect = requests.exceptions.RequestException("API Error") - result = self.weather_skill.getLocationWeather("What's the weather like in London?") - self.assertIn("error", result) - self.assertTrue(result["error"].startswith("Error fetching weather data")) - -if __name__ == '__main__': - unittest.main() From 15ac6c140af67a8fa28480a348183cc275d23249 Mon Sep 17 00:00:00 2001 From: Graham V Date: Tue, 19 Nov 2024 17:16:05 -0500 Subject: [PATCH 4/4] pick 6f3c16f Two changes to utils.py - when fetching url_text_contents, if clipboard contains a url, then fetch the text contents of that url using jina.ai - update to token limit maintenance logic to ensure that the most recent message is never trimmed from context regardless of its token count --- .DS_Store | Bin 0 -> 10244 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 .DS_Store diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..72096f33efd35d345a62155face92dc3841312f2 GIT binary patch literal 10244 zcmeHMO^6&t6t3FLW_z453>yVmJ+$PYV4}0hx<4p7Syu%&N^p`GNWx@#W|AE<)7_cw z*)=Q+qacEMQ@p6KUet?tlbj?HSO~d9BT7zsNl3_L4}w1j@x7{^=~vU;Gpiu}P(5_N z?y9e$Tkv`901FqJ!_nvELNPR8$s#< zt$&A=7{+mRM(VB@kWSlN^ZOAIcXL%A8?3W4oNE1NSOEIxQJG7nZLT#OF=n*z>w zL0Wy(3TOqQ3gFs3#U5s-S)KWG{hs~e*VkxW(^*)s8jb~J#_OM5dNsSev0tsP4y<>V z!fylF%j?jTjlsCf%wrYiuqvz3(P08#iy|n8o7XyT0%y0j==m>p%kZLzX^7 z&S5t5s$QjGIZgxbkbPCE3%56-2ul5@S9XrCd4aUb5LDV^ltOi66RBo^Y8!%#R7>m> z*anE(qdWkjP)A&1XVeA0fBdWCX9GQBho-I!M=LH#7I%OXHe^d1B2WVNT)6gBw(p+- zXe^F${qgHhc0MHKPQZqFG7PKoyn;*(kh;N|AxCS)bAy?D-f|mt&(G(-<0IQrqua-f zF(YG~C@wiGMZfH~z4>zM7+xKMyt2CIW$Q@2wo=O7^^{}#<+@#Ki3as*4Isx~s@qj( zW!`DqRZp}n$_gWGq)WMbR##{D=69RZvvX^^&DA+PpV~LQww5;T-u1|1FP6?Unzr+% zl$Zw(dv>S1?j7HX^Q)Wxh>AF^1co2Y$6;ttW*2kkuU16@Mf#$tGYF>yV??}|q(z#z zkozqNhz~#SpAN^&B!uA6nS28ZaEju_8qKw_$CNG87g^e}`Z21SFy_dGxMxtuAn z$Z(zvM$*7cT)6-Pvn&Lp#?HcQu8XrKT-g{N$`Wul24N}J9k)KXSSg7FO6;2^(pK=4 zH)EK?aHGfPXA4-nQMQji5Jgaq&+e@p&2*okNxY?`P@%!+} zsnk<0$KL+1t5+W;jA{}_;lWK9hV1|z=V13Cj(G2LK_dswgi77S#M{e2pc?$=z8!|A z4qd;!=bM}qciaDg>9A=y7@ThPV9ad#$`!HomE#-Liz|TMY))q;sAVSIG5$gB{ zar|)@xdo3wco889DPK(z9*=XEw%U$z&a|JHJ|!pCmd2#|{@J#wpC{2zgB^ljF?^h= zIWn~?4nz6BaKY+Mx=3G?>mZ!UOH8~`@gjonRekY?r8LGLUsFUafqXAk0t-ya*0iIeYF3bb}?sB`|75vt2!hRBV+z{81Ry2VF zLaD{F?u_{~4gy|E=EV7`^_-R7_l>(>GVZ J{cqO)e*gr+!p{Hz literal 0 HcmV?d00001