@@ -100,6 +100,72 @@ def get_tool(self, tool_name: str) -> Optional[Any]:
100100 return tool
101101 return None
102102
103+ def _normalize_arguments (self , tool_instance : Any , arguments : Dict [str , Any ]) -> Dict [str , Any ]:
104+ """
105+ 根据工具的参数schema规范化参数类型
106+
107+ Args:
108+ tool_instance: 工具实例
109+ arguments: 原始参数
110+
111+ Returns:
112+ 规范化后的参数
113+ """
114+ # 获取工具的参数schema
115+ args_schema = getattr (tool_instance , 'args_schema' , None )
116+ if not args_schema :
117+ return arguments
118+
119+ # 获取schema中的字段定义
120+ try :
121+ schema = args_schema .model_json_schema ()
122+ properties = schema .get ("properties" , {})
123+ except Exception as e :
124+ logger .warning (f"获取工具schema失败: { e } " )
125+ return arguments
126+
127+ # 规范化参数
128+ normalized = {}
129+ for key , value in arguments .items ():
130+ if key not in properties :
131+ # 参数不在schema中,保持原样
132+ normalized [key ] = value
133+ continue
134+
135+ field_info = properties [key ]
136+ field_type = field_info .get ("type" )
137+
138+ # 处理 anyOf 类型(例如 Optional[int] 会生成 anyOf)
139+ any_of = field_info .get ("anyOf" )
140+ if any_of and not field_type :
141+ # 从 anyOf 中提取实际类型
142+ for type_option in any_of :
143+ if "type" in type_option and type_option ["type" ] != "null" :
144+ field_type = type_option ["type" ]
145+ break
146+
147+ # 根据类型进行转换
148+ if field_type == "integer" and isinstance (value , str ):
149+ try :
150+ normalized [key ] = int (value )
151+ except (ValueError , TypeError ):
152+ logger .warning (f"无法将参数 { key } ='{ value } ' 转换为整数,保持原值" )
153+ normalized [key ] = value
154+ elif field_type == "number" and isinstance (value , str ):
155+ try :
156+ normalized [key ] = float (value )
157+ except (ValueError , TypeError ):
158+ logger .warning (f"无法将参数 { key } ='{ value } ' 转换为浮点数,保持原值" )
159+ normalized [key ] = value
160+ elif field_type == "boolean" and isinstance (value , str ):
161+ # 转换字符串为布尔值
162+ normalized [key ] = value .lower () in ("true" , "1" , "yes" , "on" )
163+ else :
164+ # 其他类型保持原样
165+ normalized [key ] = value
166+
167+ return normalized
168+
103169 async def call_tool (self , tool_name : str , arguments : Dict [str , Any ]) -> str :
104170 """
105171 调用工具
@@ -120,8 +186,11 @@ async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
120186 return error_msg
121187
122188 try :
189+ # 规范化参数类型
190+ normalized_arguments = self ._normalize_arguments (tool_instance , arguments )
191+
123192 # 调用工具的run方法
124- result = await tool_instance .run (** arguments )
193+ result = await tool_instance .run (** normalized_arguments )
125194
126195 # 确保返回字符串
127196 if isinstance (result , str ):
0 commit comments