Skip to content

Commit 7c28297

Browse files
seanzhougooglecopybara-github
authored andcommitted
fix: Support Generator and AsyncGenerator tool declaration
use yield type as return type Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com> PiperOrigin-RevId: 856459995
1 parent d4da1bb commit 7c28297

File tree

2 files changed

+94
-0
lines changed

2 files changed

+94
-0
lines changed

src/google/adk/tools/_automatic_function_calling_util.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@
1414

1515
from __future__ import annotations
1616

17+
import collections.abc
1718
import inspect
1819
from types import FunctionType
1920
import typing
2021
from typing import Any
2122
from typing import Callable
2223
from typing import Dict
24+
from typing import get_args
25+
from typing import get_origin
2326
from typing import Optional
2427
from typing import Union
2528

@@ -391,6 +394,20 @@ def from_function_with_options(
391394

392395
return_annotation = inspect.signature(func).return_annotation
393396

397+
# Handle AsyncGenerator and Generator return types (streaming tools)
398+
# AsyncGenerator[YieldType, SendType] -> use YieldType as response schema
399+
# Generator[YieldType, SendType, ReturnType] -> use YieldType as response schema
400+
origin = get_origin(return_annotation)
401+
if origin is not None and (
402+
origin is collections.abc.AsyncGenerator
403+
or origin is collections.abc.Generator
404+
):
405+
type_args = get_args(return_annotation)
406+
if type_args:
407+
# First type argument is the yield type
408+
yield_type = type_args[0]
409+
return_annotation = yield_type
410+
394411
# Handle functions with no return annotation
395412
if return_annotation is inspect._empty:
396413
# Functions with no return annotation can return any type

tests/unittests/tools/test_from_function_with_options.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
from collections.abc import Sequence
1616
from typing import Any
17+
from typing import AsyncGenerator
1718
from typing import Dict
19+
from typing import Generator
1820

1921
from google.adk.tools import _automatic_function_calling_util
2022
from google.adk.utils.variant_utils import GoogleLLMVariant
@@ -242,3 +244,78 @@ def test_function(
242244
assert declaration.name == 'test_function'
243245
assert declaration.response.type == types.Type.ARRAY
244246
assert declaration.response.items.type == types.Type.STRING
247+
248+
249+
def test_from_function_with_async_generator_return_vertex():
250+
"""Test from_function_with_options with AsyncGenerator return for VERTEX_AI."""
251+
252+
async def test_function(param: str) -> AsyncGenerator[str, None]:
253+
"""A streaming function that yields strings."""
254+
yield param
255+
256+
declaration = _automatic_function_calling_util.from_function_with_options(
257+
test_function, GoogleLLMVariant.VERTEX_AI
258+
)
259+
260+
assert declaration.name == 'test_function'
261+
assert declaration.parameters.type == 'OBJECT'
262+
assert declaration.parameters.properties['param'].type == 'STRING'
263+
# VERTEX_AI should extract yield type (str) from AsyncGenerator[str, None]
264+
assert declaration.response is not None
265+
assert declaration.response.type == types.Type.STRING
266+
267+
268+
def test_from_function_with_async_generator_return_gemini():
269+
"""Test from_function_with_options with AsyncGenerator return for GEMINI_API."""
270+
271+
async def test_function(param: str) -> AsyncGenerator[str, None]:
272+
"""A streaming function that yields strings."""
273+
yield param
274+
275+
declaration = _automatic_function_calling_util.from_function_with_options(
276+
test_function, GoogleLLMVariant.GEMINI_API
277+
)
278+
279+
assert declaration.name == 'test_function'
280+
assert declaration.parameters.type == 'OBJECT'
281+
assert declaration.parameters.properties['param'].type == 'STRING'
282+
# GEMINI_API should not have response schema
283+
assert declaration.response is None
284+
285+
286+
def test_from_function_with_generator_return_vertex():
287+
"""Test from_function_with_options with Generator return for VERTEX_AI."""
288+
289+
def test_function(param: str) -> Generator[int, None, None]:
290+
"""A streaming function that yields integers."""
291+
yield 42
292+
293+
declaration = _automatic_function_calling_util.from_function_with_options(
294+
test_function, GoogleLLMVariant.VERTEX_AI
295+
)
296+
297+
assert declaration.name == 'test_function'
298+
assert declaration.parameters.type == 'OBJECT'
299+
assert declaration.parameters.properties['param'].type == 'STRING'
300+
# VERTEX_AI should extract yield type (int) from Generator[int, None, None]
301+
assert declaration.response is not None
302+
assert declaration.response.type == types.Type.INTEGER
303+
304+
305+
def test_from_function_with_async_generator_complex_yield_type_vertex():
306+
"""Test from_function_with_options with AsyncGenerator yielding dict."""
307+
308+
async def test_function(param: str) -> AsyncGenerator[Dict[str, str], None]:
309+
"""A streaming function that yields dicts."""
310+
yield {'result': param}
311+
312+
declaration = _automatic_function_calling_util.from_function_with_options(
313+
test_function, GoogleLLMVariant.VERTEX_AI
314+
)
315+
316+
assert declaration.name == 'test_function'
317+
assert declaration.parameters.type == 'OBJECT'
318+
assert declaration.parameters.properties['param'].type == 'STRING'
319+
# VERTEX_AI should extract yield type (Dict[str, str]) from AsyncGenerator
320+
assert declaration.response is not None
321+
assert declaration.response.type == types.Type.OBJECT

0 commit comments

Comments
 (0)