Skip to content

Commit 19555e7

Browse files
seanzhougooglecopybara-github
authored andcommitted
fix: Support Generator and Async Generator tool declaration in JSON schema
Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com> PiperOrigin-RevId: 856713741
1 parent ed2c3eb commit 19555e7

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed

src/google/adk/tools/_function_tool_declarations.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@
2424

2525
from __future__ import annotations
2626

27+
import collections.abc
2728
import inspect
2829
import logging
2930
from typing import Any
3031
from typing import Callable
32+
from typing import get_args
33+
from typing import get_origin
3134
from typing import get_type_hints
3235
from typing import Optional
3336
from typing import Type
@@ -145,6 +148,19 @@ def _build_response_json_schema(
145148
except TypeError:
146149
pass
147150

151+
# Handle AsyncGenerator and Generator return types (streaming tools)
152+
# AsyncGenerator[YieldType, SendType] -> use YieldType as response schema
153+
# Generator[YieldType, SendType, ReturnType] -> use YieldType as response schema
154+
origin = get_origin(return_annotation)
155+
if origin is not None and (
156+
origin is collections.abc.AsyncGenerator
157+
or origin is collections.abc.Generator
158+
):
159+
type_args = get_args(return_annotation)
160+
if type_args:
161+
# First type argument is the yield type
162+
return_annotation = type_args[0]
163+
148164
try:
149165
adapter = pydantic.TypeAdapter(
150166
return_annotation,

tests/unittests/tools/test_function_tool_declarations.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from collections.abc import Sequence
2424
from enum import Enum
2525
from typing import Any
26+
from typing import AsyncGenerator
27+
from typing import Generator
2628
from typing import Literal
2729
from typing import Optional
2830

@@ -840,3 +842,81 @@ class CreateUserRequest(BaseModel):
840842
# When passing a BaseModel, there is no function return, so response schema
841843
# is None
842844
self.assertIsNone(decl.response_json_schema)
845+
846+
847+
class TestStreamingReturnTypes(parameterized.TestCase):
848+
"""Tests for AsyncGenerator and Generator return types (streaming tools)."""
849+
850+
def test_async_generator_string_yield(self):
851+
"""Test AsyncGenerator[str, None] return type extracts str as response."""
852+
853+
async def streaming_tool(param: str) -> AsyncGenerator[str, None]:
854+
"""A streaming tool that yields strings."""
855+
yield param
856+
857+
decl = build_function_declaration_with_json_schema(streaming_tool)
858+
859+
self.assertEqual(decl.name, "streaming_tool")
860+
self.assertIsNotNone(decl.parameters_json_schema)
861+
self.assertEqual(
862+
decl.parameters_json_schema["properties"]["param"]["type"], "string"
863+
)
864+
# Should extract str from AsyncGenerator[str, None]
865+
self.assertEqual(decl.response_json_schema, {"type": "string"})
866+
867+
def test_async_generator_int_yield(self):
868+
"""Test AsyncGenerator[int, None] return type extracts int as response."""
869+
870+
async def counter(start: int) -> AsyncGenerator[int, None]:
871+
"""A streaming counter."""
872+
yield start
873+
874+
decl = build_function_declaration_with_json_schema(counter)
875+
876+
self.assertEqual(decl.name, "counter")
877+
# Should extract int from AsyncGenerator[int, None]
878+
self.assertEqual(decl.response_json_schema, {"type": "integer"})
879+
880+
def test_async_generator_dict_yield(self):
881+
"""Test AsyncGenerator[dict[str, str], None] return type."""
882+
883+
async def streaming_dict(
884+
param: str,
885+
) -> AsyncGenerator[dict[str, str], None]:
886+
"""A streaming tool that yields dicts."""
887+
yield {"result": param}
888+
889+
decl = build_function_declaration_with_json_schema(streaming_dict)
890+
891+
self.assertEqual(decl.name, "streaming_dict")
892+
# Should extract dict[str, str] from AsyncGenerator
893+
self.assertEqual(
894+
decl.response_json_schema,
895+
{"additionalProperties": {"type": "string"}, "type": "object"},
896+
)
897+
898+
def test_generator_string_yield(self):
899+
"""Test Generator[str, None, None] return type extracts str as response."""
900+
901+
def sync_streaming_tool(param: str) -> Generator[str, None, None]:
902+
"""A sync streaming tool that yields strings."""
903+
yield param
904+
905+
decl = build_function_declaration_with_json_schema(sync_streaming_tool)
906+
907+
self.assertEqual(decl.name, "sync_streaming_tool")
908+
# Should extract str from Generator[str, None, None]
909+
self.assertEqual(decl.response_json_schema, {"type": "string"})
910+
911+
def test_generator_int_yield(self):
912+
"""Test Generator[int, None, None] return type extracts int as response."""
913+
914+
def sync_counter(start: int) -> Generator[int, None, None]:
915+
"""A sync streaming counter."""
916+
yield start
917+
918+
decl = build_function_declaration_with_json_schema(sync_counter)
919+
920+
self.assertEqual(decl.name, "sync_counter")
921+
# Should extract int from Generator[int, None, None]
922+
self.assertEqual(decl.response_json_schema, {"type": "integer"})

0 commit comments

Comments
 (0)