Skip to content

Commit c3a5106

Browse files
committed
feat(manager): 添加工具调用参数格式自动转换功能
1 parent ec375a1 commit c3a5106

File tree

1 file changed

+70
-1
lines changed

1 file changed

+70
-1
lines changed

app/agent/tools/manager.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,72 @@ def get_tool(self, tool_name: str) -> Optional[Any]:
100100
return tool
101101
return None
102102

103+
def _normalize_arguments(self, tool_instance: Any, arguments: Dict[str, Any]) -> Dict[str, Any]:
104+
"""
105+
根据工具的参数schema规范化参数类型
106+
107+
Args:
108+
tool_instance: 工具实例
109+
arguments: 原始参数
110+
111+
Returns:
112+
规范化后的参数
113+
"""
114+
# 获取工具的参数schema
115+
args_schema = getattr(tool_instance, 'args_schema', None)
116+
if not args_schema:
117+
return arguments
118+
119+
# 获取schema中的字段定义
120+
try:
121+
schema = args_schema.model_json_schema()
122+
properties = schema.get("properties", {})
123+
except Exception as e:
124+
logger.warning(f"获取工具schema失败: {e}")
125+
return arguments
126+
127+
# 规范化参数
128+
normalized = {}
129+
for key, value in arguments.items():
130+
if key not in properties:
131+
# 参数不在schema中,保持原样
132+
normalized[key] = value
133+
continue
134+
135+
field_info = properties[key]
136+
field_type = field_info.get("type")
137+
138+
# 处理 anyOf 类型(例如 Optional[int] 会生成 anyOf)
139+
any_of = field_info.get("anyOf")
140+
if any_of and not field_type:
141+
# 从 anyOf 中提取实际类型
142+
for type_option in any_of:
143+
if "type" in type_option and type_option["type"] != "null":
144+
field_type = type_option["type"]
145+
break
146+
147+
# 根据类型进行转换
148+
if field_type == "integer" and isinstance(value, str):
149+
try:
150+
normalized[key] = int(value)
151+
except (ValueError, TypeError):
152+
logger.warning(f"无法将参数 {key}='{value}' 转换为整数,保持原值")
153+
normalized[key] = value
154+
elif field_type == "number" and isinstance(value, str):
155+
try:
156+
normalized[key] = float(value)
157+
except (ValueError, TypeError):
158+
logger.warning(f"无法将参数 {key}='{value}' 转换为浮点数,保持原值")
159+
normalized[key] = value
160+
elif field_type == "boolean" and isinstance(value, str):
161+
# 转换字符串为布尔值
162+
normalized[key] = value.lower() in ("true", "1", "yes", "on")
163+
else:
164+
# 其他类型保持原样
165+
normalized[key] = value
166+
167+
return normalized
168+
103169
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
104170
"""
105171
调用工具
@@ -120,8 +186,11 @@ async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
120186
return error_msg
121187

122188
try:
189+
# 规范化参数类型
190+
normalized_arguments = self._normalize_arguments(tool_instance, arguments)
191+
123192
# 调用工具的run方法
124-
result = await tool_instance.run(**arguments)
193+
result = await tool_instance.run(**normalized_arguments)
125194

126195
# 确保返回字符串
127196
if isinstance(result, str):

0 commit comments

Comments
 (0)