Skip to content

Support dspy.Tool as input field type and dspy.ToolCall as output field type #8242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 2, 2025
Merged
2 changes: 1 addition & 1 deletion dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from dspy.evaluate import Evaluate # isort: skip
from dspy.clients import * # isort: skip
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, TwoStepAdapter, Image, Audio, History, BaseType # isort: skip
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, TwoStepAdapter, Image, Audio, History, BaseType, Tool, ToolCalls # isort: skip
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging
from dspy.utils.asyncify import asyncify
from dspy.utils.saving import load
Expand Down
4 changes: 3 additions & 1 deletion dspy/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dspy.adapters.chat_adapter import ChatAdapter
from dspy.adapters.json_adapter import JSONAdapter
from dspy.adapters.two_step_adapter import TwoStepAdapter
from dspy.adapters.types import History, Image, Audio, BaseType
from dspy.adapters.types import History, Image, Audio, BaseType, Tool, ToolCalls

__all__ = [
"Adapter",
Expand All @@ -13,4 +13,6 @@
"Audio",
"JSONAdapter",
"TwoStepAdapter",
"Tool",
"ToolCalls",
]
109 changes: 97 additions & 12 deletions dspy/adapters/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from typing import TYPE_CHECKING, Any, Optional, Type
import logging
from typing import TYPE_CHECKING, Any, Optional, Type, get_origin

import json_repair
import litellm

from dspy.adapters.types import History
from dspy.adapters.types.base_type import split_message_content_for_custom_types
from dspy.adapters.types.tool import Tool, ToolCalls
from dspy.signatures.signature import Signature
from dspy.utils.callback import BaseCallback, with_callbacks

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from dspy.clients.lm import LM

Expand All @@ -20,18 +27,78 @@ def __init_subclass__(cls, **kwargs) -> None:
cls.format = with_callbacks(cls.format)
cls.parse = with_callbacks(cls.parse)

def _call_post_process(self, outputs: list[dict[str, Any]], signature: Type[Signature]) -> list[dict[str, Any]]:
def _call_preprocess(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar q for the preprocess too, could this instead be wrapped in Tool's format and make use of split_message_content_for_custom_types?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for the tool calling, here we need to handle the special case of native function calling: https://platform.openai.com/docs/guides/function-calling?api-mode=chat, so we need to modify the LM call args in addition to the messages.

self,
lm: "LM",
lm_kwargs: dict[str, Any],
signature: Type[Signature],
inputs: dict[str, Any],
use_native_function_calling: bool = False,
) -> dict[str, Any]:
if use_native_function_calling:
tool_call_input_field_name = self._get_tool_call_input_field_name(signature)
tool_call_output_field_name = self._get_tool_call_output_field_name(signature)

if tool_call_output_field_name and tool_call_input_field_name is None:
raise ValueError(
f"You provided an output field {tool_call_output_field_name} to receive the tool calls information, "
"but did not provide any tools as the input. Please provide a list of tools as the input by adding an "
"input field with type `list[dspy.Tool]`."
)

if tool_call_output_field_name and litellm.supports_function_calling(model=lm.model):
tools = inputs[tool_call_input_field_name]
tools = tools if isinstance(tools, list) else [tools]

litellm_tools = []
for tool in tools:
litellm_tools.append(tool.format_as_litellm_function_call())

lm_kwargs["tools"] = litellm_tools

signature_for_native_function_calling = signature.delete(tool_call_output_field_name)

return signature_for_native_function_calling

return signature

def _call_postprocess(
self,
signature: Type[Signature],
outputs: list[dict[str, Any]],
) -> list[dict[str, Any]]:
values = []

tool_call_output_field_name = self._get_tool_call_output_field_name(signature)

for output in outputs:
output_logprobs = None
tool_calls = None
text = output

if isinstance(output, dict):
output, output_logprobs = output["text"], output["logprobs"]

value = self.parse(signature, output)

