diff --git a/knowledge_storm/rm.py b/knowledge_storm/rm.py index 563116fe..7d4c6a7c 100644 --- a/knowledge_storm/rm.py +++ b/knowledge_storm/rm.py @@ -69,7 +69,8 @@ def forward( if "hits" in results: collected_results.extend(authoritative_results[: self.k]) except Exception as e: - logging.error(f"Error occurs when searching query {query}: {e}") + logging.error( + f"Error occurs when searching query {query}: {e}") return collected_results @@ -151,7 +152,8 @@ def forward( for query in queries: try: results = requests.get( - self.endpoint, headers=headers, params={**self.params, "q": query} + self.endpoint, headers=headers, params={ + **self.params, "q": query} ).json() for d in results["webPages"]["value"]: @@ -162,7 +164,8 @@ def forward( "description": d["snippet"], } except Exception as e: - logging.error(f"Error occurs when searching query {query}: {e}") + logging.error( + f"Error occurs when searching query {query}: {e}") valid_url_to_snippets = self.webpage_helper.urls_to_snippets( list(url_to_results.keys()) @@ -268,7 +271,8 @@ def init_online_vector_db(self, url: str, api_key: str): self.client = QdrantClient(url=url, api_key=api_key) self._check_collection() except Exception as e: - raise ValueError(f"Error occurs when connecting to the server: {e}") + raise ValueError( + f"Error occurs when connecting to the server: {e}") def init_offline_vector_db(self, vector_store_path: str): from qdrant_client import QdrantClient @@ -286,7 +290,8 @@ def init_offline_vector_db(self, vector_store_path: str): self.client = QdrantClient(path=vector_store_path) self._check_collection() except Exception as e: - raise ValueError(f"Error occurs when loading the vector store: {e}") + raise ValueError( + f"Error occurs when loading the vector store: {e}") def get_usage_and_reset(self): usage = self.usage @@ -322,7 +327,8 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st self.usage += len(queries) collected_results = [] for query in queries: - related_docs = self.qdrant.similarity_search_with_score(query, k=self.k) + related_docs = self.qdrant.similarity_search_with_score( + query, k=self.k) for i in range(len(related_docs)): doc = related_docs[i][0] collected_results.append( @@ -356,7 +362,8 @@ def _retrieve(self, query: str): payload = {"query": query, "num_blocks": self.k, "rerank": self.rerank} response = requests.post( - self.endpoint, json=payload, headers={"Content-Type": "application/json"} + self.endpoint, json=payload, headers={ + "Content-Type": "application/json"} ) # Check if the request was successful @@ -399,7 +406,8 @@ def forward( results = self._retrieve(query) collected_results.extend(results) except Exception as e: - logging.error(f"Error occurs when searching query {query}: {e}") + logging.error( + f"Error occurs when searching query {query}: {e}") return collected_results @@ -547,7 +555,8 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st snippets = [organic.get("snippet")] if self.ENABLE_EXTRA_SNIPPET_EXTRACTION: snippets.extend( - valid_url_to_snippets.get(url, {}).get("snippets", []) + valid_url_to_snippets.get( + url, {}).get("snippets", []) ) collected_results.append( { @@ -636,7 +645,8 @@ def forward( } ) except Exception as e: - logging.error(f"Error occurs when searching query {query}: {e}") + logging.error( + f"Error occurs when searching query {query}: {e}") return collected_results @@ -720,7 +730,8 @@ def forward( } ) except Exception as e: - logging.error(f"Error occurs when searching query {query}: {e}") + logging.error( + f"Error occurs when searching query {query}: {e}") return collected_results @@ -880,7 +891,8 @@ def __init__( try: from tavily import TavilyClient except ImportError as err: - raise ImportError("Tavily requires `pip install tavily-python`.") from err + raise ImportError( + "Tavily requires `pip install tavily-python`.") from err if not tavily_search_api_key and not os.environ.get("TAVILY_API_KEY"): raise RuntimeError( @@ -927,11 +939,24 @@ def forward( Returns: a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url' """ - queries = ( - [query_or_queries] - if isinstance(query_or_queries, str) - else query_or_queries - ) + # Validate input queries + if not query_or_queries: + print("Warning: Empty query_or_queries provided to TavilySearchRM") + return [] + + if isinstance(query_or_queries, str): + if not query_or_queries.strip(): + print("Warning: Empty string query provided to TavilySearchRM") + return [] + queries = [query_or_queries] + else: + # Filter out empty queries from list + queries = [q for q in query_or_queries if q and isinstance( + q, str) and q.strip()] + if not queries: + print("Warning: All queries in list are empty or invalid") + return [] + self.usage += len(queries) collected_results = [] @@ -942,8 +967,12 @@ def forward( "include_raw_contents": self.include_raw_content, } # list of dicts that will be parsed to return - responseData = self.tavily_client.search(query) - results = responseData.get("results") + try: + responseData = self.tavily_client.search(query) + results = responseData.get("results") + except Exception as e: + print(f"Error searching with Tavily for query '{query}': {e}") + continue for d in results: # assert d is dict if not isinstance(d, dict): @@ -1091,7 +1120,8 @@ def forward( } except Exception as e: - logging.error(f"Error occurred while searching query {query}: {e}") + logging.error( + f"Error occurred while searching query {query}: {e}") valid_url_to_snippets = self.webpage_helper.urls_to_snippets( list(url_to_results.keys()) @@ -1233,6 +1263,7 @@ def forward( } collected_results.append(document) except Exception as e: - logging.error(f"Error occurs when searching query {query}: {e}") + logging.error( + f"Error occurs when searching query {query}: {e}") return collected_results