Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 52 additions & 21 deletions knowledge_storm/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def forward(
if "hits" in results:
collected_results.extend(authoritative_results[: self.k])
except Exception as e:
logging.error(f"Error occurs when searching query {query}: {e}")
logging.error(
f"Error occurs when searching query {query}: {e}")

return collected_results

Expand Down Expand Up @@ -151,7 +152,8 @@ def forward(
for query in queries:
try:
results = requests.get(
self.endpoint, headers=headers, params={**self.params, "q": query}
self.endpoint, headers=headers, params={
**self.params, "q": query}
).json()

for d in results["webPages"]["value"]:
Expand All @@ -162,7 +164,8 @@ def forward(
"description": d["snippet"],
}
except Exception as e:
logging.error(f"Error occurs when searching query {query}: {e}")
logging.error(
f"Error occurs when searching query {query}: {e}")

valid_url_to_snippets = self.webpage_helper.urls_to_snippets(
list(url_to_results.keys())
Expand Down Expand Up @@ -268,7 +271,8 @@ def init_online_vector_db(self, url: str, api_key: str):
self.client = QdrantClient(url=url, api_key=api_key)
self._check_collection()
except Exception as e:
raise ValueError(f"Error occurs when connecting to the server: {e}")
raise ValueError(
f"Error occurs when connecting to the server: {e}")

def init_offline_vector_db(self, vector_store_path: str):
from qdrant_client import QdrantClient
Expand All @@ -286,7 +290,8 @@ def init_offline_vector_db(self, vector_store_path: str):
self.client = QdrantClient(path=vector_store_path)
self._check_collection()
except Exception as e:
raise ValueError(f"Error occurs when loading the vector store: {e}")
raise ValueError(
f"Error occurs when loading the vector store: {e}")

def get_usage_and_reset(self):
usage = self.usage
Expand Down Expand Up @@ -322,7 +327,8 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st
self.usage += len(queries)
collected_results = []
for query in queries:
related_docs = self.qdrant.similarity_search_with_score(query, k=self.k)
related_docs = self.qdrant.similarity_search_with_score(
query, k=self.k)
for i in range(len(related_docs)):
doc = related_docs[i][0]
collected_results.append(
Expand Down Expand Up @@ -356,7 +362,8 @@ def _retrieve(self, query: str):
payload = {"query": query, "num_blocks": self.k, "rerank": self.rerank}

response = requests.post(
self.endpoint, json=payload, headers={"Content-Type": "application/json"}
self.endpoint, json=payload, headers={
"Content-Type": "application/json"}
)

# Check if the request was successful
Expand Down Expand Up @@ -399,7 +406,8 @@ def forward(
results = self._retrieve(query)
collected_results.extend(results)
except Exception as e:
logging.error(f"Error occurs when searching query {query}: {e}")
logging.error(
f"Error occurs when searching query {query}: {e}")
return collected_results


Expand Down Expand Up @@ -547,7 +555,8 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st
snippets = [organic.get("snippet")]
if self.ENABLE_EXTRA_SNIPPET_EXTRACTION:
snippets.extend(
valid_url_to_snippets.get(url, {}).get("snippets", [])
valid_url_to_snippets.get(
url, {}).get("snippets", [])
)
collected_results.append(
{
Expand Down Expand Up @@ -636,7 +645,8 @@ def forward(
}
)
except Exception as e:
logging.error(f"Error occurs when searching query {query}: {e}")
logging.error(
f"Error occurs when searching query {query}: {e}")

return collected_results

Expand Down Expand Up @@ -720,7 +730,8 @@ def forward(
}
)
except Exception as e:
logging.error(f"Error occurs when searching query {query}: {e}")
logging.error(
f"Error occurs when searching query {query}: {e}")

return collected_results

Expand Down Expand Up @@ -880,7 +891,8 @@ def __init__(
try:
from tavily import TavilyClient
except ImportError as err:
raise ImportError("Tavily requires `pip install tavily-python`.") from err
raise ImportError(
"Tavily requires `pip install tavily-python`.") from err

if not tavily_search_api_key and not os.environ.get("TAVILY_API_KEY"):
raise RuntimeError(
Expand Down Expand Up @@ -927,11 +939,24 @@ def forward(
Returns:
a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
# Validate input queries
if not query_or_queries:
print("Warning: Empty query_or_queries provided to TavilySearchRM")
return []

if isinstance(query_or_queries, str):
if not query_or_queries.strip():
print("Warning: Empty string query provided to TavilySearchRM")
return []
queries = [query_or_queries]
else:
# Filter out empty queries from list
queries = [q for q in query_or_queries if q and isinstance(
q, str) and q.strip()]
if not queries:
print("Warning: All queries in list are empty or invalid")
return []

self.usage += len(queries)

collected_results = []
Expand All @@ -942,8 +967,12 @@ def forward(
"include_raw_contents": self.include_raw_content,
}
# list of dicts that will be parsed to return
responseData = self.tavily_client.search(query)
results = responseData.get("results")
try:
responseData = self.tavily_client.search(query)
results = responseData.get("results")
except Exception as e:
print(f"Error searching with Tavily for query '{query}': {e}")
continue
for d in results:
# assert d is dict
if not isinstance(d, dict):
Expand Down Expand Up @@ -1091,7 +1120,8 @@ def forward(
}

except Exception as e:
logging.error(f"Error occurred while searching query {query}: {e}")
logging.error(
f"Error occurred while searching query {query}: {e}")

valid_url_to_snippets = self.webpage_helper.urls_to_snippets(
list(url_to_results.keys())
Expand Down Expand Up @@ -1233,6 +1263,7 @@ def forward(
}
collected_results.append(document)
except Exception as e:
logging.error(f"Error occurs when searching query {query}: {e}")
logging.error(
f"Error occurs when searching query {query}: {e}")

return collected_results