Skip to content

Commit 52b1dfe

Browse files
sararobcopybara-github
authored andcommitted
chore: Add abstract type annotation support to AFC
PiperOrigin-RevId: 831462920
1 parent 22c6dbe commit 52b1dfe

File tree

3 files changed

+215
-21
lines changed

3 files changed

+215
-21
lines changed

src/google/adk/tools/_automatic_function_calling_util.py

Lines changed: 80 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -296,20 +296,55 @@ def from_function_with_options(
296296
) -> 'types.FunctionDeclaration':
297297

298298
parameters_properties = {}
299-
for name, param in inspect.signature(func).parameters.items():
300-
if param.kind in (
301-
inspect.Parameter.POSITIONAL_OR_KEYWORD,
302-
inspect.Parameter.KEYWORD_ONLY,
303-
inspect.Parameter.POSITIONAL_ONLY,
304-
):
305-
# This snippet catches the case when type hints are stored as strings
306-
if isinstance(param.annotation, str):
307-
param = param.replace(annotation=typing.get_type_hints(func)[name])
308-
309-
schema = _function_parameter_parse_util._parse_schema_from_parameter(
310-
variant, param, func.__name__
311-
)
312-
parameters_properties[name] = schema
299+
parameters_json_schema = {}
300+
annotation_under_future = typing.get_type_hints(func)
301+
try:
302+
for name, param in inspect.signature(func).parameters.items():
303+
if param.kind in (
304+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
305+
inspect.Parameter.KEYWORD_ONLY,
306+
inspect.Parameter.POSITIONAL_ONLY,
307+
):
308+
param = _function_parameter_parse_util._handle_params_as_deferred_annotations(
309+
param, annotation_under_future, name
310+
)
311+
312+
schema = _function_parameter_parse_util._parse_schema_from_parameter(
313+
variant, param, func.__name__
314+
)
315+
parameters_properties[name] = schema
316+
except ValueError:
317+
# If the function has complex parameter types that fail in _parse_schema_from_parameter,
318+
# we try to generate a json schema for the parameter using pydantic.TypeAdapter.
319+
parameters_properties = {}
320+
for name, param in inspect.signature(func).parameters.items():
321+
if param.kind in (
322+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
323+
inspect.Parameter.KEYWORD_ONLY,
324+
inspect.Parameter.POSITIONAL_ONLY,
325+
):
326+
try:
327+
if param.annotation == inspect.Parameter.empty:
328+
param = param.replace(annotation=Any)
329+
330+
param = _function_parameter_parse_util._handle_params_as_deferred_annotations(
331+
param, annotation_under_future, name
332+
)
333+
334+
_function_parameter_parse_util._raise_for_invalid_enum_value(param)
335+
336+
json_schema_dict = _function_parameter_parse_util._generate_json_schema_for_parameter(
337+
param
338+
)
339+
340+
parameters_json_schema[name] = types.Schema.model_validate(
341+
json_schema_dict
342+
)
343+
except Exception as e:
344+
_function_parameter_parse_util._raise_for_unsupported_param(
345+
param, func.__name__, e
346+
)
347+
313348
declaration = types.FunctionDeclaration(
314349
name=func.__name__,
315350
description=func.__doc__,
@@ -324,6 +359,12 @@ def from_function_with_options(
324359
declaration.parameters
325360
)
326361
)
362+
elif parameters_json_schema:
363+
declaration.parameters = types.Schema(
364+
type='OBJECT',
365+
properties=parameters_json_schema,
366+
)
367+
327368
if variant == GoogleLLMVariant.GEMINI_API:
328369
return declaration
329370

@@ -372,17 +413,35 @@ def from_function_with_options(
372413
inspect.Parameter.POSITIONAL_OR_KEYWORD,
373414
annotation=return_annotation,
374415
)
375-
# This snippet catches the case when type hints are stored as strings
376416
if isinstance(return_value.annotation, str):
377417
return_value = return_value.replace(
378418
annotation=typing.get_type_hints(func)['return']
379419
)
380420

