Skip to content

Commit ef91d8d

Browse files
qbc2016DavdGao
andauthored
fix(Gemini): fix the bug that Gemini LLMs doesn't support nested JSON schema in its tools API (#1050)
--------- Co-authored-by: DavdGao <[email protected]>
1 parent a592492 commit ef91d8d

File tree

2 files changed

+160
-14
lines changed

2 files changed

+160
-14
lines changed

src/agentscope/model/_gemini_model.py

Lines changed: 101 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22
# mypy: disable-error-code="dict-item"
33
"""The Google Gemini model in agentscope."""
4+
import copy
45
import warnings
56
from datetime import datetime
67
from typing import (
@@ -30,6 +31,85 @@
3031
GenerateContentResponse = "google.genai.types.GenerateContentResponse"
3132

3233

34+
def _flatten_json_schema(schema: dict) -> dict:
35+
"""Flatten a JSON schema by resolving all $ref references.
36+
37+
.. note::
38+
Gemini API does not support `$defs` and `$ref` in JSON schemas.
39+
This function resolves all `$ref` references by inlining the
40+
referenced definitions, producing a self-contained schema without
41+
any references.
42+
43+
Args:
44+
schema (`dict`):
45+
The JSON schema that may contain `$defs` and `$ref` references.
46+
47+
Returns:
48+
`dict`:
49+
A flattened JSON schema with all references resolved inline.
50+
"""
51+
# Deep copy to avoid modifying the original schema
52+
schema = copy.deepcopy(schema)
53+
54+
# Extract $defs if present
55+
defs = schema.pop("$defs", {})
56+
57+
def _resolve_ref(obj: Any, visited: set | None = None) -> Any:
58+
"""Recursively resolve $ref references in the schema."""
59+
if visited is None:
60+
visited = set()
61+
62+
if not isinstance(obj, dict):
63+
if isinstance(obj, list):
64+
return [_resolve_ref(item, visited.copy()) for item in obj]
65+
return obj
66+
67+
# Handle $ref
68+
if "$ref" in obj:
69+
ref_path = obj["$ref"]
70+
# Extract definition name from "#/$defs/DefinitionName"
71+
if ref_path.startswith("#/$defs/"):
72+
def_name = ref_path[len("#/$defs/") :]
73+
74+
# Prevent infinite recursion for circular references
75+
if def_name in visited:
76+
logger.warning(
77+
"Circular reference detected for '%s' in tool schema",
78+
def_name,
79+
)
80+
return {
81+
"type": "object",
82+
"description": f"(circular: {def_name})",
83+
}
84+
85+
visited.add(def_name)
86+
87+
if def_name in defs:
88+
# Recursively resolve any nested refs in the definition
89+
resolved = _resolve_ref(
90+
defs[def_name],
91+
visited.copy(),
92+
)
93+
# Merge any additional properties from the original object
94+
# (excluding $ref itself)
95+
for key, value in obj.items():
96+
if key != "$ref":
97+
resolved[key] = _resolve_ref(value, visited.copy())
98+
return resolved
99+
100+
# If we can't resolve the ref, return as-is (shouldn't happen)
101+
return obj
102+
103+
# Recursively process all nested objects
104+
result = {}
105+
for key, value in obj.items():
106+
result[key] = _resolve_ref(value, visited.copy())
107+
108+
return result
109+
110+
return _resolve_ref(schema)
111+
112+
33113
class GeminiChatModel(ChatModelBase):
34114
"""The Google Gemini chat model class in agentscope."""
35115

@@ -310,11 +390,7 @@ async def _parse_gemini_stream_generation_response(
310390
),
311391
)
312392

313-
content_block.extend(
314-
[
315-
*tool_calls,
316-
],
317-
)
393+
content_block.extend(tool_calls)
318394

