-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Expand file tree
/
Copy path_common.py
More file actions
586 lines (476 loc) · 17 KB
/
_common.py
File metadata and controls
586 lines (476 loc) · 17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
# -*- coding: utf-8 -*-
"""The common utilities for agentscope library."""
import asyncio
import base64
import functools
import ipaddress
import inspect
import json
import os
import socket
import tempfile
import types
import typing
import uuid
from datetime import datetime
from typing import Any, Callable, Type, Dict
from urllib.parse import urljoin, urlparse
import numpy as np
import requests
from docstring_parser import parse
from json_repair import repair_json
from pydantic import BaseModel, Field, create_model, ConfigDict
from .._logging import logger
from ..types import ToolFunction
if typing.TYPE_CHECKING:
from mcp.types import Tool
else:
Tool = "mcp.types.Tool"
def _json_loads_with_repair(
json_str: str,
) -> dict:
"""The given json_str maybe incomplete, e.g. '{"key', so we need to
repair and load it into a Python object.
.. note::
This function is currently only used for parsing the streaming output
of the argument field in `tool_use`, so the parsed result must be a
dict.
Args:
json_str (`str`):
The JSON string to parse, which may be incomplete or malformed.
Returns:
`dict`:
A dictionary parsed from the JSON string after repair attempts.
Returns an empty dict if all repair attempts fail.
"""
try:
repaired = repair_json(json_str, stream_stable=True)
result = json.loads(repaired)
if isinstance(result, dict):
return result
except Exception:
if len(json_str) > 100:
log_str = json_str[:100] + "..."
else:
log_str = json_str
logger.warning(
"Failed to load JSON dict from string: %s. Returning empty dict "
"instead.",
log_str,
)
return {}
def _parse_streaming_json_dict(
json_str: str,
last_input: dict | None = None,
) -> dict:
"""Parse a streaming JSON dict without regressing on incomplete chunks.
If the current chunk already forms a valid JSON dict, prefer it directly.
Otherwise, fall back to repaired JSON and keep the previous parsed value
only when repair would shrink the intermediate structure.
"""
json_str = json_str or "{}"
try:
result = json.loads(json_str)
if isinstance(result, dict):
return result
except Exception:
pass
repaired_input = _json_loads_with_repair(json_str)
last_input = last_input or {}
if len(json.dumps(last_input)) > len(json.dumps(repaired_input)):
return last_input
return repaired_input
def _is_accessible_local_file(url: str) -> bool:
"""Check if the given URL is a local URL."""
# First identify if it's an uri with 'file://' schema,
if url.startswith("file://"):
local_path = url.removeprefix("file://")
return os.path.isfile(local_path)
return os.path.isfile(url)
def _get_timestamp(add_random_suffix: bool = False) -> str:
"""Get the current timestamp in the format YYYY-MM-DD HH:MM:SS.sss."""
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
if add_random_suffix:
# Add a random suffix to the timestamp
timestamp += f"_{os.urandom(3).hex()}"
return timestamp
async def _is_async_func(func: Callable) -> bool:
"""Check if the given function is an async function, including
coroutine functions, async generators, and coroutine objects.
"""
return (
inspect.iscoroutinefunction(func)
or inspect.isasyncgenfunction(func)
or isinstance(func, types.CoroutineType)
or isinstance(func, types.GeneratorType)
and asyncio.iscoroutine(func)
or isinstance(func, functools.partial)
and await _is_async_func(func.func)
)
async def _execute_async_or_sync_func(
func: Callable,
*args: Any,
**kwargs: Any,
) -> Any:
"""Execute an async or sync function based on its type.
Args:
func (`Callable`):
The function to be executed, which can be either async or sync.
*args (`Any`):
Positional arguments to be passed to the function.
**kwargs (`Any`):
Keyword arguments to be passed to the function.
Returns:
`Any`:
The result of the function execution.
"""
if await _is_async_func(func):
return await func(*args, **kwargs)
return func(*args, **kwargs)
def _get_bytes_from_web_url(
url: str,
max_retries: int = 3,
) -> str:
"""Get the bytes from a given URL.
Args:
url (`str`):
The URL to fetch the bytes from.
max_retries (`int`, defaults to `3`):
The maximum number of retries.
"""
for _ in range(max_retries):
try:
response = _request_url_with_validated_redirects(url)
response.raise_for_status()
return response.content.decode("utf-8")
except UnicodeDecodeError:
return base64.b64encode(response.content).decode("ascii")
except Exception as e:
logger.info(
"Failed to fetch bytes from URL %s. Error %s. Retrying...",
url,
str(e),
)
raise RuntimeError(
f"Failed to fetch bytes from URL `{url}` after {max_retries} retries.",
)
def _is_public_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
"""Return whether the given IP address is globally routable."""
return ip.is_global
def _validate_external_url(url: str) -> None:
"""Validate URL to prevent fetching local or private network resources."""
parsed = urlparse(url)
if parsed.scheme not in {"http", "https"}:
raise ValueError(
f"Unsupported URL scheme: {parsed.scheme}. "
"Only http/https are allowed.",
)
host = parsed.hostname
if not host:
raise ValueError(f"Invalid URL without hostname: {url}")
# Fast-path for literal IP addresses
try:
ip = ipaddress.ip_address(host)
except ValueError:
ip = None
if ip is not None:
if not _is_public_ip(ip):
raise ValueError(
f"Blocked non-public URL host: {host}",
)
return
if host.lower() == "localhost":
raise ValueError("Blocked localhost URL host.")
try:
addresses = socket.getaddrinfo(host, None)
except socket.gaierror as e:
raise ValueError(f"Failed to resolve URL host: {host}") from e
if not addresses:
raise ValueError(f"Failed to resolve URL host: {host}")
for addr_info in addresses:
resolved_ip = ipaddress.ip_address(addr_info[4][0])
if not _is_public_ip(resolved_ip):
raise ValueError(
f"Blocked non-public URL host {host} "
f"(resolved to {resolved_ip}).",
)
def _request_url_with_validated_redirects(
url: str,
max_redirects: int = 5,
) -> requests.Response:
"""Request URL while validating each redirect target."""
current_url = url
for _ in range(max_redirects + 1):
_validate_external_url(current_url)
response = requests.get(
current_url,
allow_redirects=False,
timeout=(5, 10),
)
if (
response.status_code in {301, 302, 303, 307, 308}
and "Location" in response.headers
):
current_url = urljoin(current_url, response.headers["Location"])
continue
return response
raise RuntimeError(
f"Exceeded maximum redirects ({max_redirects}) for URL `{url}`.",
)
def _save_base64_data(
media_type: str,
base64_data: str,
) -> str:
"""Save the base64 data to a temp file and return the file path. The
extension is guessed from the MIME type.
Args:
media_type (`str`):
The MIME type of the data, e.g. "image/png", "audio/mpeg".
base64_data (`str`):
The base64 data to be saved.
"""
extension = "." + media_type.split("/")[-1]
with tempfile.NamedTemporaryFile(
suffix=extension,
delete=False,
) as temp_file:
decoded_data = base64.b64decode(base64_data)
temp_file.write(decoded_data)
return temp_file.name
def _extract_json_schema_from_mcp_tool(tool: Tool) -> dict[str, Any]:
"""Extract JSON schema from MCP tool."""
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": tool.inputSchema.get(
"properties",
{},
),
"required": tool.inputSchema.get(
"required",
[],
),
},
},
}
def _remove_title_field(schema: dict) -> None:
"""Remove the title field from the JSON schema to avoid
misleading the LLM."""
# The top level title field
if "title" in schema:
schema.pop("title")
# properties
if "properties" in schema:
for prop in schema["properties"].values():
if isinstance(prop, dict):
_remove_title_field(prop)
# items
if "items" in schema and isinstance(schema["items"], dict):
_remove_title_field(schema["items"])
# additionalProperties
if "additionalProperties" in schema and isinstance(
schema["additionalProperties"],
dict,
):
_remove_title_field(
schema["additionalProperties"],
)
def _create_tool_from_base_model(
structured_model: Type[BaseModel],
tool_name: str = "generate_structured_output",
) -> Dict[str, Any]:
"""Create a function tool definition from a Pydantic BaseModel.
This function converts a Pydantic BaseModel class into a tool definition
that can be used with function calling API. The resulting tool
definition includes the model's JSON schema as parameters, enabling
structured output generation by forcing the model to call this function
with properly formatted data.
Args:
structured_model (`Type[BaseModel]`):
A Pydantic BaseModel class that defines the expected structure
for the tool's output.
tool_name (`str`, default `"generate_structured_output"`):
The tool name that used to force the LLM to generate structured
output by calling this function.
Returns:
`Dict[str, Any]`: A tool definition dictionary compatible with
function calling API, containing type ("function") and
function dictionary with name, description, and parameters
(JSON schema).
.. code-block:: python
:caption: Example usage
from pydantic import BaseModel
class PersonInfo(BaseModel):
name: str
age: int
email: str
tool = _create_tool_from_base_model(PersonInfo, "extract_person")
print(tool["function"]["name"]) # extract_person
print(tool["type"]) # function
.. note:: The function automatically removes the 'title' field from
the JSON schema to ensure compatibility with function calling
format. This is handled by the internal ``_remove_title_field()``
function.
"""
schema = structured_model.model_json_schema()
_remove_title_field(schema)
tool_definition = {
"type": "function",
"function": {
"name": tool_name,
"description": "Generate the required structured output with "
"this function",
"parameters": schema,
},
}
return tool_definition
def _map_text_to_uuid(text: str) -> str:
"""Map the given text to a deterministic UUID string.
Args:
text (`str`):
The input text to be mapped to a UUID.
Returns:
`str`:
A deterministic UUID string derived from the input text.
"""
return str(uuid.uuid3(uuid.NAMESPACE_DNS, text))
def _parse_tool_function(
tool_func: ToolFunction,
include_long_description: bool,
include_var_positional: bool,
include_var_keyword: bool,
) -> dict:
"""Extract JSON schema from the tool function's docstring
Args:
tool_func (`ToolFunction`):
The tool function to extract the JSON schema from.
include_long_description (`bool`):
Whether to include the long description in the JSON schema.
include_var_positional (`bool`):
Whether to include variable positional arguments in the JSON
schema.
include_var_keyword (`bool`):
Whether to include variable keyword arguments in the JSON schema.
Returns:
`dict`:
The extracted JSON schema.
"""
docstring = parse(tool_func.__doc__)
params_docstring = {_.arg_name: _.description for _ in docstring.params}
# Function description
descriptions = []
if docstring.short_description is not None:
descriptions.append(docstring.short_description)
if include_long_description and docstring.long_description is not None:
descriptions.append(docstring.long_description)
func_description = "\n".join(descriptions)
# Create a dynamic model with the function signature
fields = {}
for name, param in inspect.signature(tool_func).parameters.items():
# Skip the `self` and `cls` parameters
if name in ["self", "cls"]:
continue
# Handle `**kwargs`
if param.kind == inspect.Parameter.VAR_KEYWORD:
if not include_var_keyword:
continue
fields[name] = (
Dict[str, Any]
if param.annotation == inspect.Parameter.empty
else Dict[str, param.annotation], # type: ignore
Field(
description=params_docstring.get(
f"**{name}",
params_docstring.get(name, None),
),
default={}
if param.default is param.empty
else param.default,
),
)
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
if not include_var_positional:
continue
fields[name] = (
list[Any]
if param.annotation == inspect.Parameter.empty
else list[param.annotation], # type: ignore
Field(
description=params_docstring.get(
f"*{name}",
params_docstring.get(name, None),
),
default=[]
if param.default is param.empty
else param.default,
),
)
else:
fields[name] = (
Any
if param.annotation == inspect.Parameter.empty
else param.annotation,
Field(
description=params_docstring.get(name, None),
default=...
if param.default is param.empty
else param.default,
),
)
base_model = create_model(
"_StructuredOutputDynamicClass",
__config__=ConfigDict(arbitrary_types_allowed=True),
**fields,
)
params_json_schema = base_model.model_json_schema()
# Remove the title from the json schema
_remove_title_field(params_json_schema)
func_json_schema: dict = {
"type": "function",
"function": {
"name": tool_func.__name__,
"parameters": params_json_schema,
},
}
if func_description not in [None, ""]:
func_json_schema["function"]["description"] = func_description
return func_json_schema
def _resample_pcm_delta(
pcm_base64: str,
sample_rate: int,
target_rate: int,
) -> str:
"""Resampling the input pcm base64 data into the target rate.
Args:
pcm_base64 (`str`):
The input base64 audio data in pcm format.
sample_rate (`int`):
The sampling rate of the input data.
target_rate (`int`):
The target rate of the input data.
Returns:
`str`:
The resampling base64 audio data in the required sampling
rate.
"""
pcm_data = base64.b64decode(pcm_base64)
# Into numpy array first
audio_array = np.frombuffer(pcm_data, dtype=np.int16)
# return directly if the same
if sample_rate == target_rate:
return pcm_base64
# compute the number of samples
num_samples = int(len(audio_array) * target_rate / sample_rate)
from scipy import signal
# Use scipy to resample
resampled_audio = signal.resample(audio_array, num_samples)
# Turn it back into bytes
resampled_audio = np.clip(resampled_audio, -32768, 32767).astype(np.int16)
# into base64
resampled_bytes = resampled_audio.tobytes()
resampled_base64 = base64.b64encode(resampled_bytes).decode("utf-8")
return resampled_base64