if output_logprobs is not None:
text = output["text"]
output_logprobs = output.get("logprobs")
tool_calls = output.get("tool_calls")

if text:
value = self.parse(signature, text)
else:
value = {}
for field_name in signature.output_fields.keys():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chenmoneygithub qq - do we need to handle this tool-call specific logic here? I'm wondering that since it inherits from BaseType if we can add a parse function to Tool (and the BaseType interface) and port over all the logic there (similar to the format we use for custom types)? This could keep post-process generalizable, and we'd just add a check for if the signature includes a BaseType output field and parse accordingly. curious on thoughts here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output handling is a bit different from input handling, which we can possibly generalize.

However I would be cautious about doing it because we don't know yet if generalization makes sense here - for the ToolCalls case, we are reading from the LM response and write back to the output field that of type dspy.ToolCalls if native function calling is used, which may be completely different from the second output field we introduce.

value[field_name] = None

if tool_calls and tool_call_output_field_name:
tool_calls = [
{
"name": v["function"]["name"],
"args": json_repair.loads(v["function"]["arguments"]),
}
for v in tool_calls
]
value[tool_call_output_field_name] = ToolCalls.from_dict_list(tool_calls)

if output_logprobs:
value["logprobs"] = output_logprobs

values.append(value)
Expand All @@ -46,10 +113,11 @@ def __call__(
demos: list[dict[str, Any]],
inputs: dict[str, Any],
) -> list[dict[str, Any]]:
inputs = self.format(signature, demos, inputs)
processed_signature = self._call_preprocess(lm, lm_kwargs, signature, inputs)
inputs = self.format(processed_signature, demos, inputs)

outputs = lm(messages=inputs, **lm_kwargs)
return self._call_post_process(outputs, signature)
return self._call_postprocess(signature, outputs)

