Skip to content

Commit 2e19fac

Browse files
fix tool handling in OpenAIServingChat and add tests for Jinja tool schema behavior
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
1 parent 443b1a8 commit 2e19fac

File tree

2 files changed

+83
-6
lines changed

2 files changed

+83
-6
lines changed

python/sglang/srt/entrypoints/openai/serving_chat.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -338,12 +338,12 @@ def _process_messages(
338338
request.skip_special_tokens = False
339339
if not isinstance(request.tool_choice, str):
340340
tools = [
341-
item.function.model_dump()
341+
item.model_dump()
342342
for item in request.tools
343343
if item.function.name == request.tool_choice.function.name
344344
]
345345
else:
346-
tools = [item.function.model_dump() for item in request.tools]
346+
tools = [item.model_dump() for item in request.tools]
347347
if self.tool_call_parser:
348348
parser = FunctionCallParser(request.tools, self.tool_call_parser)
349349
tool_call_constraint = parser.get_structure_constraint(
@@ -481,11 +481,10 @@ def _apply_jinja_template(
481481
return_dict=False,
482482
)
483483
except Exception as e:
484-
# If the first attempt fails, try transforming the tools format
485-
# This handles models like Mistral that have a different tools input format
486-
# that is not compatible with OpenAI's apply_chat_template tool_call format
484+
# If the first attempt fails, try with flat function-only format.
485+
# Some templates (e.g. Mistral) expect tools without the OpenAI wrapper.
487486
tools = (
488-
[t if "function" in t else {"function": t} for t in tools]
487+
[t["function"] if "function" in t else t for t in tools]
489488
if tools
490489
else None
491490
)

test/registered/openai_server/basic/test_serving_chat.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,84 @@ def test_convert_to_internal_request_single(self):
133133
self.assertFalse(adapted.stream)
134134
self.assertEqual(processed, self.basic_req)
135135

136+
def test_jinja_uses_openai_tool_schema_first(self):
137+
"""Ensure Jinja chat templates receive OpenAI-shaped tools by default."""
138+
self.template_manager.chat_template_name = None
139+
self.template_manager.jinja_template_content_format = "string"
140+
141+
req = ChatCompletionRequest(
142+
model="x",
143+
messages=[{"role": "user", "content": "What is 2+2?"}],
144+
tools=[
145+
{
146+
"type": "function",
147+
"function": {
148+
"name": "add",
149+
"description": "Add two numbers.",
150+
"parameters": {
151+
"type": "object",
152+
"properties": {
153+
"a": {"type": "integer"},
154+
"b": {"type": "integer"},
155+
},
156+
"required": ["a", "b"],
157+
},
158+
},
159+
}
160+
],
161+
)
162+
163+
self.chat._process_messages(req, is_multimodal=False)
164+
165+
expected_tools = [tool.model_dump() for tool in req.tools]
166+
kwargs = self.tm.tokenizer.apply_chat_template.call_args.kwargs
167+
self.assertEqual(kwargs["tools"], expected_tools)
168+
169+
def test_jinja_tool_schema_fallback_to_flat_function(self):
170+
"""Fallback to function-only schema when template rejects OpenAI wrapper."""
171+
self.template_manager.chat_template_name = None
172+
self.template_manager.jinja_template_content_format = "string"
173+
174+
req = ChatCompletionRequest(
175+
model="x",
176+
messages=[{"role": "user", "content": "What is 2+2?"}],
177+
tools=[
178+
{
179+
"type": "function",
180+
"function": {
181+
"name": "add",
182+
"description": "Add two numbers.",
183+
"parameters": {
184+
"type": "object",
185+
"properties": {
186+
"a": {"type": "integer"},
187+
"b": {"type": "integer"},
188+
},
189+
"required": ["a", "b"],
190+
},
191+
},
192+
}
193+
],
194+
)
195+
196+
self.tm.tokenizer.apply_chat_template.side_effect = [
197+
RuntimeError("template expects flat tools format"),
198+
[1, 2, 3],
199+
]
200+
201+
self.chat._process_messages(req, is_multimodal=False)
202+
203+
first_tools = self.tm.tokenizer.apply_chat_template.call_args_list[0].kwargs[
204+
"tools"
205+
]
206+
second_tools = self.tm.tokenizer.apply_chat_template.call_args_list[1].kwargs[
207+
"tools"
208+
]
209+
self.assertEqual(first_tools, [tool.model_dump() for tool in req.tools])
210+
self.assertEqual(
211+
second_tools, [tool.function.model_dump() for tool in req.tools]
212+
)
213+
136214
def test_stop_str_isolation_between_requests(self):
137215
"""Test that stop strings from one request don't affect subsequent requests.
138216

0 commit comments

Comments
 (0)