319395
parsed_chunk = ChatResponse(
320396
content=content_block,
@@ -335,8 +411,8 @@ def _parse_gemini_generation_response(
335411
Args:
336412
start_datetime (`datetime`):
337413
The start datetime of the response generation.
338-
response (`ChatCompletion`):
339-
The OpenAI chat completion response object to parse.
414+
response (`GenerateContentResponse`):
415+
The Gemini generation response object to parse.
340416
structured_model (`Type[BaseModel] | None`, default `None`):
341417
A Pydantic BaseModel class that defines the expected structure
342418
for the model's output.
@@ -410,6 +486,11 @@ def _format_tools_json_schemas(
410486
) -> list[dict[str, Any]]:
411487
"""Format the tools JSON schema into required format for Gemini API.
412488
489+
.. note:: Gemini API does not support `$defs` and `$ref` in JSON
490+
schemas. This function resolves all `$ref` references by inlining the
491+
referenced definitions, producing a self-contained schema without
492+
any references.
493+
413494
Args:
414495
schemas (`dict[str, Any]`):
415496
The tools JSON schemas.
@@ -474,14 +555,19 @@ def _format_tools_json_schemas(
474555
]
475556
}
476557
]
558+
477559
"""
478-
return [
479-
{
480-
"function_declarations": [
481-
_["function"] for _ in schemas if "function" in _
482-
],
483-
},
484-
]
560+
function_declarations = []
561+
for schema in schemas:
562+
if "function" not in schema:
563+
continue
564+
func = schema["function"].copy()
565+
# Flatten the parameters schema to resolve $ref references
566+
if "parameters" in func:
567+
func["parameters"] = _flatten_json_schema(func["parameters"])
568+
function_declarations.append(func)
569+
570+
return [{"function_declarations": function_declarations}]
485571

486572
def _format_tool_choice(
487573
self,
@@ -496,6 +582,7 @@ def _format_tool_choice(
496582
Can be "auto", "none", "required", or specific tool name.
497583
For more details, please refer to
498584
https://ai.google.dev/gemini-api/docs/function-calling?hl=en&example=meeting#function_calling_modes
585+
499586
Returns:
500587
`dict | None`:
501588
The formatted tool choice configuration dict, or None if

tests/model_gemini_test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,65 @@ async def test_generate_kwargs_integration(self) -> None:
356356
self.assertEqual(call_args["config"]["top_p"], 0.9)
357357
self.assertEqual(call_args["config"]["top_k"], 40)
358358

359+
def test_format_tools_with_nested_schema(self) -> None:
360+
"""Test formatting tools with nested JSON schema ($defs and $ref)."""
361+
model = GeminiChatModel(
362+
model_name="gemini-2.5-flash",
363+
api_key="test_key",
364+
)
365+
366+
nested_schema = {
367+
"type": "object",
368+
"properties": {
369+
"person": {"$ref": "#/$defs/Person"},
370+
"location": {"type": "string"},
371+
},
372+
"required": ["person"],
373+
"$defs": {
374+
"Person": {
375+
"type": "object",
376+
"properties": {
377+
"name": {"type": "string"},
378+
"age": {"type": "integer"},
379+
},
380+
"required": ["name"],
381+
},
382+
},
383+
}
384+
385+
tools = [
386+
{
387+
"type": "function",
388+
"function": {
389+
"name": "process_person",
390+
"description": "Process person info",
391+
"parameters": nested_schema,
392+
},
393+
},
394+
]
395+
396+
# pylint: disable=protected-access
397+
formatted_tools = model._format_tools_json_schemas(tools)
398+
399+
# Check if $ref is resolved
400+
params = formatted_tools[0]["function_declarations"][0]["parameters"]
401+
expected_params = {
402+
"type": "object",
403+
"properties": {
404+
"person": {
405+
"type": "object",
406+
"properties": {
407+
"name": {"type": "string"},
408+
"age": {"type": "integer"},
409+
},
410+
"required": ["name"],
411+
},
412+
"location": {"type": "string"},
413+
},
414+
"required": ["person"],
415+
}
416+
self.assertEqual(params, expected_params)
417+
359418
# Auxiliary methods
360419
def _create_mock_response(
361420
self,

0 commit comments

Comments
 (0)