Skip to content

Commit cca9365

Browse files
[Feat] Allow custom argument scope in tool descriptions (#292)
* update `tool_api` * update `tool_api` * update `get_plugin_prompt`
1 parent 0cd5df8 commit cca9365

File tree

2 files changed

+43
-63
lines changed

2 files changed

+43
-63
lines changed

lagent/actions/base_action.py

Lines changed: 42 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from abc import ABCMeta
55
from copy import deepcopy
66
from functools import wraps
7-
from typing import Callable, Optional, Type, get_args, get_origin
7+
from typing import Callable, Iterable, Optional, Type, get_args, get_origin
88

99
try:
1010
from typing import Annotated
@@ -24,11 +24,15 @@
2424
logging.getLogger('griffe').setLevel(logging.ERROR)
2525

2626

27-
def tool_api(func: Optional[Callable] = None,
28-
*,
29-
explode_return: bool = False,
30-
returns_named_value: bool = False,
31-
**kwargs):
27+
def tool_api(
28+
func: Optional[Callable] = None,
29+
*,
30+
explode_return: bool = False,
31+
returns_named_value: bool = False,
32+
include_arguments: Optional[Iterable[str]] = None,
33+
exclude_arguments: Optional[Iterable[str]] = None,
34+
**kwargs,
35+
):
3236
"""Turn functions into tools. It will parse typehints as well as docstrings
3337
to build the tool description and attach it to functions via an attribute
3438
``api_description``.
@@ -90,6 +94,16 @@ def foo(a, b):
9094
``return_data`` field will be added to ``api_description`` only
9195
when ``explode_return`` or ``returns_named_value`` is enabled.
9296
"""
97+
if include_arguments is None:
98+
exclude_arguments = exclude_arguments or set()
99+
if isinstance(exclude_arguments, str):
100+
exclude_arguments = {exclude_arguments}
101+
elif not isinstance(exclude_arguments, set):
102+
exclude_arguments = set(exclude_arguments)
103+
if 'self' not in exclude_arguments:
104+
exclude_arguments.add('self')
105+
else:
106+
include_arguments = {include_arguments} if isinstance(include_arguments, str) else set(include_arguments)
93107

94108
def _detect_type(string):
95109
field_type = 'STRING'
@@ -106,10 +120,9 @@ def _detect_type(string):
106120

107121
def _explode(desc):
108122
kvs = []
109-
desc = '\nArgs:\n' + '\n'.join([
110-
' ' + item.lstrip(' -+*#.')
111-
for item in desc.split('\n')[1:] if item.strip()
112-
])
123+
desc = '\nArgs:\n' + '\n'.join(
124+
[' ' + item.lstrip(' -+*#.') for item in desc.split('\n')[1:] if item.strip()]
125+
)
113126
docs = Docstring(desc).parse('google')
114127
if not docs:
115128
return kvs
@@ -125,13 +138,12 @@ def _explode(desc):
125138

126139
def _parse_tool(function):
127140
# remove rst syntax
128-
docs = Docstring(
129-
re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse(
130-
'google', returns_named_value=returns_named_value, **kwargs)
141+
docs = Docstring(re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse(
142+
'google', returns_named_value=returns_named_value, **kwargs
143+
)
131144
desc = dict(
132145
name=function.__name__,
133-
description=docs[0].value
134-
if docs[0].kind is DocstringSectionKind.text else '',
146+
description=docs[0].value if docs[0].kind is DocstringSectionKind.text else '',
135147
parameters=[],
136148
required=[],
137149
)
@@ -155,17 +167,14 @@ def _parse_tool(function):
155167

156168
sig = inspect.signature(function)
157169
for name, param in sig.parameters.items():
158-
if name == 'self':
170+
if name in exclude_arguments if include_arguments is None else name not in include_arguments:
159171
continue
160172
parameter = dict(
161-
name=param.name,
162-
type='STRING',
163-
description=args_doc.get(param.name,
164-
{}).get('description', ''))
173+
name=param.name, type='STRING', description=args_doc.get(param.name, {}).get('description', '')
174+
)
165175
annotation = param.annotation
166176
if annotation is inspect.Signature.empty:
167-
parameter['type'] = args_doc.get(param.name,
168-
{}).get('type', 'STRING')
177+
parameter['type'] = args_doc.get(param.name, {}).get('type', 'STRING')
169178
else:
170179
if get_origin(annotation) is Annotated:
171180
annotation, info = get_args(annotation)
@@ -229,9 +238,8 @@ class ToolMeta(ABCMeta):
229238

230239
def __new__(mcs, name, base, attrs):
231240
is_toolkit, tool_desc = True, dict(
232-
name=name,
233-
description=Docstring(attrs.get('__doc__',
234-
'')).parse('google')[0].value)
241+
name=name, description=Docstring(attrs.get('__doc__', '')).parse('google')[0].value
242+
)
235243
for key, value in attrs.items():
236244
if callable(value) and hasattr(value, 'api_description'):
237245
api_desc = getattr(value, 'api_description')
@@ -246,8 +254,7 @@ def __new__(mcs, name, base, attrs):
246254
else:
247255
tool_desc.setdefault('api_list', []).append(api_desc)
248256
if not is_toolkit and 'api_list' in tool_desc:
249-
raise KeyError('`run` and other tool APIs can not be implemented '
250-
'at the same time')
257+
raise KeyError('`run` and other tool APIs can not be implemented ' 'at the same time')
251258
if is_toolkit and 'api_list' not in tool_desc:
252259
is_toolkit = False
253260
if callable(attrs.get('run')):
@@ -346,26 +353,16 @@ def __call__(self, inputs: str, name='run') -> ActionReturn:
346353
fallback_args = {'inputs': inputs, 'name': name}
347354
if not hasattr(self, name):
348355
return ActionReturn(
349-
fallback_args,
350-
type=self.name,
351-
errmsg=f'invalid API: {name}',
352-
state=ActionStatusCode.API_ERROR)
356+
fallback_args, type=self.name, errmsg=f'invalid API: {name}', state=ActionStatusCode.API_ERROR
357+
)
353358
try:
354359
inputs = self._parser.parse_inputs(inputs, name)
355360
except ParseError as exc:
356-
return ActionReturn(
357-
fallback_args,
358-
type=self.name,
359-
errmsg=exc.err_msg,
360-
state=ActionStatusCode.ARGS_ERROR)
361+
return ActionReturn(fallback_args, type=self.name, errmsg=exc.err_msg, state=ActionStatusCode.ARGS_ERROR)
361362
try:
362363
outputs = getattr(self, name)(**inputs)
363364
except Exception as exc:
364-
return ActionReturn(
365-
inputs,
366-
type=self.name,
367-
errmsg=str(exc),
368-
state=ActionStatusCode.API_ERROR)
365+
return ActionReturn(inputs, type=self.name, errmsg=str(exc), state=ActionStatusCode.API_ERROR)
369366
if isinstance(outputs, ActionReturn):
370367
action_return = outputs
371368
if not action_return.args:
@@ -402,26 +399,16 @@ async def __call__(self, inputs: str, name='run') -> ActionReturn:
402399
fallback_args = {'inputs': inputs, 'name': name}
403400
if not hasattr(self, name):
404401
return ActionReturn(
405-
fallback_args,
406-
type=self.name,
407-
errmsg=f'invalid API: {name}',
408-
state=ActionStatusCode.API_ERROR)
402+
fallback_args, type=self.name, errmsg=f'invalid API: {name}', state=ActionStatusCode.API_ERROR
403+
)
409404
try:
410405
inputs = self._parser.parse_inputs(inputs, name)
411406
except ParseError as exc:
412-
return ActionReturn(
413-
fallback_args,
414-
type=self.name,
415-
errmsg=exc.err_msg,
416-
state=ActionStatusCode.ARGS_ERROR)
407+
return ActionReturn(fallback_args, type=self.name, errmsg=exc.err_msg, state=ActionStatusCode.ARGS_ERROR)
417408
try:
418409
outputs = await getattr(self, name)(**inputs)
419410
except Exception as exc:
420-
return ActionReturn(
421-
inputs,
422-
type=self.name,
423-
errmsg=str(exc),
424-
state=ActionStatusCode.API_ERROR)
411+
return ActionReturn(inputs, type=self.name, errmsg=str(exc), state=ActionStatusCode.API_ERROR)
425412
if isinstance(outputs, ActionReturn):
426413
action_return = outputs
427414
if not action_return.args:

lagent/agents/stream.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939

4040

41-
def get_plugin_prompt(actions, api_desc_template=API_PREFIX):
41+
def get_plugin_prompt(actions, api_desc_template='{description}'):
4242
plugin_descriptions = []
4343
for action in actions if isinstance(actions, list) else [actions]:
4444
action = create_object(action)
@@ -47,15 +47,8 @@ def get_plugin_prompt(actions, api_desc_template=API_PREFIX):
4747
for api in action_desc['api_list']:
4848
api['name'] = f"{action.name}.{api['name']}"
4949
api['description'] = api_desc_template.format(tool_name=action.name, description=api['description'])
50-
api['parameters'] = [param for param in api['parameters'] if param['name'] in api['required']]
5150
plugin_descriptions.append(api)
5251
else:
53-
action_desc['description'] = api_desc_template.format(
54-
tool_name=action.name, description=action_desc['description']
55-
)
56-
action_desc['parameters'] = [
57-
param for param in action_desc['parameters'] if param['name'] in action_desc['required']
58-
]
5952
plugin_descriptions.append(action_desc)
6053
return json.dumps(plugin_descriptions, ensure_ascii=False, indent=4)
6154

0 commit comments

Comments
 (0)