Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit a78f60b

Browse files
committedMay 20, 2025
Support tool and toolcall in input and output field types
1 parent 25e918a commit a78f60b

File tree

7 files changed

+139
-26
lines changed

7 files changed

+139
-26
lines changed
 

‎dspy/adapters/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,4 +352,4 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
352352
Returns:
353353
A dictionary of the output fields.
354354
"""
355-
raise NotImplementedError
355+
raise NotImplementedError

‎dspy/adapters/types/base_type.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import re
3-
from typing import Any
3+
from typing import Any, Union, get_args, get_origin
44

55
import json_repair
66
import pydantic
@@ -26,12 +26,42 @@ def format(self) -> list[dict[str, Any]]:
2626
```
2727
"""
2828

29-
def format(self) -> list[dict[str, Any]]:
29+
def format(self) -> Union[list[dict[str, Any]], str]:
3030
raise NotImplementedError
3131

32+
@classmethod
33+
def description(cls) -> str:
34+
"""Description of the custom type"""
35+
return ""
36+
37+
@classmethod
38+
def extract_custom_type_from_annotation(cls, annotation):
39+
"""Extract all custom types from the annotation.
40+
41+
This is used to extract all custom types from the annotation of a field, while the annotation can
42+
have arbitrary level of nesting. For example, we detect `Tool` is in `list[dict[str, Tool]]`.
43+
"""
44+
# Direct match
45+
if isinstance(annotation, type) and issubclass(annotation, cls):
46+
return [annotation]
47+
48+
origin = get_origin(annotation)
49+
if origin is None:
50+
return []
51+
52+
result = []
53+
# Recurse into all type args
54+
for arg in get_args(annotation):
55+
result.extend(cls.extract_custom_type_from_annotation(arg))
56+
57+
return result
58+
3259
@pydantic.model_serializer()
3360
def serialize_model(self):
34-
return f"{CUSTOM_TYPE_START_IDENTIFIER}{self.format()}{CUSTOM_TYPE_END_IDENTIFIER}"
61+
formatted = self.format()
62+
if isinstance(formatted, list):
63+
return f"{CUSTOM_TYPE_START_IDENTIFIER}{self.format()}{CUSTOM_TYPE_END_IDENTIFIER}"
64+
return formatted
3565

3666