381-
declaration.response = (
382-
_function_parameter_parse_util._parse_schema_from_parameter(
383-
variant,
384-
return_value,
385-
func.__name__,
421+
response_schema: Optional[types.Schema] = None
422+
response_json_schema: Optional[Union[Dict[str, Any], types.Schema]] = None
423+
try:
424+
response_schema = (
425+
_function_parameter_parse_util._parse_schema_from_parameter(
426+
variant,
427+
return_value,
428+
func.__name__,
429+
)
430+
)
431+
except ValueError:
432+
try:
433+
response_json_schema = (
434+
_function_parameter_parse_util._generate_json_schema_for_parameter(
435+
return_value
436+
)
386437
)
387-
)
438+
response_json_schema = types.Schema.model_validate(response_json_schema)
439+
except Exception as e:
440+
_function_parameter_parse_util._raise_for_unsupported_param(
441+
return_value, func.__name__, e
442+
)
443+
if response_schema:
444+
declaration.response = response_schema
445+
elif response_json_schema:
446+
declaration.response = response_json_schema
388447
return declaration

src/google/adk/tools/_function_parameter_parse_util.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,91 @@
4949
logger = logging.getLogger('google_adk.' + __name__)
5050

5151

52+
def _handle_params_as_deferred_annotations(
53+
param: inspect.Parameter, annotation_under_future: dict[str, Any], name: str
54+
) -> inspect.Parameter:
55+
"""Catches the case when type hints are stored as strings."""
56+
if isinstance(param.annotation, str):
57+
param = param.replace(annotation=annotation_under_future[name])
58+
return param
59+
60+
61+
def _add_unevaluated_items_to_fixed_len_tuple_schema(
62+
json_schema: dict[str, Any],
63+
) -> dict[str, Any]:
64+
"""Adds 'unevaluatedItems': False to schemas for fixed-length tuples.
65+
66+
For example, the schema for a parameter of type `tuple[float, float]` would
67+
be:
68+
{
69+
"type": "array",
70+
"prefixItems": [
71+
{
72+
"type": "number"
73+
},
74+
{
75+
"type": "number"
76+
},
77+
],
78+
"minItems": 2,
79+
"maxItems": 2,
80+
"unevaluatedItems": False
81+
}
82+
83+
"""
84+
if (
85+
json_schema.get('maxItems')
86+
and (
87+
json_schema.get('prefixItems')
88+
and len(json_schema['prefixItems']) == json_schema['maxItems']
89+
)
90+
and json_schema.get('type') == 'array'
91+
):
92+
json_schema['unevaluatedItems'] = False
93+
return json_schema
94+
95+
96+
def _raise_for_unsupported_param(
97+
param: inspect.Parameter,
98+
func_name: str,
99+
exception: Exception,
100+
) -> None:
101+
raise ValueError(
102+
f'Failed to parse the parameter {param} of function {func_name} for'
103+
' automatic function calling.Automatic function calling works best with'
104+
' simpler function signature schema, consider manually parsing your'
105+
f' function declaration for function {func_name}.'
106+
) from exception
107+
108+
109+
def _raise_for_invalid_enum_value(param: inspect.Parameter):
110+
"""Raises an error if the default value is not a valid enum value."""
111+
if inspect.isclass(param.annotation) and issubclass(param.annotation, Enum):
112+
if param.default is not inspect.Parameter.empty and param.default not in [
113+
e.value for e in param.annotation
114+
]:
115+
raise ValueError(
116+
f'Default value {param.default} is not a valid enum value for'
117+
f' {param.annotation}.'
118+
)
119+
120+
121+
def _generate_json_schema_for_parameter(
122+
param: inspect.Parameter,
123+
) -> dict[str, Any]:
124+
"""Generates a JSON schema for a parameter using pydantic.TypeAdapter."""
125+
126+
param_schema_adapter = pydantic.TypeAdapter(
127+
param.annotation,
128+
config=pydantic.ConfigDict(arbitrary_types_allowed=True),
129+
)
130+
json_schema_dict = param_schema_adapter.json_schema()
131+
json_schema_dict = _add_unevaluated_items_to_fixed_len_tuple_schema(
132+
json_schema_dict
133+
)
134+
return json_schema_dict
135+
136+
52137
def _is_builtin_primitive_or_compound(
53138
annotation: inspect.Parameter.annotation,
54139
) -> bool:

tests/unittests/tools/test_from_function_with_options.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections.abc import Sequence
1516
from typing import Any
1617
from typing import Dict
1718

@@ -192,3 +193,52 @@ def test_function() -> None:
192193
# VERTEX_AI should have response schema for None return
193194
assert declaration.response is not None
194195
assert declaration.response.type == types.Type.NULL
196+
197+
198+
def test_from_function_with_collections_type_parameter():
199+
"""Test from_function_with_options with collections type parameter."""
200+
201+
def test_function(
202+
artifact_key: str,
203+
input_edit_ids: Sequence[str],
204+
) -> str:
205+
"""Saves a sequence of edit IDs."""
206+
return f'Saved {len(input_edit_ids)} edit IDs for artifact {artifact_key}'
207+
208+
declaration = _automatic_function_calling_util.from_function_with_options(
209+
test_function, GoogleLLMVariant.VERTEX_AI
210+
)
211+
212+
assert declaration.name == 'test_function'
213+
assert declaration.parameters.type == types.Type.OBJECT
214+
assert (
215+
declaration.parameters.properties['artifact_key'].type
216+
== types.Type.STRING
217+
)
218+
assert (
219+
declaration.parameters.properties['input_edit_ids'].type
220+
== types.Type.ARRAY
221+
)
222+
assert (
223+
declaration.parameters.properties['input_edit_ids'].items.type
224+
== types.Type.STRING
225+
)
226+
assert declaration.response.type == types.Type.STRING
227+
228+
229+
def test_from_function_with_collections_return_type():
230+
"""Test from_function_with_options with collections return type."""
231+
232+
def test_function(
233+
names: list[str],
234+
) -> Sequence[str]:
235+
"""Returns a sequence of names."""
236+
return names
237+
238+
declaration = _automatic_function_calling_util.from_function_with_options(
239+
test_function, GoogleLLMVariant.VERTEX_AI
240+
)
241+
242+
assert declaration.name == 'test_function'
243+
assert declaration.response.type == types.Type.ARRAY
244+
assert declaration.response.items.type == types.Type.STRING

0 commit comments

Comments
 (0)