-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Support dspy.Tool
as input field type and dspy.ToolCall
as output field type
#8242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5ce5961
3e2b99a
b171e9f
a2237ee
e585ac0
acc60aa
5a09c25
de57207
ed31028
d3bdc4f
85a9bed
c2c4cf0
cc62448
1288672
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,17 @@ | ||
from typing import TYPE_CHECKING, Any, Optional, Type | ||
import logging | ||
from typing import TYPE_CHECKING, Any, Optional, Type, get_origin | ||
|
||
import json_repair | ||
import litellm | ||
|
||
from dspy.adapters.types import History | ||
from dspy.adapters.types.base_type import split_message_content_for_custom_types | ||
from dspy.adapters.types.tool import Tool, ToolCalls | ||
from dspy.signatures.signature import Signature | ||
from dspy.utils.callback import BaseCallback, with_callbacks | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
if TYPE_CHECKING: | ||
from dspy.clients.lm import LM | ||
|
||
|
@@ -20,18 +27,78 @@ def __init_subclass__(cls, **kwargs) -> None: | |
cls.format = with_callbacks(cls.format) | ||
cls.parse = with_callbacks(cls.parse) | ||
|
||
def _call_post_process(self, outputs: list[dict[str, Any]], signature: Type[Signature]) -> list[dict[str, Any]]: | ||
def _call_preprocess( | ||
self, | ||
lm: "LM", | ||
lm_kwargs: dict[str, Any], | ||
signature: Type[Signature], | ||
inputs: dict[str, Any], | ||
use_native_function_calling: bool = False, | ||
) -> dict[str, Any]: | ||
if use_native_function_calling: | ||
tool_call_input_field_name = self._get_tool_call_input_field_name(signature) | ||
tool_call_output_field_name = self._get_tool_call_output_field_name(signature) | ||
|
||
if tool_call_output_field_name and tool_call_input_field_name is None: | ||
raise ValueError( | ||
f"You provided an output field {tool_call_output_field_name} to receive the tool calls information, " | ||
"but did not provide any tools as the input. Please provide a list of tools as the input by adding an " | ||
"input field with type `list[dspy.Tool]`." | ||
) | ||
|
||
if tool_call_output_field_name and litellm.supports_function_calling(model=lm.model): | ||
tools = inputs[tool_call_input_field_name] | ||
tools = tools if isinstance(tools, list) else [tools] | ||
|
||
litellm_tools = [] | ||
for tool in tools: | ||
litellm_tools.append(tool.format_as_litellm_function_call()) | ||
|
||
lm_kwargs["tools"] = litellm_tools | ||
|
||
signature_for_native_function_calling = signature.delete(tool_call_output_field_name) | ||
|
||
return signature_for_native_function_calling | ||
|
||
return signature | ||
|
||
def _call_postprocess( | ||
self, | ||
signature: Type[Signature], | ||
outputs: list[dict[str, Any]], | ||
) -> list[dict[str, Any]]: | ||
values = [] | ||
|
||
tool_call_output_field_name = self._get_tool_call_output_field_name(signature) | ||
|
||
for output in outputs: | ||
output_logprobs = None | ||
tool_calls = None | ||
text = output | ||
|
||
if isinstance(output, dict): | ||
output, output_logprobs = output["text"], output["logprobs"] | ||
|
||
value = self.parse(signature, output) | ||
|
||
if output_logprobs is not None: | ||
text = output["text"] | ||
output_logprobs = output.get("logprobs") | ||
tool_calls = output.get("tool_calls") | ||
|
||
if text: | ||
value = self.parse(signature, text) | ||
else: | ||
value = {} | ||
for field_name in signature.output_fields.keys(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @chenmoneygithub qq - do we need to handle this tool-call specific logic here? I'm wondering that since it inherits from BaseType if we can add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Output handling is a bit different from input handling, which we can possibly generalize. However I would be cautious about doing it because we don't know yet if generalization makes sense here - for the |
||
value[field_name] = None | ||
|
||
if tool_calls and tool_call_output_field_name: | ||
tool_calls = [ | ||
{ | ||
"name": v["function"]["name"], | ||
"args": json_repair.loads(v["function"]["arguments"]), | ||
} | ||
for v in tool_calls | ||
] | ||
value[tool_call_output_field_name] = ToolCalls.from_dict_list(tool_calls) | ||
|
||
if output_logprobs: | ||
value["logprobs"] = output_logprobs | ||
|
||
values.append(value) | ||
|
@@ -46,10 +113,11 @@ def __call__( | |
demos: list[dict[str, Any]], | ||
inputs: dict[str, Any], | ||
) -> list[dict[str, Any]]: | ||
inputs = self.format(signature, demos, inputs) | ||
processed_signature = self._call_preprocess(lm, lm_kwargs, signature, inputs) | ||
inputs = self.format(processed_signature, demos, inputs) | ||
|
||
outputs = lm(messages=inputs, **lm_kwargs) | ||
return self._call_post_process(outputs, signature) | ||
return self._call_postprocess(signature, outputs) | ||
|
||
async def acall( | ||
self, | ||
|
@@ -59,10 +127,11 @@ async def acall( | |
demos: list[dict[str, Any]], | ||
inputs: dict[str, Any], | ||
) -> list[dict[str, Any]]: | ||
inputs = self.format(signature, demos, inputs) | ||
processed_signature = self._call_preprocess(lm, lm_kwargs, signature, inputs) | ||
inputs = self.format(processed_signature, demos, inputs) | ||
|
||
outputs = await lm.acall(messages=inputs, **lm_kwargs) | ||
return self._call_post_process(outputs, signature) | ||
return self._call_postprocess(signature, outputs) | ||
|
||
def format( | ||
self, | ||
|
@@ -297,6 +366,22 @@ def _get_history_field_name(self, signature: Type[Signature]) -> bool: | |
return name | ||
return None | ||
|
||
def _get_tool_call_input_field_name(self, signature: Type[Signature]) -> bool: | ||
for name, field in signature.input_fields.items(): | ||
# Look for annotation `list[dspy.Tool]` or `dspy.Tool` | ||
origin = get_origin(field.annotation) | ||
if origin is list and field.annotation.__args__[0] == Tool: | ||
return name | ||
if field.annotation == Tool: | ||
return name | ||
return None | ||
|
||
def _get_tool_call_output_field_name(self, signature: Type[Signature]) -> bool: | ||
for name, field in signature.output_fields.items(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens when there are multiple fields with type There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then only one field will be populated with value. We can raise a warning when there are multiple ToolCalls field, I kinda doubt if users will do that though. We made it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. had the same question for multimodal. in a single call (diff from could be different for ToolCalls tho, but might be safer to give that warning globally (maybe for the select ones in types?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with Arnav. For cases where users define multiple output fields with a type that can only have a single field, such as tool calls or multi-modal, can we raise an exception when the invalid signature is created rather than warning since this is a wrong usage of signature? |
||
if field.annotation == ToolCalls: | ||
return name | ||
return None | ||
|
||
def format_conversation_history( | ||
self, | ||
signature: Type[Signature], | ||
|
@@ -352,4 +437,4 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]: | |
Returns: | ||
A dictionary of the output fields. | ||
""" | ||
raise NotImplementedError | ||
raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar q for the preprocess too, could this instead be wrapped in Tool's
format
and make use ofsplit_message_content_for_custom_types
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not for the tool calling, here we need to handle the special case of native function calling: https://platform.openai.com/docs/guides/function-calling?api-mode=chat, so we need to modify the LM call args in addition to the messages.