3767
def split_message_content_for_custom_types(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:

‎dspy/adapters/types/tool.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ class Tool(BaseType):
1919
functions for now.
2020
"""
2121

22+
func: Callable
23+
name: Optional[str] = None
24+
desc: Optional[str] = None
25+
args: Optional[dict[str, Any]] = None
26+
arg_types: Optional[dict[str, Any]] = None
27+
arg_desc: Optional[dict[str, str]] = None
28+
has_kwargs: bool = False
29+
2230
def __init__(
2331
self,
2432
func: Callable,
@@ -56,15 +64,7 @@ def foo(x: int, y: str = "hello"):
5664
# Expected output: {'x': {'type': 'integer'}, 'y': {'type': 'string', 'default': 'hello'}}
5765
```
5866
"""
59-
super().__init__() # Initialize the Pydantic BaseModel
60-
self.func = func
61-
self.name = name
62-
self.desc = desc
63-
self.args = args
64-
self.arg_types = arg_types
65-
self.arg_desc = arg_desc
66-
self.has_kwargs = False
67-
67+
super().__init__(func=func, name=name, desc=desc, args=args, arg_types=arg_types, arg_desc=arg_desc)
6868
self._parse_function(func, arg_desc)
6969

7070
def _parse_function(self, func: Callable, arg_desc: Optional[dict[str, str]] = None):
@@ -141,17 +141,7 @@ def _validate_and_parse_args(self, **kwargs):
141141
return parsed_kwargs
142142

143143
def format(self):
144-
return [
145-
{
146-
"type": "function",
147-
"function": {
148-
"name": self.name,
149-
"description": self.desc,
150-
"parameters": self.args,
151-
"requirements": "Arguments must be provided in JSON format.",
152-
},
153-
}
154-
]
144+
return str(self)
155145

156146
@with_callbacks
157147
def __call__(self, **kwargs):
@@ -200,6 +190,13 @@ class ToolCall(BaseType):
200190
name: str
201191
args: dict[str, Any]
202192

193+
@classmethod
194+
def description(cls) -> str:
195+
return (
196+
"Tool call information, including the name of the tool and the arguments to be passed to it. "
197+
"Arguments must be provided in JSON format."
198+
)
199+
203200

204201
def resolve_json_schema_reference(schema: dict) -> dict:
205202
"""Recursively resolve json model schema, expanding all references."""

‎dspy/adapters/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pydantic import TypeAdapter
1111
from pydantic.fields import FieldInfo
1212

13+
from dspy.adapters.types.base_type import BaseType
1314
from dspy.signatures.utils import get_dspy_field_type
1415

1516

@@ -200,7 +201,14 @@ def get_field_description_string(fields: dict) -> str:
200201
for idx, (k, v) in enumerate(fields.items()):
201202
field_message = f"{idx + 1}. `{k}`"
202203
field_message += f" ({get_annotation_name(v.annotation)})"
203-
field_message += f": {v.json_schema_extra['desc']}" if v.json_schema_extra["desc"] != f"${{{k}}}" else ""
204+
desc = v.json_schema_extra["desc"] if v.json_schema_extra["desc"] != f"${{{k}}}" else ""
205+
206+
custom_types = BaseType.extract_custom_type_from_annotation(v.annotation)
207+
for custom_type in custom_types:
208+
if len(custom_type.description()) > 0:
209+
desc += f"\n Type description of {get_annotation_name(custom_type)}: {custom_type.description()}"
210+
211+
field_message += f": {desc}"
204212
field_message += (
205213
f"\nConstraints: {v.json_schema_extra['constraints']}" if v.json_schema_extra.get("constraints") else ""
206214
)

‎tests/adapters/test_chat_adapter.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import dspy
77
import pydantic
8+
import json
89

910

1011
@pytest.mark.parametrize(
@@ -213,3 +214,40 @@ class MySignature(dspy.Signature):
213214

214215
# The query image is formatted in the last user message
215216
assert {"type": "image_url", "image_url": {"url": "https://example.com/image4.jpg"}} in messages[-1]["content"]
217+
218+
219+
def test_chat_adapter_with_tool():
220+
class MySignature(dspy.Signature):
221+
"""Answer question with the help of the tools"""
222+
223+
question: str = dspy.InputField()
224+
tools: list[dspy.Tool] = dspy.InputField()
225+
answer: str = dspy.OutputField()
226+
tool_calls: list[dspy.ToolCall] = dspy.OutputField()
227+
228+
def get_weather(city: str) -> str:
229+
"""Get the weather for a city"""
230+
return f"The weather in {city} is sunny"
231+
232+
def get_population(country: str, year: int) -> str:
233+
"""Get the population for a country"""
234+
return f"The population of {country} in {year} is 1000000"
235+
236+
tools = [dspy.Tool(get_weather), dspy.Tool(get_population)]
237+
238+
adapter = dspy.ChatAdapter()
239+
messages = adapter.format(MySignature, [], {"question": "What is the weather in Tokyo?", "tools": tools})
240+
241+
assert len(messages) == 2
242+
243+
# The output field type description should be included in the system message even if the output field is nested
244+
assert dspy.ToolCall.description() in messages[0]["content"]
245+
246+
# The user message should include the question and the tools
247+
assert "What is the weather in Tokyo?" in messages[1]["content"]
248+
assert "get_weather" in messages[1]["content"]
249+
assert "get_population" in messages[1]["content"]
250+
251+
# Tool arguments format should be included in the user message
252+
assert "{'city': {'type': 'string'}}" in messages[1]["content"]
253+
assert "{'country': {'type': 'string'}, 'year': {'type': 'integer'}}" in messages[1]["content"]

‎tests/adapters/test_json_adapter.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from litellm.utils import Choices, Message, ModelResponse
66

77
import dspy
8+
import json
89

910

1011
def test_json_adapter_passes_structured_output_when_supported_by_model():
@@ -337,3 +338,40 @@ class MySignature(dspy.Signature):
337338

338339
# The query image is formatted in the last user message
339340
assert {"type": "image_url", "image_url": {"url": "https://example.com/image4.jpg"}} in messages[-1]["content"]
341+
342+
343+
def test_json_adapter_with_tool():
344+
class MySignature(dspy.Signature):
345+
"""Answer question with the help of the tools"""
346+
347+
question: str = dspy.InputField()
348+
tools: list[dspy.Tool] = dspy.InputField()
349+
answer: str = dspy.OutputField()
350+
tool_calls: list[dspy.ToolCall] = dspy.OutputField()
351+
352+
def get_weather(city: str) -> str:
353+
"""Get the weather for a city"""
354+
return f"The weather in {city} is sunny"
355+
356+
def get_population(country: str, year: int) -> str:
357+
"""Get the population for a country"""
358+
return f"The population of {country} in {year} is 1000000"
359+
360+
tools = [dspy.Tool(get_weather), dspy.Tool(get_population)]
361+
362+
adapter = dspy.JSONAdapter()
363+
messages = adapter.format(MySignature, [], {"question": "What is the weather in Tokyo?", "tools": tools})
364+
365+
assert len(messages) == 2
366+
367+
# The output field type description should be included in the system message even if the output field is nested
368+
assert dspy.ToolCall.description() in messages[0]["content"]
369+
370+
# The user message should include the question and the tools
371+
assert "What is the weather in Tokyo?" in messages[1]["content"]
372+
assert "get_weather" in messages[1]["content"]
373+
assert "get_population" in messages[1]["content"]
374+
375+
# Tool arguments format should be included in the user message
376+
assert "{'city': {'type': 'string'}}" in messages[1]["content"]
377+
assert "{'country': {'type': 'string'}, 'year': {'type': 'integer'}}" in messages[1]["content"]

‎tests/primitives/test_tool.py renamed to ‎tests/adapters/test_tool.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import asyncio
22
from typing import Any, Optional
33

4+
import dspy
45
import pytest
56
from pydantic import BaseModel
67

7-
from dspy.primitives.tool import Tool
8+
from dspy.adapters.types.tool import Tool
9+
from unittest import mock
810

911

1012
# Test fixtures

0 commit comments

Comments
 (0)
Please sign in to comment.