From 5ce5961047ced641aa3f6cab1c11655bfd48cb50 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Mon, 19 May 2025 14:46:00 -0700 Subject: [PATCH 01/10] init --- dspy/__init__.py | 2 +- dspy/adapters/__init__.py | 4 +- dspy/adapters/types/__init__.py | 3 +- dspy/adapters/types/tool.py | 227 ++++++++++++++++++++++++++++++++ dspy/predict/react.py | 2 +- dspy/primitives/__init__.py | 2 - dspy/primitives/tool.py | 17 ++- dspy/utils/mcp.py | 2 +- 8 files changed, 251 insertions(+), 8 deletions(-) create mode 100644 dspy/adapters/types/tool.py diff --git a/dspy/__init__.py b/dspy/__init__.py index 68f9065282..38a2e85748 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -8,7 +8,7 @@ from dspy.evaluate import Evaluate # isort: skip from dspy.clients import * # isort: skip -from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, TwoStepAdapter, Image, Audio, History, BaseType # isort: skip +from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, TwoStepAdapter, Image, Audio, History, BaseType, Tool, ToolCall # isort: skip from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging from dspy.utils.asyncify import asyncify from dspy.utils.saving import load diff --git a/dspy/adapters/__init__.py b/dspy/adapters/__init__.py index 4ae8c4e262..c724d24b19 100644 --- a/dspy/adapters/__init__.py +++ b/dspy/adapters/__init__.py @@ -2,7 +2,7 @@ from dspy.adapters.chat_adapter import ChatAdapter from dspy.adapters.json_adapter import JSONAdapter from dspy.adapters.two_step_adapter import TwoStepAdapter -from dspy.adapters.types import History, Image, Audio, BaseType +from dspy.adapters.types import History, Image, Audio, BaseType, Tool, ToolCall __all__ = [ "Adapter", @@ -13,4 +13,6 @@ "Audio", "JSONAdapter", "TwoStepAdapter", + "Tool", + "ToolCall", ] diff --git a/dspy/adapters/types/__init__.py b/dspy/adapters/types/__init__.py index c690a85537..ac42265e87 100644 --- a/dspy/adapters/types/__init__.py +++ b/dspy/adapters/types/__init__.py @@ -2,5 +2,6 @@ from dspy.adapters.types.image import Image from dspy.adapters.types.audio import Audio from dspy.adapters.types.base_type import BaseType +from dspy.adapters.types.tool import Tool, ToolCall -__all__ = ["History", "Image", "Audio", "BaseType"] +__all__ = ["History", "Image", "Audio", "BaseType", "Tool", "ToolCall"] diff --git a/dspy/adapters/types/tool.py b/dspy/adapters/types/tool.py new file mode 100644 index 0000000000..7132dc3384 --- /dev/null +++ b/dspy/adapters/types/tool.py @@ -0,0 +1,227 @@ +import asyncio +import inspect +from typing import TYPE_CHECKING, Any, Callable, Optional, get_origin, get_type_hints + +from jsonschema import ValidationError, validate +from pydantic import BaseModel, TypeAdapter, create_model + +from dspy.adapters.types.base_type import BaseType +from dspy.utils.callback import with_callbacks + +if TYPE_CHECKING: + import mcp + + +class Tool(BaseType): + """Tool class. + + This class is used to simplify the creation of tools for tool calling (function calling) in LLMs. Only supports + functions for now. + """ + + def __init__( + self, + func: Callable, + name: Optional[str] = None, + desc: Optional[str] = None, + args: Optional[dict[str, Any]] = None, + arg_types: Optional[dict[str, Any]] = None, + arg_desc: Optional[dict[str, str]] = None, + ): + """Initialize the Tool class. + + Users can choose to specify the `name`, `desc`, `args`, and `arg_types`, or let the `dspy.Tool` + automatically infer the values from the function. For values that are specified by the user, automatic inference + will not be performed on them. + + Args: + func (Callable): The actual function that is being wrapped by the tool. + name (Optional[str], optional): The name of the tool. Defaults to None. + desc (Optional[str], optional): The description of the tool. Defaults to None. + args (Optional[dict[str, Any]], optional): The args and their schema of the tool, represented as a + dictionary from arg name to arg's json schema. Defaults to None. + arg_types (Optional[dict[str, Any]], optional): The argument types of the tool, represented as a dictionary + from arg name to the type of the argument. Defaults to None. + arg_desc (Optional[dict[str, str]], optional): Descriptions for each arg, represented as a + dictionary from arg name to description string. Defaults to None. + + Example: + + ```python + def foo(x: int, y: str = "hello"): + return str(x) + y + + tool = Tool(foo) + print(tool.args) + # Expected output: {'x': {'type': 'integer'}, 'y': {'type': 'string', 'default': 'hello'}} + ``` + """ + super().__init__() # Initialize the Pydantic BaseModel + self.func = func + self.name = name + self.desc = desc + self.args = args + self.arg_types = arg_types + self.arg_desc = arg_desc + self.has_kwargs = False + + self._parse_function(func, arg_desc) + + def _parse_function(self, func: Callable, arg_desc: Optional[dict[str, str]] = None): + """Helper method that parses a function to extract the name, description, and args. + + This is a helper function that automatically infers the name, description, and args of the tool from the + provided function. In order to make the inference work, the function must have valid type hints. + """ + annotations_func = func if inspect.isfunction(func) or inspect.ismethod(func) else func.__call__ + name = getattr(func, "__name__", type(func).__name__) + desc = getattr(func, "__doc__", None) or getattr(annotations_func, "__doc__", "") + args = {} + arg_types = {} + + # Use inspect.signature to get all arg names + sig = inspect.signature(annotations_func) + # Get available type hints + available_hints = get_type_hints(annotations_func) + # Build a dictionary of arg name -> type (defaulting to Any when missing) + hints = {param_name: available_hints.get(param_name, Any) for param_name in sig.parameters.keys()} + default_values = {param_name: sig.parameters[param_name].default for param_name in sig.parameters.keys()} + + # Process each argument's type to generate its JSON schema. + for k, v in hints.items(): + arg_types[k] = v + if k == "return": + continue + # Check if the type (or its origin) is a subclass of Pydantic's BaseModel + origin = get_origin(v) or v + if isinstance(origin, type) and issubclass(origin, BaseModel): + # Get json schema, and replace $ref with the actual schema + v_json_schema = resolve_json_schema_reference(v.model_json_schema()) + args[k] = v_json_schema + else: + args[k] = TypeAdapter(v).json_schema() + if default_values[k] is not inspect.Parameter.empty: + args[k]["default"] = default_values[k] + if arg_desc and k in arg_desc: + args[k]["description"] = arg_desc[k] + + self.name = self.name or name + self.desc = self.desc or desc + self.args = self.args or args + self.arg_types = self.arg_types or arg_types + self.has_kwargs = any(param.kind == param.VAR_KEYWORD for param in sig.parameters.values()) + + def _validate_and_parse_args(self, **kwargs): + # Validate the args value comply to the json schema. + for k, v in kwargs.items(): + if k not in self.args: + if self.has_kwargs: + continue + else: + raise ValueError(f"Arg {k} is not in the tool's args.") + try: + instance = v.model_dump() if hasattr(v, "model_dump") else v + type_str = self.args[k].get("type") + if type_str is not None and type_str != "Any": + validate(instance=instance, schema=self.args[k]) + except ValidationError as e: + raise ValueError(f"Arg {k} is invalid: {e.message}") + + # Parse the args to the correct type. + parsed_kwargs = {} + for k, v in kwargs.items(): + if k in self.arg_types and self.arg_types[k] != Any: + # Create a pydantic model wrapper with a dummy field `value` to parse the arg to the correct type. + # This is specifically useful for handling nested Pydantic models like `list[list[MyPydanticModel]]` + pydantic_wrapper = create_model("Wrapper", value=(self.arg_types[k], ...)) + parsed = pydantic_wrapper.model_validate({"value": v}) + parsed_kwargs[k] = parsed.value + else: + parsed_kwargs[k] = v + return parsed_kwargs + + def format(self): + return [ + { + "type": "function", + "function": { + "name": self.name, + "description": self.desc, + "parameters": self.args, + "requirements": "Arguments must be provided in JSON format.", + }, + } + ] + + @with_callbacks + def __call__(self, **kwargs): + parsed_kwargs = self._validate_and_parse_args(**kwargs) + result = self.func(**parsed_kwargs) + if asyncio.iscoroutine(result): + raise ValueError("You are calling `__call__` on an async tool, please use `acall` instead.") + return result + + @with_callbacks + async def acall(self, **kwargs): + parsed_kwargs = self._validate_and_parse_args(**kwargs) + result = self.func(**parsed_kwargs) + if asyncio.iscoroutine(result): + return await result + else: + # We should allow calling a sync tool in the async path. + return result + + @classmethod + def from_mcp_tool(cls, session: "mcp.client.session.ClientSession", tool: "mcp.types.Tool") -> "Tool": + """ + Build a DSPy tool from an MCP tool and a ClientSession. + + Args: + session: The MCP session to use. + tool: The MCP tool to convert. + + Returns: + A Tool object. + """ + from dspy.utils.mcp import convert_mcp_tool + + return convert_mcp_tool(session, tool) + + def __repr__(self): + return f"Tool(name={self.name}, desc={self.desc}, args={self.args})" + + def __str__(self): + desc = f", whose description is {self.desc}.".replace("\n", " ") if self.desc else "." + arg_desc = f"It takes arguments {self.args} in JSON format." + return f"{self.name}{desc} {arg_desc}" + + +class ToolCall(BaseType): + name: str + args: dict[str, Any] + + +def resolve_json_schema_reference(schema: dict) -> dict: + """Recursively resolve json model schema, expanding all references.""" + + # If there are no definitions to resolve, return the main schema + if "$defs" not in schema and "definitions" not in schema: + return schema + + def resolve_refs(obj: Any) -> Any: + if not isinstance(obj, (dict, list)): + return obj + if isinstance(obj, dict): + if "$ref" in obj: + ref_path = obj["$ref"].split("/")[-1] + return resolve_refs(schema["$defs"][ref_path]) + return {k: resolve_refs(v) for k, v in obj.items()} + + # Must be a list + return [resolve_refs(item) for item in obj] + + # Resolve all references in the main schema + resolved_schema = resolve_refs(schema) + # Remove the $defs key as it's no longer needed + resolved_schema.pop("$defs", None) + return resolved_schema diff --git a/dspy/predict/react.py b/dspy/predict/react.py index f6098672a9..bc7b3e23b6 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -4,8 +4,8 @@ from litellm import ContextWindowExceededError import dspy +from dspy.adapters.types.tool import Tool from dspy.primitives.program import Module -from dspy.primitives.tool import Tool from dspy.signatures.signature import ensure_signature logger = logging.getLogger(__name__) diff --git a/dspy/primitives/__init__.py b/dspy/primitives/__init__.py index 0fa160d260..00d3f47de4 100644 --- a/dspy/primitives/__init__.py +++ b/dspy/primitives/__init__.py @@ -4,7 +4,6 @@ from dspy.primitives.prediction import Completions, Prediction from dspy.primitives.program import Module, Program from dspy.primitives.python_interpreter import PythonInterpreter -from dspy.primitives.tool import Tool __all__ = [ "assertions", @@ -15,5 +14,4 @@ "Program", "Module", "PythonInterpreter", - "Tool", ] diff --git a/dspy/primitives/tool.py b/dspy/primitives/tool.py index bf4a208f5a..44e3744122 100644 --- a/dspy/primitives/tool.py +++ b/dspy/primitives/tool.py @@ -5,13 +5,14 @@ from jsonschema import ValidationError, validate from pydantic import BaseModel, TypeAdapter, create_model +from dspy.adapters.types.base_type import BaseType from dspy.utils.callback import with_callbacks if TYPE_CHECKING: import mcp -class Tool: +class Tool(BaseType): """Tool class. This class is used to simplify the creation of tools for tool calling (function calling) in LLMs. Only supports @@ -55,6 +56,7 @@ def foo(x: int, y: str = "hello"): # Expected output: {'x': {'type': 'integer'}, 'y': {'type': 'string', 'default': 'hello'}} ``` """ + super().__init__() # Initialize the Pydantic BaseModel self.func = func self.name = name self.desc = desc @@ -138,6 +140,19 @@ def _validate_and_parse_args(self, **kwargs): parsed_kwargs[k] = v return parsed_kwargs + def format(self, **kwargs): + return [ + { + "type": "function", + "function": { + "name": self.name, + "description": self.desc, + "parameters": self.args, + "requirements": "Arguments must be provided in JSON format.", + }, + } + ] + @with_callbacks def __call__(self, **kwargs): parsed_kwargs = self._validate_and_parse_args(**kwargs) diff --git a/dspy/utils/mcp.py b/dspy/utils/mcp.py index 54b9cca088..0b42bc1e88 100644 --- a/dspy/utils/mcp.py +++ b/dspy/utils/mcp.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Any, Tuple, Type, Union -from dspy.primitives.tool import Tool, resolve_json_schema_reference +from dspy.adapters.types.tool import Tool, resolve_json_schema_reference if TYPE_CHECKING: import mcp From 3e2b99acfbbfc93c7dad7d8720e5fac58dfe5137 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Mon, 19 May 2025 21:48:19 -0700 Subject: [PATCH 02/10] Support tool and toolcall in input and output field types --- dspy/adapters/base.py | 2 +- dspy/adapters/types/base_type.py | 36 +++++++++++++++++-- dspy/adapters/types/tool.py | 37 +++++++++----------- dspy/adapters/utils.py | 10 +++++- tests/adapters/test_chat_adapter.py | 38 +++++++++++++++++++++ tests/adapters/test_json_adapter.py | 38 +++++++++++++++++++++ tests/{primitives => adapters}/test_tool.py | 4 ++- 7 files changed, 139 insertions(+), 26 deletions(-) rename tests/{primitives => adapters}/test_tool.py (99%) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 9ad6d12565..aa453e7498 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -352,4 +352,4 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]: Returns: A dictionary of the output fields. """ - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/dspy/adapters/types/base_type.py b/dspy/adapters/types/base_type.py index f2983ef463..b57c75cb1a 100644 --- a/dspy/adapters/types/base_type.py +++ b/dspy/adapters/types/base_type.py @@ -1,6 +1,6 @@ import json import re -from typing import Any +from typing import Any, Union, get_args, get_origin import json_repair import pydantic @@ -26,12 +26,42 @@ def format(self) -> list[dict[str, Any]]: ``` """ - def format(self) -> list[dict[str, Any]]: + def format(self) -> Union[list[dict[str, Any]], str]: raise NotImplementedError + @classmethod + def description(cls) -> str: + """Description of the custom type""" + return "" + + @classmethod + def extract_custom_type_from_annotation(cls, annotation): + """Extract all custom types from the annotation. + + This is used to extract all custom types from the annotation of a field, while the annotation can + have arbitrary level of nesting. For example, we detect `Tool` is in `list[dict[str, Tool]]`. + """ + # Direct match + if isinstance(annotation, type) and issubclass(annotation, cls): + return [annotation] + + origin = get_origin(annotation) + if origin is None: + return [] + + result = [] + # Recurse into all type args + for arg in get_args(annotation): + result.extend(cls.extract_custom_type_from_annotation(arg)) + + return result + @pydantic.model_serializer() def serialize_model(self): - return f"{CUSTOM_TYPE_START_IDENTIFIER}{self.format()}{CUSTOM_TYPE_END_IDENTIFIER}" + formatted = self.format() + if isinstance(formatted, list): + return f"{CUSTOM_TYPE_START_IDENTIFIER}{self.format()}{CUSTOM_TYPE_END_IDENTIFIER}" + return formatted def split_message_content_for_custom_types(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: diff --git a/dspy/adapters/types/tool.py b/dspy/adapters/types/tool.py index 7132dc3384..8e7a9b7da3 100644 --- a/dspy/adapters/types/tool.py +++ b/dspy/adapters/types/tool.py @@ -19,6 +19,14 @@ class Tool(BaseType): functions for now. """ + func: Callable + name: Optional[str] = None + desc: Optional[str] = None + args: Optional[dict[str, Any]] = None + arg_types: Optional[dict[str, Any]] = None + arg_desc: Optional[dict[str, str]] = None + has_kwargs: bool = False + def __init__( self, func: Callable, @@ -56,15 +64,7 @@ def foo(x: int, y: str = "hello"): # Expected output: {'x': {'type': 'integer'}, 'y': {'type': 'string', 'default': 'hello'}} ``` """ - super().__init__() # Initialize the Pydantic BaseModel - self.func = func - self.name = name - self.desc = desc - self.args = args - self.arg_types = arg_types - self.arg_desc = arg_desc - self.has_kwargs = False - + super().__init__(func=func, name=name, desc=desc, args=args, arg_types=arg_types, arg_desc=arg_desc) self._parse_function(func, arg_desc) def _parse_function(self, func: Callable, arg_desc: Optional[dict[str, str]] = None): @@ -141,17 +141,7 @@ def _validate_and_parse_args(self, **kwargs): return parsed_kwargs def format(self): - return [ - { - "type": "function", - "function": { - "name": self.name, - "description": self.desc, - "parameters": self.args, - "requirements": "Arguments must be provided in JSON format.", - }, - } - ] + return str(self) @with_callbacks def __call__(self, **kwargs): @@ -200,6 +190,13 @@ class ToolCall(BaseType): name: str args: dict[str, Any] + @classmethod + def description(cls) -> str: + return ( + "Tool call information, including the name of the tool and the arguments to be passed to it. " + "Arguments must be provided in JSON format." + ) + def resolve_json_schema_reference(schema: dict) -> dict: """Recursively resolve json model schema, expanding all references.""" diff --git a/dspy/adapters/utils.py b/dspy/adapters/utils.py index 21540bd32c..2c27fec949 100644 --- a/dspy/adapters/utils.py +++ b/dspy/adapters/utils.py @@ -10,6 +10,7 @@ from pydantic import TypeAdapter from pydantic.fields import FieldInfo +from dspy.adapters.types.base_type import BaseType from dspy.signatures.utils import get_dspy_field_type @@ -200,7 +201,14 @@ def get_field_description_string(fields: dict) -> str: for idx, (k, v) in enumerate(fields.items()): field_message = f"{idx + 1}. `{k}`" field_message += f" ({get_annotation_name(v.annotation)})" - field_message += f": {v.json_schema_extra['desc']}" if v.json_schema_extra["desc"] != f"${{{k}}}" else "" + desc = v.json_schema_extra["desc"] if v.json_schema_extra["desc"] != f"${{{k}}}" else "" + + custom_types = BaseType.extract_custom_type_from_annotation(v.annotation) + for custom_type in custom_types: + if len(custom_type.description()) > 0: + desc += f"\n Type description of {get_annotation_name(custom_type)}: {custom_type.description()}" + + field_message += f": {desc}" field_message += ( f"\nConstraints: {v.json_schema_extra['constraints']}" if v.json_schema_extra.get("constraints") else "" ) diff --git a/tests/adapters/test_chat_adapter.py b/tests/adapters/test_chat_adapter.py index d25000e4ef..eba77a1397 100644 --- a/tests/adapters/test_chat_adapter.py +++ b/tests/adapters/test_chat_adapter.py @@ -6,6 +6,7 @@ import dspy import pydantic +import json @pytest.mark.parametrize( @@ -314,3 +315,40 @@ class MySignature(dspy.Signature): # The query image is formatted in the last user message assert {"type": "image_url", "image_url": {"url": "https://example.com/image4.jpg"}} in messages[-1]["content"] + + +def test_chat_adapter_with_tool(): + class MySignature(dspy.Signature): + """Answer question with the help of the tools""" + + question: str = dspy.InputField() + tools: list[dspy.Tool] = dspy.InputField() + answer: str = dspy.OutputField() + tool_calls: list[dspy.ToolCall] = dspy.OutputField() + + def get_weather(city: str) -> str: + """Get the weather for a city""" + return f"The weather in {city} is sunny" + + def get_population(country: str, year: int) -> str: + """Get the population for a country""" + return f"The population of {country} in {year} is 1000000" + + tools = [dspy.Tool(get_weather), dspy.Tool(get_population)] + + adapter = dspy.ChatAdapter() + messages = adapter.format(MySignature, [], {"question": "What is the weather in Tokyo?", "tools": tools}) + + assert len(messages) == 2 + + # The output field type description should be included in the system message even if the output field is nested + assert dspy.ToolCall.description() in messages[0]["content"] + + # The user message should include the question and the tools + assert "What is the weather in Tokyo?" in messages[1]["content"] + assert "get_weather" in messages[1]["content"] + assert "get_population" in messages[1]["content"] + + # Tool arguments format should be included in the user message + assert "{'city': {'type': 'string'}}" in messages[1]["content"] + assert "{'country': {'type': 'string'}, 'year': {'type': 'integer'}}" in messages[1]["content"] diff --git a/tests/adapters/test_json_adapter.py b/tests/adapters/test_json_adapter.py index 6cd185203d..decbc34f0b 100644 --- a/tests/adapters/test_json_adapter.py +++ b/tests/adapters/test_json_adapter.py @@ -5,6 +5,7 @@ from litellm.utils import Choices, Message, ModelResponse import dspy +import json def test_json_adapter_passes_structured_output_when_supported_by_model(): @@ -337,3 +338,40 @@ class MySignature(dspy.Signature): # The query image is formatted in the last user message assert {"type": "image_url", "image_url": {"url": "https://example.com/image4.jpg"}} in messages[-1]["content"] + + +def test_json_adapter_with_tool(): + class MySignature(dspy.Signature): + """Answer question with the help of the tools""" + + question: str = dspy.InputField() + tools: list[dspy.Tool] = dspy.InputField() + answer: str = dspy.OutputField() + tool_calls: list[dspy.ToolCall] = dspy.OutputField() + + def get_weather(city: str) -> str: + """Get the weather for a city""" + return f"The weather in {city} is sunny" + + def get_population(country: str, year: int) -> str: + """Get the population for a country""" + return f"The population of {country} in {year} is 1000000" + + tools = [dspy.Tool(get_weather), dspy.Tool(get_population)] + + adapter = dspy.JSONAdapter() + messages = adapter.format(MySignature, [], {"question": "What is the weather in Tokyo?", "tools": tools}) + + assert len(messages) == 2 + + # The output field type description should be included in the system message even if the output field is nested + assert dspy.ToolCall.description() in messages[0]["content"] + + # The user message should include the question and the tools + assert "What is the weather in Tokyo?" in messages[1]["content"] + assert "get_weather" in messages[1]["content"] + assert "get_population" in messages[1]["content"] + + # Tool arguments format should be included in the user message + assert "{'city': {'type': 'string'}}" in messages[1]["content"] + assert "{'country': {'type': 'string'}, 'year': {'type': 'integer'}}" in messages[1]["content"] diff --git a/tests/primitives/test_tool.py b/tests/adapters/test_tool.py similarity index 99% rename from tests/primitives/test_tool.py rename to tests/adapters/test_tool.py index 06f5b8d582..da8bb93c20 100644 --- a/tests/primitives/test_tool.py +++ b/tests/adapters/test_tool.py @@ -1,10 +1,12 @@ import asyncio from typing import Any, Optional +import dspy import pytest from pydantic import BaseModel -from dspy.primitives.tool import Tool +from dspy.adapters.types.tool import Tool +from unittest import mock # Test fixtures From b171e9f40969e14766fa7573e6a8f15cd7fee6ae Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Mon, 19 May 2025 22:00:25 -0700 Subject: [PATCH 03/10] move file --- dspy/primitives/tool.py | 222 ---------------------------------------- 1 file changed, 222 deletions(-) delete mode 100644 dspy/primitives/tool.py diff --git a/dspy/primitives/tool.py b/dspy/primitives/tool.py deleted file mode 100644 index 44e3744122..0000000000 --- a/dspy/primitives/tool.py +++ /dev/null @@ -1,222 +0,0 @@ -import asyncio -import inspect -from typing import TYPE_CHECKING, Any, Callable, Optional, get_origin, get_type_hints - -from jsonschema import ValidationError, validate -from pydantic import BaseModel, TypeAdapter, create_model - -from dspy.adapters.types.base_type import BaseType -from dspy.utils.callback import with_callbacks - -if TYPE_CHECKING: - import mcp - - -class Tool(BaseType): - """Tool class. - - This class is used to simplify the creation of tools for tool calling (function calling) in LLMs. Only supports - functions for now. - """ - - def __init__( - self, - func: Callable, - name: Optional[str] = None, - desc: Optional[str] = None, - args: Optional[dict[str, Any]] = None, - arg_types: Optional[dict[str, Any]] = None, - arg_desc: Optional[dict[str, str]] = None, - ): - """Initialize the Tool class. - - Users can choose to specify the `name`, `desc`, `args`, and `arg_types`, or let the `dspy.Tool` - automatically infer the values from the function. For values that are specified by the user, automatic inference - will not be performed on them. - - Args: - func (Callable): The actual function that is being wrapped by the tool. - name (Optional[str], optional): The name of the tool. Defaults to None. - desc (Optional[str], optional): The description of the tool. Defaults to None. - args (Optional[dict[str, Any]], optional): The args and their schema of the tool, represented as a - dictionary from arg name to arg's json schema. Defaults to None. - arg_types (Optional[dict[str, Any]], optional): The argument types of the tool, represented as a dictionary - from arg name to the type of the argument. Defaults to None. - arg_desc (Optional[dict[str, str]], optional): Descriptions for each arg, represented as a - dictionary from arg name to description string. Defaults to None. - - Example: - - ```python - def foo(x: int, y: str = "hello"): - return str(x) + y - - tool = Tool(foo) - print(tool.args) - # Expected output: {'x': {'type': 'integer'}, 'y': {'type': 'string', 'default': 'hello'}} - ``` - """ - super().__init__() # Initialize the Pydantic BaseModel - self.func = func - self.name = name - self.desc = desc - self.args = args - self.arg_types = arg_types - self.arg_desc = arg_desc - self.has_kwargs = False - - self._parse_function(func, arg_desc) - - def _parse_function(self, func: Callable, arg_desc: Optional[dict[str, str]] = None): - """Helper method that parses a function to extract the name, description, and args. - - This is a helper function that automatically infers the name, description, and args of the tool from the - provided function. In order to make the inference work, the function must have valid type hints. - """ - annotations_func = func if inspect.isfunction(func) or inspect.ismethod(func) else func.__call__ - name = getattr(func, "__name__", type(func).__name__) - desc = getattr(func, "__doc__", None) or getattr(annotations_func, "__doc__", "") - args = {} - arg_types = {} - - # Use inspect.signature to get all arg names - sig = inspect.signature(annotations_func) - # Get available type hints - available_hints = get_type_hints(annotations_func) - # Build a dictionary of arg name -> type (defaulting to Any when missing) - hints = {param_name: available_hints.get(param_name, Any) for param_name in sig.parameters.keys()} - default_values = {param_name: sig.parameters[param_name].default for param_name in sig.parameters.keys()} - - # Process each argument's type to generate its JSON schema. - for k, v in hints.items(): - arg_types[k] = v - if k == "return": - continue - # Check if the type (or its origin) is a subclass of Pydantic's BaseModel - origin = get_origin(v) or v - if isinstance(origin, type) and issubclass(origin, BaseModel): - # Get json schema, and replace $ref with the actual schema - v_json_schema = resolve_json_schema_reference(v.model_json_schema()) - args[k] = v_json_schema - else: - args[k] = TypeAdapter(v).json_schema() - if default_values[k] is not inspect.Parameter.empty: - args[k]["default"] = default_values[k] - if arg_desc and k in arg_desc: - args[k]["description"] = arg_desc[k] - - self.name = self.name or name - self.desc = self.desc or desc - self.args = self.args or args - self.arg_types = self.arg_types or arg_types - self.has_kwargs = any(param.kind == param.VAR_KEYWORD for param in sig.parameters.values()) - - def _validate_and_parse_args(self, **kwargs): - # Validate the args value comply to the json schema. - for k, v in kwargs.items(): - if k not in self.args: - if self.has_kwargs: - continue - else: - raise ValueError(f"Arg {k} is not in the tool's args.") - try: - instance = v.model_dump() if hasattr(v, "model_dump") else v - type_str = self.args[k].get("type") - if type_str is not None and type_str != "Any": - validate(instance=instance, schema=self.args[k]) - except ValidationError as e: - raise ValueError(f"Arg {k} is invalid: {e.message}") - - # Parse the args to the correct type. - parsed_kwargs = {} - for k, v in kwargs.items(): - if k in self.arg_types and self.arg_types[k] != Any: - # Create a pydantic model wrapper with a dummy field `value` to parse the arg to the correct type. - # This is specifically useful for handling nested Pydantic models like `list[list[MyPydanticModel]]` - pydantic_wrapper = create_model("Wrapper", value=(self.arg_types[k], ...)) - parsed = pydantic_wrapper.model_validate({"value": v}) - parsed_kwargs[k] = parsed.value - else: - parsed_kwargs[k] = v - return parsed_kwargs - - def format(self, **kwargs): - return [ - { - "type": "function", - "function": { - "name": self.name, - "description": self.desc, - "parameters": self.args, - "requirements": "Arguments must be provided in JSON format.", - }, - } - ] - - @with_callbacks - def __call__(self, **kwargs): - parsed_kwargs = self._validate_and_parse_args(**kwargs) - result = self.func(**parsed_kwargs) - if asyncio.iscoroutine(result): - raise ValueError("You are calling `__call__` on an async tool, please use `acall` instead.") - return result - - @with_callbacks - async def acall(self, **kwargs): - parsed_kwargs = self._validate_and_parse_args(**kwargs) - result = self.func(**parsed_kwargs) - if asyncio.iscoroutine(result): - return await result - else: - # We should allow calling a sync tool in the async path. - return result - - @classmethod - def from_mcp_tool(cls, session: "mcp.client.session.ClientSession", tool: "mcp.types.Tool") -> "Tool": - """ - Build a DSPy tool from an MCP tool and a ClientSession. - - Args: - session: The MCP session to use. - tool: The MCP tool to convert. - - Returns: - A Tool object. - """ - from dspy.utils.mcp import convert_mcp_tool - - return convert_mcp_tool(session, tool) - - def __repr__(self): - return f"Tool(name={self.name}, desc={self.desc}, args={self.args})" - - def __str__(self): - desc = f", whose description is {self.desc}.".replace("\n", " ") if self.desc else "." - arg_desc = f"It takes arguments {self.args} in JSON format." - return f"{self.name}{desc} {arg_desc}" - - -def resolve_json_schema_reference(schema: dict) -> dict: - """Recursively resolve json model schema, expanding all references.""" - - # If there are no definitions to resolve, return the main schema - if "$defs" not in schema and "definitions" not in schema: - return schema - - def resolve_refs(obj: Any) -> Any: - if not isinstance(obj, (dict, list)): - return obj - if isinstance(obj, dict): - if "$ref" in obj: - ref_path = obj["$ref"].split("/")[-1] - return resolve_refs(schema["$defs"][ref_path]) - return {k: resolve_refs(v) for k, v in obj.items()} - - # Must be a list - return [resolve_refs(item) for item in obj] - - # Resolve all references in the main schema - resolved_schema = resolve_refs(schema) - # Remove the $defs key as it's no longer needed - resolved_schema.pop("$defs", None) - return resolved_schema From a2237ee0f22f9c2e640db8fc8e5d5033377cc263 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Thu, 22 May 2025 12:40:35 -0700 Subject: [PATCH 04/10] increment --- dspy/adapters/base.py | 34 ++++++++++++++++++++++++++++------ dspy/adapters/json_adapter.py | 3 +++ 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index aa453e7498..ece16ac6b6 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -1,7 +1,8 @@ -from typing import TYPE_CHECKING, Any, Optional, Type +from typing import TYPE_CHECKING, Any, Optional, Type, get_origin from dspy.adapters.types import History from dspy.adapters.types.base_type import split_message_content_for_custom_types +from dspy.adapters.types.tool import Tool from dspy.signatures.signature import Signature from dspy.utils.callback import BaseCallback, with_callbacks @@ -20,7 +21,16 @@ def __init_subclass__(cls, **kwargs) -> None: cls.format = with_callbacks(cls.format) cls.parse = with_callbacks(cls.parse) - def _call_post_process(self, outputs: list[dict[str, Any]], signature: Type[Signature]) -> list[dict[str, Any]]: + def _call_preprocess( + self, + lm: "LM", + lm_kwargs: dict[str, Any], + signature: Type[Signature], + inputs: dict[str, Any], + ) -> dict[str, Any]: + tool_call_field_name = self._get_tool_call_field_name(signature) + + def _call_postprocess(self, outputs: list[dict[str, Any]], signature: Type[Signature]) -> list[dict[str, Any]]: values = [] for output in outputs: @@ -46,10 +56,11 @@ def __call__( demos: list[dict[str, Any]], inputs: dict[str, Any], ) -> list[dict[str, Any]]: - inputs = self.format(signature, demos, inputs) + processed_signature = self._call_preprocess(lm, lm_kwargs, signature, inputs) + inputs = self.format(processed_signature, demos, inputs) outputs = lm(messages=inputs, **lm_kwargs) - return self._call_post_process(outputs, signature) + return self._call_postprocess(outputs, processed_signature) async def acall( self, @@ -59,10 +70,11 @@ async def acall( demos: list[dict[str, Any]], inputs: dict[str, Any], ) -> list[dict[str, Any]]: - inputs = self.format(signature, demos, inputs) + processed_signature = self._call_preprocess(inputs, signature) + inputs = self.format(processed_signature, demos, inputs) outputs = await lm.acall(messages=inputs, **lm_kwargs) - return self._call_post_process(outputs, signature) + return self._call_postprocess(outputs, processed_signature) def format( self, @@ -297,6 +309,16 @@ def _get_history_field_name(self, signature: Type[Signature]) -> bool: return name return None + def _get_tool_call_field_name(self, signature: Type[Signature]) -> bool: + for name, field in signature.input_fields.items(): + # Look for annotation `list[dspy.Tool]` or `dspy.Tool` + origin = get_origin(field.annotation) + if origin is list and field.annotation.__args__[0] == Tool: + return name + if field.annotation == Tool: + return name + return None + def format_conversation_history( self, signature: Type[Signature], diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index 018bd5f415..c2565ed5b5 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -57,6 +57,9 @@ def __call__( lm_kwargs["response_format"] = {"type": "json_object"} return super().__call__(lm, lm_kwargs, signature, demos, inputs) + if litellm.supports_function_calling(model=lm.model): + lm_kwargs["tools"] = self._get_tool_calling_value(signature, inputs) + # Try structured output first, fall back to basic JSON if it fails. try: structured_output_model = _get_structured_outputs_response_format(signature) From e585ac0297f035a917721ad01ddc2b7a6703868a Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Thu, 22 May 2025 16:26:21 -0700 Subject: [PATCH 05/10] better ways --- dspy/__init__.py | 2 +- dspy/adapters/__init__.py | 4 +- dspy/adapters/base.py | 86 ++++++++++++++++++++++++----- dspy/adapters/json_adapter.py | 13 ++++- dspy/adapters/types/__init__.py | 4 +- dspy/adapters/types/tool.py | 48 ++++++++++++++-- dspy/clients/base_lm.py | 30 ++++++---- dspy/utils/dummies.py | 1 + tests/adapters/test_chat_adapter.py | 4 +- tests/adapters/test_json_adapter.py | 49 +++++++++++++++- 10 files changed, 199 insertions(+), 42 deletions(-) diff --git a/dspy/__init__.py b/dspy/__init__.py index 38a2e85748..b61d9d2b65 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -8,7 +8,7 @@ from dspy.evaluate import Evaluate # isort: skip from dspy.clients import * # isort: skip -from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, TwoStepAdapter, Image, Audio, History, BaseType, Tool, ToolCall # isort: skip +from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, TwoStepAdapter, Image, Audio, History, BaseType, Tool, ToolCalls # isort: skip from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging from dspy.utils.asyncify import asyncify from dspy.utils.saving import load diff --git a/dspy/adapters/__init__.py b/dspy/adapters/__init__.py index c724d24b19..15fc74a8a9 100644 --- a/dspy/adapters/__init__.py +++ b/dspy/adapters/__init__.py @@ -2,7 +2,7 @@ from dspy.adapters.chat_adapter import ChatAdapter from dspy.adapters.json_adapter import JSONAdapter from dspy.adapters.two_step_adapter import TwoStepAdapter -from dspy.adapters.types import History, Image, Audio, BaseType, Tool, ToolCall +from dspy.adapters.types import History, Image, Audio, BaseType, Tool, ToolCalls __all__ = [ "Adapter", @@ -14,5 +14,5 @@ "JSONAdapter", "TwoStepAdapter", "Tool", - "ToolCall", + "ToolCalls", ] diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index ece16ac6b6..611d4bfa72 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -1,11 +1,17 @@ +import logging from typing import TYPE_CHECKING, Any, Optional, Type, get_origin +import json_repair +import litellm + from dspy.adapters.types import History from dspy.adapters.types.base_type import split_message_content_for_custom_types -from dspy.adapters.types.tool import Tool +from dspy.adapters.types.tool import Tool, ToolCalls from dspy.signatures.signature import Signature from dspy.utils.callback import BaseCallback, with_callbacks +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from dspy.clients.lm import LM @@ -27,21 +33,67 @@ def _call_preprocess( lm_kwargs: dict[str, Any], signature: Type[Signature], inputs: dict[str, Any], + use_native_function_calling: bool = False, ) -> dict[str, Any]: - tool_call_field_name = self._get_tool_call_field_name(signature) + if use_native_function_calling: + tool_call_input_field_name = self._get_tool_call_input_field_name(signature) + tool_call_output_field_name = self._get_tool_call_output_field_name(signature) - def _call_postprocess(self, outputs: list[dict[str, Any]], signature: Type[Signature]) -> list[dict[str, Any]]: - values = [] + if tool_call_output_field_name and tool_call_input_field_name is None: + raise ValueError( + f"You provided an output field {tool_call_output_field_name} to receive the tool calls information, " + "but did not provide any tools as the input. Please provide a list of tools as the input by adding an " + "input field with type `list[dspy.Tool]`." + ) - for output in outputs: - output_logprobs = None + if tool_call_output_field_name and litellm.supports_function_calling(model=lm.model): + tools = inputs[tool_call_input_field_name] + tools = tools if isinstance(tools, list) else [tools] + + litellm_tools = [] + for tool in tools: + litellm_tools.append(tool.format_as_litellm_function_call()) + + lm_kwargs["tools"] = litellm_tools + + signature_for_native_function_calling = signature.delete(tool_call_output_field_name) + + return signature_for_native_function_calling + + return signature - if isinstance(output, dict): - output, output_logprobs = output["text"], output["logprobs"] + def _call_postprocess( + self, + signature: Type[Signature], + outputs: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + values = [] - value = self.parse(signature, output) + tool_call_output_field_name = self._get_tool_call_output_field_name(signature) - if output_logprobs is not None: + for output in outputs: + text = output["text"] + output_logprobs = output.get("logprobs") + tool_calls = output.get("tool_calls") + + if text: + value = self.parse(signature, text) + else: + value = {} + for field_name in signature.output_fields.keys(): + value[field_name] = None + + if tool_calls and tool_call_output_field_name: + tool_calls = [ + { + "name": v["function"]["name"], + "args": json_repair.loads(v["function"]["arguments"]), + } + for v in tool_calls + ] + value[tool_call_output_field_name] = ToolCalls.from_dict_list(tool_calls) + + if output_logprobs: value["logprobs"] = output_logprobs values.append(value) @@ -60,7 +112,7 @@ def __call__( inputs = self.format(processed_signature, demos, inputs) outputs = lm(messages=inputs, **lm_kwargs) - return self._call_postprocess(outputs, processed_signature) + return self._call_postprocess(signature, outputs) async def acall( self, @@ -70,11 +122,11 @@ async def acall( demos: list[dict[str, Any]], inputs: dict[str, Any], ) -> list[dict[str, Any]]: - processed_signature = self._call_preprocess(inputs, signature) + processed_signature = self._call_preprocess(lm, lm_kwargs, signature, inputs) inputs = self.format(processed_signature, demos, inputs) outputs = await lm.acall(messages=inputs, **lm_kwargs) - return self._call_postprocess(outputs, processed_signature) + return self._call_postprocess(signature, outputs) def format( self, @@ -309,7 +361,7 @@ def _get_history_field_name(self, signature: Type[Signature]) -> bool: return name return None - def _get_tool_call_field_name(self, signature: Type[Signature]) -> bool: + def _get_tool_call_input_field_name(self, signature: Type[Signature]) -> bool: for name, field in signature.input_fields.items(): # Look for annotation `list[dspy.Tool]` or `dspy.Tool` origin = get_origin(field.annotation) @@ -319,6 +371,12 @@ def _get_tool_call_field_name(self, signature: Type[Signature]) -> bool: return name return None + def _get_tool_call_output_field_name(self, signature: Type[Signature]) -> bool: + for name, field in signature.output_fields.items(): + if field.annotation == ToolCalls: + return name + return None + def format_conversation_history( self, signature: Type[Signature], diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index c2565ed5b5..e507c66e09 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -57,9 +57,6 @@ def __call__( lm_kwargs["response_format"] = {"type": "json_object"} return super().__call__(lm, lm_kwargs, signature, demos, inputs) - if litellm.supports_function_calling(model=lm.model): - lm_kwargs["tools"] = self._get_tool_calling_value(signature, inputs) - # Try structured output first, fall back to basic JSON if it fails. try: structured_output_model = _get_structured_outputs_response_format(signature) @@ -80,6 +77,16 @@ def __call__( f"`response_format` argument. Original error: {e}" ) from e + def _call_preprocess( + self, + lm: "LM", + lm_kwargs: dict[str, Any], + signature: Type[Signature], + inputs: dict[str, Any], + use_native_function_calling: bool = True, + ) -> dict[str, Any]: + return super()._call_preprocess(lm, lm_kwargs, signature, inputs, use_native_function_calling) + def format_field_structure(self, signature: Type[Signature]) -> str: parts = [] parts.append("All interactions will be structured in the following way, with the appropriate values filled in.") diff --git a/dspy/adapters/types/__init__.py b/dspy/adapters/types/__init__.py index ac42265e87..11d09227c5 100644 --- a/dspy/adapters/types/__init__.py +++ b/dspy/adapters/types/__init__.py @@ -2,6 +2,6 @@ from dspy.adapters.types.image import Image from dspy.adapters.types.audio import Audio from dspy.adapters.types.base_type import BaseType -from dspy.adapters.types.tool import Tool, ToolCall +from dspy.adapters.types.tool import Tool, ToolCalls -__all__ = ["History", "Image", "Audio", "BaseType", "Tool", "ToolCall"] +__all__ = ["History", "Image", "Audio", "BaseType", "Tool", "ToolCalls"] diff --git a/dspy/adapters/types/tool.py b/dspy/adapters/types/tool.py index 8e7a9b7da3..28411cbeb6 100644 --- a/dspy/adapters/types/tool.py +++ b/dspy/adapters/types/tool.py @@ -143,6 +143,20 @@ def _validate_and_parse_args(self, **kwargs): def format(self): return str(self) + def format_as_litellm_function_call(self): + return { + "type": "function", + "function": { + "name": self.name, + "description": self.desc, + "parameters": { + "type": "object", + "properties": self.args, + "required": list(self.args.keys()), + }, + }, + } + @with_callbacks def __call__(self, **kwargs): parsed_kwargs = self._validate_and_parse_args(**kwargs) @@ -186,14 +200,40 @@ def __str__(self): return f"{self.name}{desc} {arg_desc}" -class ToolCall(BaseType): - name: str - args: dict[str, Any] +class ToolCalls(BaseType): + class ToolCall(BaseModel): + name: str + args: dict[str, Any] + + tool_calls: list[ToolCall] + + @classmethod + def from_dict_list(cls, dict_list: list[dict[str, Any]]) -> "ToolCalls": + """Convert a list of dictionaries to a ToolCalls instance. + + Args: + dict_list: A list of dictionaries, where each dictionary should have 'name' and 'args' keys. + + Returns: + A ToolCalls instance. + + Example: + + ```python + tool_calls_dict = [ + {"name": "search", "args": {"query": "hello"}}, + {"name": "translate", "args": {"text": "world"}} + ] + tool_calls = ToolCalls.from_dict_list(tool_calls_dict) + ``` + """ + tool_calls = [cls.ToolCall(**item) for item in dict_list] + return cls(tool_calls=tool_calls) @classmethod def description(cls) -> str: return ( - "Tool call information, including the name of the tool and the arguments to be passed to it. " + "Tool calls information, including the name of the tools and the arguments to be passed to it. " "Arguments must be provided in JSON format." ) diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index 3a74bdae9c..54ca558642 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -49,16 +49,16 @@ def __init__(self, model, model_type="chat", temperature=0.0, max_tokens=1000, c def _process_lm_response(self, response, prompt, messages, **kwargs): merged_kwargs = {**self.kwargs, **kwargs} - if merged_kwargs.get("logprobs"): - outputs = [ - { - "text": c.message.content if hasattr(c, "message") else c["text"], - "logprobs": c.logprobs if hasattr(c, "logprobs") else c["logprobs"], - } - for c in response.choices - ] - else: - outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response.choices] + + outputs = [] + for c in response.choices: + output = {} + output["text"] = c.message.content if hasattr(c, "message") else c["text"] + if merged_kwargs.get("logprobs"): + output["logprobs"] = c.logprobs if hasattr(c, "logprobs") else c["logprobs"] + if getattr(c.message, "tool_calls", None): + output["tool_calls"] = c.message.tool_calls + outputs.append(output) if settings.disable_history: return outputs @@ -191,8 +191,14 @@ def _inspect_history(history, n: int = 1): print(_blue(audio_str.strip())) print("\n") - print(_red("Response:")) - print(_green(outputs[0].strip())) + if outputs[0]["text"]: + print(_red("Response:")) + print(_green(outputs[0]["text"].strip())) + + if outputs[0].get("tool_calls"): + print(_red("Tool calls:")) + for tool_call in outputs[0]["tool_calls"]: + print(_green(f"{tool_call['function']['name']}: {tool_call['function']['arguments']}")) if len(outputs) > 1: choices_text = f" \t (and {len(outputs) - 1} other completions)" diff --git a/dspy/utils/dummies.py b/dspy/utils/dummies.py index 26c06f74f9..e5d53fc348 100644 --- a/dspy/utils/dummies.py +++ b/dspy/utils/dummies.py @@ -119,6 +119,7 @@ def format_answer_fields(field_names_and_values: Dict[str, Any]): else: outputs.append(format_answer_fields(next(self.answers, {"answer": "No more responses"}))) + outputs = [{"text": output} for output in outputs] # Logging, with removed api key & where `cost` is None on cache hit. kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")} entry = {"prompt": prompt, "messages": messages, "kwargs": kwargs} diff --git a/tests/adapters/test_chat_adapter.py b/tests/adapters/test_chat_adapter.py index eba77a1397..816c04a313 100644 --- a/tests/adapters/test_chat_adapter.py +++ b/tests/adapters/test_chat_adapter.py @@ -324,7 +324,7 @@ class MySignature(dspy.Signature): question: str = dspy.InputField() tools: list[dspy.Tool] = dspy.InputField() answer: str = dspy.OutputField() - tool_calls: list[dspy.ToolCall] = dspy.OutputField() + tool_calls: dspy.ToolCalls = dspy.OutputField() def get_weather(city: str) -> str: """Get the weather for a city""" @@ -342,7 +342,7 @@ def get_population(country: str, year: int) -> str: assert len(messages) == 2 # The output field type description should be included in the system message even if the output field is nested - assert dspy.ToolCall.description() in messages[0]["content"] + assert dspy.ToolCalls.description() in messages[0]["content"] # The user message should include the question and the tools assert "What is the weather in Tokyo?" in messages[1]["content"] diff --git a/tests/adapters/test_json_adapter.py b/tests/adapters/test_json_adapter.py index decbc34f0b..622458114c 100644 --- a/tests/adapters/test_json_adapter.py +++ b/tests/adapters/test_json_adapter.py @@ -347,7 +347,7 @@ class MySignature(dspy.Signature): question: str = dspy.InputField() tools: list[dspy.Tool] = dspy.InputField() answer: str = dspy.OutputField() - tool_calls: list[dspy.ToolCall] = dspy.OutputField() + tool_calls: dspy.ToolCalls = dspy.OutputField() def get_weather(city: str) -> str: """Get the weather for a city""" @@ -365,7 +365,7 @@ def get_population(country: str, year: int) -> str: assert len(messages) == 2 # The output field type description should be included in the system message even if the output field is nested - assert dspy.ToolCall.description() in messages[0]["content"] + assert dspy.ToolCalls.description() in messages[0]["content"] # The user message should include the question and the tools assert "What is the weather in Tokyo?" in messages[1]["content"] @@ -375,3 +375,48 @@ def get_population(country: str, year: int) -> str: # Tool arguments format should be included in the user message assert "{'city': {'type': 'string'}}" in messages[1]["content"] assert "{'country': {'type': 'string'}, 'year': {'type': 'integer'}}" in messages[1]["content"] + + with mock.patch("litellm.completion") as mock_completion: + lm = dspy.LM(model="openai/gpt-4o-mini") + adapter(lm, {}, MySignature, [], {"question": "What is the weather in Tokyo?", "tools": tools}) + + mock_completion.assert_called_once() + _, call_kwargs = mock_completion.call_args + + # Assert tool calls are included in the `tools` arg + assert len(call_kwargs["tools"]) > 0 + assert call_kwargs["tools"][0] == { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather for a city", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + }, + }, + "required": ["city"], + }, + }, + } + assert call_kwargs["tools"][1] == { + "type": "function", + "function": { + "name": "get_population", + "description": "Get the population for a country", + "parameters": { + "type": "object", + "properties": { + "country": { + "type": "string", + }, + "year": { + "type": "integer", + }, + }, + "required": ["country", "year"], + }, + }, + } From 5a09c2515dda88e6ecde8f635cda6842508bdbc1 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Thu, 22 May 2025 17:06:41 -0700 Subject: [PATCH 06/10] return a list from lm for backward compatibility --- dspy/adapters/base.py | 11 ++++++++--- dspy/adapters/two_step_adapter.py | 22 ++++++++++++++++++++-- dspy/clients/base_lm.py | 4 ++++ dspy/utils/dummies.py | 1 - 4 files changed, 32 insertions(+), 6 deletions(-) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 611d4bfa72..a8cced8827 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -72,9 +72,14 @@ def _call_postprocess( tool_call_output_field_name = self._get_tool_call_output_field_name(signature) for output in outputs: - text = output["text"] - output_logprobs = output.get("logprobs") - tool_calls = output.get("tool_calls") + output_logprobs = None + tool_calls = None + text = output + + if isinstance(output, dict): + text = output["text"] + output_logprobs = output.get("logprobs") + tool_calls = output.get("tool_calls") if text: value = self.parse(signature, text) diff --git a/dspy/adapters/two_step_adapter.py b/dspy/adapters/two_step_adapter.py index 9ee91458f3..64f1b80eda 100644 --- a/dspy/adapters/two_step_adapter.py +++ b/dspy/adapters/two_step_adapter.py @@ -1,7 +1,10 @@ from typing import Any, Optional, Type +import json_repair + from dspy.adapters.base import Adapter from dspy.adapters.chat_adapter import ChatAdapter +from dspy.adapters.types import ToolCalls from dspy.adapters.utils import get_field_description_string from dspy.clients import LM from dspy.signatures.field import InputField @@ -115,11 +118,16 @@ async def acall( values = [] + tool_call_output_field_name = self._get_tool_call_output_field_name(signature) for output in outputs: output_logprobs = None + tool_calls = None + text = output if isinstance(output, dict): - output, output_logprobs = output["text"], output["logprobs"] + text = output["text"] + output_logprobs = output.get("logprobs") + tool_calls = output.get("tool_calls") try: # Call the smaller LM to extract structured data from the raw completion text with ChatAdapter @@ -128,13 +136,23 @@ async def acall( lm_kwargs={}, signature=extractor_signature, demos=[], - inputs={"text": output}, + inputs={"text": text}, ) value = value[0] except Exception as e: raise ValueError(f"Failed to parse response from the original completion: {output}") from e + if tool_calls and tool_call_output_field_name: + tool_calls = [ + { + "name": v["function"]["name"], + "args": json_repair.loads(v["function"]["arguments"]), + } + for v in tool_calls + ] + value[tool_call_output_field_name] = ToolCalls.from_dict_list(tool_calls) + if output_logprobs is not None: value["logprobs"] = output_logprobs diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index a4f7bff561..ecf85f40c5 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -61,6 +61,10 @@ def _process_lm_response(self, response, prompt, messages, **kwargs): output["tool_calls"] = c.message.tool_calls outputs.append(output) + if all(len(output) == 1 for output in outputs): + # Return a list if every output only has "text" key + outputs = [output["text"] for output in outputs] + if settings.disable_history: return outputs diff --git a/dspy/utils/dummies.py b/dspy/utils/dummies.py index e5d53fc348..26c06f74f9 100644 --- a/dspy/utils/dummies.py +++ b/dspy/utils/dummies.py @@ -119,7 +119,6 @@ def format_answer_fields(field_names_and_values: Dict[str, Any]): else: outputs.append(format_answer_fields(next(self.answers, {"answer": "No more responses"}))) - outputs = [{"text": output} for output in outputs] # Logging, with removed api key & where `cost` is None on cache hit. kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")} entry = {"prompt": prompt, "messages": messages, "kwargs": kwargs} From de57207be685cbff0dfd93ff8878ff2bf1871451 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Thu, 22 May 2025 17:10:49 -0700 Subject: [PATCH 07/10] fix tests --- dspy/clients/base_lm.py | 2 +- dspy/utils/inspect_history.py | 18 +++++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index ecf85f40c5..1def8bb83f 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -57,7 +57,7 @@ def _process_lm_response(self, response, prompt, messages, **kwargs): output["text"] = c.message.content if hasattr(c, "message") else c["text"] if merged_kwargs.get("logprobs"): output["logprobs"] = c.logprobs if hasattr(c, "logprobs") else c["logprobs"] - if getattr(c.message, "tool_calls", None): + if hasattr(c, "message") and getattr(c.message, "tool_calls", None): output["tool_calls"] = c.message.tool_calls outputs.append(output) diff --git a/dspy/utils/inspect_history.py b/dspy/utils/inspect_history.py index 7adc62abbe..10333c4eb0 100644 --- a/dspy/utils/inspect_history.py +++ b/dspy/utils/inspect_history.py @@ -43,14 +43,18 @@ def pretty_print_history(history, n: int = 1): print(_blue(image_str.strip())) print("\n") - if outputs[0]["text"]: + if isinstance(outputs[0], dict): + if outputs[0]["text"]: + print(_red("Response:")) + print(_green(outputs[0]["text"].strip())) + + if outputs[0].get("tool_calls"): + print(_red("Tool calls:")) + for tool_call in outputs[0]["tool_calls"]: + print(_green(f"{tool_call['function']['name']}: {tool_call['function']['arguments']}")) + else: print(_red("Response:")) - print(_green(outputs[0]["text"].strip())) - - if outputs[0].get("tool_calls"): - print(_red("Tool calls:")) - for tool_call in outputs[0]["tool_calls"]: - print(_green(f"{tool_call['function']['name']}: {tool_call['function']['arguments']}")) + print(_green(outputs[0].strip())) if len(outputs) > 1: choices_text = f" \t (and {len(outputs) - 1} other completions)" From ed3102810dcec00dc51f249bdb0d4fbcd499ef2f Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Sat, 24 May 2025 16:10:34 -0700 Subject: [PATCH 08/10] better name --- dspy/adapters/types/tool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dspy/adapters/types/tool.py b/dspy/adapters/types/tool.py index 28411cbeb6..31268abce2 100644 --- a/dspy/adapters/types/tool.py +++ b/dspy/adapters/types/tool.py @@ -208,7 +208,7 @@ class ToolCall(BaseModel): tool_calls: list[ToolCall] @classmethod - def from_dict_list(cls, dict_list: list[dict[str, Any]]) -> "ToolCalls": + def from_dict_list(cls, tool_calls_dicts: list[dict[str, Any]]) -> "ToolCalls": """Convert a list of dictionaries to a ToolCalls instance. Args: @@ -227,7 +227,7 @@ def from_dict_list(cls, dict_list: list[dict[str, Any]]) -> "ToolCalls": tool_calls = ToolCalls.from_dict_list(tool_calls_dict) ``` """ - tool_calls = [cls.ToolCall(**item) for item in dict_list] + tool_calls = [cls.ToolCall(**item) for item in tool_calls_dicts] return cls(tool_calls=tool_calls) @classmethod From 85a9bed688a4d5b97ed79a627829f46522c0d0cc Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Thu, 29 May 2025 13:58:59 -0700 Subject: [PATCH 09/10] fix tests --- dspy/predict/code_act.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/dspy/predict/code_act.py b/dspy/predict/code_act.py index 01b246916f..9a8465c870 100644 --- a/dspy/predict/code_act.py +++ b/dspy/predict/code_act.py @@ -1,15 +1,14 @@ -import logging import inspect - -from typing import Callable, Union, Type +import logging from inspect import Signature +from typing import Callable, Type, Union import dspy +from dspy.adapters.types.tool import Tool +from dspy.predict.program_of_thought import ProgramOfThought +from dspy.predict.react import ReAct from dspy.primitives.python_interpreter import PythonInterpreter -from dspy.primitives.tool import Tool from dspy.signatures.signature import ensure_signature -from dspy.predict.react import ReAct -from dspy.predict.program_of_thought import ProgramOfThought logger = logging.getLogger(__name__) From 12886725f92d52d677362cc9abb14e63551cfd7c Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Mon, 2 Jun 2025 14:15:02 -0700 Subject: [PATCH 10/10] fix invalid import path --- dspy/utils/langchain_tool.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dspy/utils/langchain_tool.py b/dspy/utils/langchain_tool.py index 3417b1f0f5..cce44660ad 100644 --- a/dspy/utils/langchain_tool.py +++ b/dspy/utils/langchain_tool.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, Any -from dspy.primitives.tool import Tool, convert_input_schema_to_tool_args + +from dspy.adapters.types.tool import Tool, convert_input_schema_to_tool_args if TYPE_CHECKING: from langchain.tools import BaseTool