-
Notifications
You must be signed in to change notification settings - Fork 497
Expand file tree
/
Copy pathsearch_base.py
More file actions
239 lines (191 loc) · 6.75 KB
/
search_base.py
File metadata and controls
239 lines (191 loc) · 6.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
# flake8: noqa
import enum
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar
import json
if TYPE_CHECKING:
from ms_agent.llm.utils import Tool
T = TypeVar('T')
class SearchEngineType(enum.Enum):
EXA = 'exa'
SERPAPI = 'serpapi'
ARXIV = 'arxiv'
TAVILY = 'tavily'
# Mapping from engine type to tool name
ENGINE_TOOL_NAMES: Dict[str, str] = {
'exa': 'exa_search',
'serpapi': 'serpapi_search',
'arxiv': 'arxiv_search',
'tavily': 'tavily_search',
}
@dataclass
class BaseResult:
"""A class representing the base fields of a search result.
Attributes:
url (str): The URL of the search result.
id (str): The temporary ID for the document.
title (str): The title of the search result.
highlights (Optional[List[str]]): Highlights from the search result.
highlight_scores (Optional[List[float]]): Scores for the highlights.
summary (Optional[str]): A summary of the search result.
markdown (Optional[str]): Markdown content of the search result.
"""
url: Optional[str] = None
id: Optional[str] = None
title: Optional[str] = None
highlights: Optional[List[str]] = None
highlight_scores: Optional[List[float]] = None
summary: Optional[str] = None
markdown: Optional[str] = None
@dataclass
class SearchResponse(Generic[T]):
"""Base class for search responses."""
# A list of search results.
results: List[T]
class SearchRequest(ABC):
"""Abstract base class for search requests."""
def __init__(self,
query: str,
num_results: Optional[int] = 10,
**kwargs: Any):
"""
Initialize SearchRequest with search parameters.
Args:
query: The search query string
num_results: Number of results to return, default is 10
"""
self.query = query
self.num_results = num_results
self._kwargs = kwargs
@abstractmethod
def to_dict(self) -> Dict[str, Any]:
"""Convert the request parameters to a dictionary."""
pass
def to_json(self) -> str:
"""
Convert the request parameters to a JSON string.
Returns:
str: The parameters as a JSON string
"""
return json.dumps(self.to_dict(), ensure_ascii=False)
class SearchResult(ABC):
"""Base class for search results."""
def __init__(self,
query: str,
arguments: Optional[Dict[str, Any]] = None,
response: Any = None):
"""
Initialize SearchResult.
Args:
query: The original search query string
arguments: The arguments used for the search
response: The raw results returned by the search
"""
self.query = query
self.arguments = arguments
self.response = response
@abstractmethod
def _process_results(self) -> SearchResponse:
"""
Process the raw results into a standardized format.
Returns:
SearchResponse: Processed search results
"""
pass
def to_list(self) -> List[Dict[str, Any]]:
"""
Convert the search results to a list of dictionaries.
"""
if not self.response or not self.response.results:
print('***Warning: No search results found.')
return []
if not self.query:
print('***Warning: No query provided for search results.')
return []
res_list: List[Dict[str, Any]] = []
for res in self.response.results:
res_list.append({
'url': res.url,
'id': res.id,
'title': res.title,
'highlights': res.highlights,
'highlight_scores': res.highlight_scores,
'summary': res.summary,
'markdown': res.markdown,
})
return res_list
@staticmethod
def load_from_disk(file_path: str) -> List[Dict[str, Any]]:
"""Load search results from a JSON file."""
if not os.path.exists(file_path):
return []
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f'Search results loaded from {file_path}')
return data
class SearchEngine(ABC):
"""Abstract base class for search engines.
Subclasses should implement:
- search(): Perform the actual search
- get_tool_definition(): Return tool definition for agent use
- build_request_from_args(): Build request from tool call arguments
"""
# Must be set by subclass
engine_type: SearchEngineType = None
@abstractmethod
def search(self, search_request: SearchRequest) -> SearchResult:
"""Perform a search and return results."""
pass
@classmethod
def get_tool_name(cls) -> str:
"""Get the tool name for this engine."""
if cls.engine_type is None:
raise NotImplementedError('engine_type must be set by subclass')
return ENGINE_TOOL_NAMES.get(cls.engine_type.value, 'web_search')
@classmethod
def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool':
"""
Return the tool definition for this search engine.
Subclasses should override this to provide engine-specific
descriptions and parameters.
Args:
server_name: The server name for the tool
Returns:
Tool definition dict
"""
from ms_agent.llm.utils import Tool
return Tool(
tool_name=cls.get_tool_name(),
server_name=server_name,
description='Search the web for information.',
parameters={
'type': 'object',
'properties': {
'query': {
'type': 'string',
'description': 'The search query.',
},
'num_results': {
'type': 'integer',
'minimum': 1,
'maximum': 10,
'description': 'Number of results to return.',
},
},
'required': ['query'],
},
)
@classmethod
def build_request_from_args(cls, **kwargs) -> SearchRequest:
"""
Build a search request from tool call arguments.
Subclasses should override this to handle engine-specific parameters.
Args:
**kwargs: Tool call arguments
Returns:
SearchRequest instance
"""
raise NotImplementedError(
f'{cls.__name__} must implement build_request_from_args')