async def acall(
self,
Expand All @@ -59,10 +127,11 @@ async def acall(
demos: list[dict[str, Any]],
inputs: dict[str, Any],
) -> list[dict[str, Any]]:
inputs = self.format(signature, demos, inputs)
processed_signature = self._call_preprocess(lm, lm_kwargs, signature, inputs)
inputs = self.format(processed_signature, demos, inputs)

outputs = await lm.acall(messages=inputs, **lm_kwargs)
return self._call_post_process(outputs, signature)
return self._call_postprocess(signature, outputs)

def format(
self,
Expand Down Expand Up @@ -297,6 +366,22 @@ def _get_history_field_name(self, signature: Type[Signature]) -> bool:
return name
return None

def _get_tool_call_input_field_name(self, signature: Type[Signature]) -> bool:
for name, field in signature.input_fields.items():
# Look for annotation `list[dspy.Tool]` or `dspy.Tool`
origin = get_origin(field.annotation)
if origin is list and field.annotation.__args__[0] == Tool:
return name
if field.annotation == Tool:
return name
return None

def _get_tool_call_output_field_name(self, signature: Type[Signature]) -> bool:
for name, field in signature.output_fields.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when there are multiple fields with type ToolCalls?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then only one field will be populated with value. We can raise a warning when there are multiple ToolCalls field, I kinda doubt if users will do that though. We made it ToolCalls as an indicator that it has multiple ToolCalls.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

had the same question for multimodal. in a single call (diff from n), i don't believe the models can produce multiple multimodal outputs with 2+ dspy.OutputFields (it'll raise the output exception we see sometimes)

could be different for ToolCalls tho, but might be safer to give that warning globally (maybe for the select ones in types?)

if field.annotation == ToolCalls:
return name
return None

def format_conversation_history(
self,
signature: Type[Signature],
Expand Down Expand Up @@ -352,4 +437,4 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
Returns:
A dictionary of the output fields.
"""
raise NotImplementedError
raise NotImplementedError
10 changes: 10 additions & 0 deletions dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ def __call__(
f"`response_format` argument. Original error: {e}"
) from e

def _call_preprocess(
self,
lm: "LM",
lm_kwargs: dict[str, Any],
signature: Type[Signature],
inputs: dict[str, Any],
use_native_function_calling: bool = True,
) -> dict[str, Any]:
return super()._call_preprocess(lm, lm_kwargs, signature, inputs, use_native_function_calling)

def format_field_structure(self, signature: Type[Signature]) -> str:
parts = []
parts.append("All interactions will be structured in the following way, with the appropriate values filled in.")
Expand Down
22 changes: 20 additions & 2 deletions dspy/adapters/two_step_adapter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Any, Optional, Type

import json_repair

from dspy.adapters.base import Adapter
from dspy.adapters.chat_adapter import ChatAdapter
from dspy.adapters.types import ToolCalls
from dspy.adapters.utils import get_field_description_string
from dspy.clients import LM
from dspy.signatures.field import InputField
Expand Down Expand Up @@ -115,11 +118,16 @@ async def acall(

values = []

tool_call_output_field_name = self._get_tool_call_output_field_name(signature)
for output in outputs:
output_logprobs = None
tool_calls = None
text = output

if isinstance(output, dict):
output, output_logprobs = output["text"], output["logprobs"]
text = output["text"]
output_logprobs = output.get("logprobs")
tool_calls = output.get("tool_calls")

try:
# Call the smaller LM to extract structured data from the raw completion text with ChatAdapter
Expand All @@ -128,13 +136,23 @@ async def acall(
lm_kwargs={},
signature=extractor_signature,
demos=[],
inputs={"text": output},
inputs={"text": text},
)
value = value[0]

except Exception as e:
raise ValueError(f"Failed to parse response from the original completion: {output}") from e

if tool_calls and tool_call_output_field_name:
tool_calls = [
{
"name": v["function"]["name"],
"args": json_repair.loads(v["function"]["arguments"]),
}
for v in tool_calls
]
value[tool_call_output_field_name] = ToolCalls.from_dict_list(tool_calls)

if output_logprobs is not None:
value["logprobs"] = output_logprobs

Expand Down
3 changes: 2 additions & 1 deletion dspy/adapters/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from dspy.adapters.types.image import Image
from dspy.adapters.types.audio import Audio
from dspy.adapters.types.base_type import BaseType
from dspy.adapters.types.tool import Tool, ToolCalls

__all__ = ["History", "Image", "Audio", "BaseType"]
__all__ = ["History", "Image", "Audio", "BaseType", "Tool", "ToolCalls"]
36 changes: 33 additions & 3 deletions dspy/adapters/types/base_type.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import re
from typing import Any
from typing import Any, Union, get_args, get_origin

import json_repair
import pydantic
Expand All @@ -26,12 +26,42 @@ def format(self) -> list[dict[str, Any]]:
```
"""

def format(self) -> list[dict[str, Any]]:
def format(self) -> Union[list[dict[str, Any]], str]:
raise NotImplementedError

@classmethod
def description(cls) -> str:
"""Description of the custom type"""
return ""

@classmethod
def extract_custom_type_from_annotation(cls, annotation):
"""Extract all custom types from the annotation.
This is used to extract all custom types from the annotation of a field, while the annotation can
have arbitrary level of nesting. For example, we detect `Tool` is in `list[dict[str, Tool]]`.
"""
# Direct match
if isinstance(annotation, type) and issubclass(annotation, cls):
return [annotation]

origin = get_origin(annotation)
if origin is None:
return []

result = []
# Recurse into all type args
for arg in get_args(annotation):
result.extend(cls.extract_custom_type_from_annotation(arg))

return result

@pydantic.model_serializer()
def serialize_model(self):
return f"{CUSTOM_TYPE_START_IDENTIFIER}{self.format()}{CUSTOM_TYPE_END_IDENTIFIER}"
formatted = self.format()
if isinstance(formatted, list):
return f"{CUSTOM_TYPE_START_IDENTIFIER}{self.format()}{CUSTOM_TYPE_END_IDENTIFIER}"
return formatted


def split_message_content_for_custom_types(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
Expand Down
Loading