44from abc import ABCMeta
55from copy import deepcopy
66from functools import wraps
7- from typing import Callable , Optional , Type , get_args , get_origin
7+ from typing import Callable , Iterable , Optional , Type , get_args , get_origin
88
99try :
1010 from typing import Annotated
2424logging .getLogger ('griffe' ).setLevel (logging .ERROR )
2525
2626
27- def tool_api (func : Optional [Callable ] = None ,
28- * ,
29- explode_return : bool = False ,
30- returns_named_value : bool = False ,
31- ** kwargs ):
27+ def tool_api (
28+ func : Optional [Callable ] = None ,
29+ * ,
30+ explode_return : bool = False ,
31+ returns_named_value : bool = False ,
32+ include_arguments : Optional [Iterable [str ]] = None ,
33+ exclude_arguments : Optional [Iterable [str ]] = None ,
34+ ** kwargs ,
35+ ):
3236 """Turn functions into tools. It will parse typehints as well as docstrings
3337 to build the tool description and attach it to functions via an attribute
3438 ``api_description``.
@@ -90,6 +94,16 @@ def foo(a, b):
9094 ``return_data`` field will be added to ``api_description`` only
9195 when ``explode_return`` or ``returns_named_value`` is enabled.
9296 """
97+ if include_arguments is None :
98+ exclude_arguments = exclude_arguments or set ()
99+ if isinstance (exclude_arguments , str ):
100+ exclude_arguments = {exclude_arguments }
101+ elif not isinstance (exclude_arguments , set ):
102+ exclude_arguments = set (exclude_arguments )
103+ if 'self' not in exclude_arguments :
104+ exclude_arguments .add ('self' )
105+ else :
106+ include_arguments = {include_arguments } if isinstance (include_arguments , str ) else set (include_arguments )
93107
94108 def _detect_type (string ):
95109 field_type = 'STRING'
@@ -106,10 +120,9 @@ def _detect_type(string):
106120
107121 def _explode (desc ):
108122 kvs = []
109- desc = '\n Args:\n ' + '\n ' .join ([
110- ' ' + item .lstrip (' -+*#.' )
111- for item in desc .split ('\n ' )[1 :] if item .strip ()
112- ])
123+ desc = '\n Args:\n ' + '\n ' .join (
124+ [' ' + item .lstrip (' -+*#.' ) for item in desc .split ('\n ' )[1 :] if item .strip ()]
125+ )
113126 docs = Docstring (desc ).parse ('google' )
114127 if not docs :
115128 return kvs
@@ -125,13 +138,12 @@ def _explode(desc):
125138
126139 def _parse_tool (function ):
127140 # remove rst syntax
128- docs = Docstring (
129- re . sub ( ':(.+?):`(.+?)` ' , ' \\ 2' , function . __doc__ or '' )). parse (
130- 'google' , returns_named_value = returns_named_value , ** kwargs )
141+ docs = Docstring (re . sub ( ':(.+?):`(.+?)`' , ' \\ 2' , function . __doc__ or '' )). parse (
142+ 'google ' , returns_named_value = returns_named_value , ** kwargs
143+ )
131144 desc = dict (
132145 name = function .__name__ ,
133- description = docs [0 ].value
134- if docs [0 ].kind is DocstringSectionKind .text else '' ,
146+ description = docs [0 ].value if docs [0 ].kind is DocstringSectionKind .text else '' ,
135147 parameters = [],
136148 required = [],
137149 )
@@ -155,17 +167,14 @@ def _parse_tool(function):
155167
156168 sig = inspect .signature (function )
157169 for name , param in sig .parameters .items ():
158- if name == 'self' :
170+ if name in exclude_arguments if include_arguments is None else name not in include_arguments :
159171 continue
160172 parameter = dict (
161- name = param .name ,
162- type = 'STRING' ,
163- description = args_doc .get (param .name ,
164- {}).get ('description' , '' ))
173+ name = param .name , type = 'STRING' , description = args_doc .get (param .name , {}).get ('description' , '' )
174+ )
165175 annotation = param .annotation
166176 if annotation is inspect .Signature .empty :
167- parameter ['type' ] = args_doc .get (param .name ,
168- {}).get ('type' , 'STRING' )
177+ parameter ['type' ] = args_doc .get (param .name , {}).get ('type' , 'STRING' )
169178 else :
170179 if get_origin (annotation ) is Annotated :
171180 annotation , info = get_args (annotation )
@@ -229,9 +238,8 @@ class ToolMeta(ABCMeta):
229238
230239 def __new__ (mcs , name , base , attrs ):
231240 is_toolkit , tool_desc = True , dict (
232- name = name ,
233- description = Docstring (attrs .get ('__doc__' ,
234- '' )).parse ('google' )[0 ].value )
241+ name = name , description = Docstring (attrs .get ('__doc__' , '' )).parse ('google' )[0 ].value
242+ )
235243 for key , value in attrs .items ():
236244 if callable (value ) and hasattr (value , 'api_description' ):
237245 api_desc = getattr (value , 'api_description' )
@@ -246,8 +254,7 @@ def __new__(mcs, name, base, attrs):
246254 else :
247255 tool_desc .setdefault ('api_list' , []).append (api_desc )
248256 if not is_toolkit and 'api_list' in tool_desc :
249- raise KeyError ('`run` and other tool APIs can not be implemented '
250- 'at the same time' )
257+ raise KeyError ('`run` and other tool APIs can not be implemented ' 'at the same time' )
251258 if is_toolkit and 'api_list' not in tool_desc :
252259 is_toolkit = False
253260 if callable (attrs .get ('run' )):
@@ -346,26 +353,16 @@ def __call__(self, inputs: str, name='run') -> ActionReturn:
346353 fallback_args = {'inputs' : inputs , 'name' : name }
347354 if not hasattr (self , name ):
348355 return ActionReturn (
349- fallback_args ,
350- type = self .name ,
351- errmsg = f'invalid API: { name } ' ,
352- state = ActionStatusCode .API_ERROR )
356+ fallback_args , type = self .name , errmsg = f'invalid API: { name } ' , state = ActionStatusCode .API_ERROR
357+ )
353358 try :
354359 inputs = self ._parser .parse_inputs (inputs , name )
355360 except ParseError as exc :
356- return ActionReturn (
357- fallback_args ,
358- type = self .name ,
359- errmsg = exc .err_msg ,
360- state = ActionStatusCode .ARGS_ERROR )
361+ return ActionReturn (fallback_args , type = self .name , errmsg = exc .err_msg , state = ActionStatusCode .ARGS_ERROR )
361362 try :
362363 outputs = getattr (self , name )(** inputs )
363364 except Exception as exc :
364- return ActionReturn (
365- inputs ,
366- type = self .name ,
367- errmsg = str (exc ),
368- state = ActionStatusCode .API_ERROR )
365+ return ActionReturn (inputs , type = self .name , errmsg = str (exc ), state = ActionStatusCode .API_ERROR )
369366 if isinstance (outputs , ActionReturn ):
370367 action_return = outputs
371368 if not action_return .args :
@@ -402,26 +399,16 @@ async def __call__(self, inputs: str, name='run') -> ActionReturn:
402399 fallback_args = {'inputs' : inputs , 'name' : name }
403400 if not hasattr (self , name ):
404401 return ActionReturn (
405- fallback_args ,
406- type = self .name ,
407- errmsg = f'invalid API: { name } ' ,
408- state = ActionStatusCode .API_ERROR )
402+ fallback_args , type = self .name , errmsg = f'invalid API: { name } ' , state = ActionStatusCode .API_ERROR
403+ )
409404 try :
410405 inputs = self ._parser .parse_inputs (inputs , name )
411406 except ParseError as exc :
412- return ActionReturn (
413- fallback_args ,
414- type = self .name ,
415- errmsg = exc .err_msg ,
416- state = ActionStatusCode .ARGS_ERROR )
407+ return ActionReturn (fallback_args , type = self .name , errmsg = exc .err_msg , state = ActionStatusCode .ARGS_ERROR )
417408 try :
418409 outputs = await getattr (self , name )(** inputs )
419410 except Exception as exc :
420- return ActionReturn (
421- inputs ,
422- type = self .name ,
423- errmsg = str (exc ),
424- state = ActionStatusCode .API_ERROR )
411+ return ActionReturn (inputs , type = self .name , errmsg = str (exc ), state = ActionStatusCode .API_ERROR )
425412 if isinstance (outputs , ActionReturn ):
426413 action_return = outputs
427414 if not action_return .args :
0 commit comments