diff --git a/dspy/predict/__init__.py b/dspy/predict/__init__.py index 00fd80d760..f1f54c3375 100644 --- a/dspy/predict/__init__.py +++ b/dspy/predict/__init__.py @@ -8,6 +8,7 @@ from dspy.predict.predict import Predict from dspy.predict.program_of_thought import ProgramOfThought from dspy.predict.react import ReAct, Tool +from dspy.predict.code_act import CodeAct from dspy.predict.refine import Refine __all__ = [ @@ -15,6 +16,7 @@ "BestOfN", "ChainOfThought", "ChainOfThoughtWithHint", + "CodeAct", "KNN", "MultiChainComparison", "Predict", diff --git a/dspy/predict/code_act.py b/dspy/predict/code_act.py new file mode 100644 index 0000000000..01b246916f --- /dev/null +++ b/dspy/predict/code_act.py @@ -0,0 +1,121 @@ +import logging +import inspect + +from typing import Callable, Union, Type +from inspect import Signature + +import dspy +from dspy.primitives.python_interpreter import PythonInterpreter +from dspy.primitives.tool import Tool +from dspy.signatures.signature import ensure_signature +from dspy.predict.react import ReAct +from dspy.predict.program_of_thought import ProgramOfThought + +logger = logging.getLogger(__name__) + +class CodeAct(ReAct, ProgramOfThought): + """ + CodeAct is a module that utilizes the Code Interpreter and predefined tools to solve the problem. + """ + + def __init__(self, signature: Union[str, Type[Signature]], tools: list[Callable], max_iters: int = 5): + """ + Initializes the CodeAct class with the specified model, temperature, and max tokens. + + Args: + signature (Union[str, Type[Signature]]): The signature of the module. + tools (list[Callable]): The tool callables to be used. CodeAct only accepts functions and not callable objects. + max_iters (int): The maximum number of iterations to generate the answer. + + Example: + + ```python + from dspy.predict import CodeAct + def factorial(n): + if n == 1: + return 1 + return n * factorial(n-1) + + act = CodeAct("n->factorial", tools=[factorial]) + act(n=5) # 120 + ``` + """ + self.signature = ensure_signature(signature) + self.max_iters = max_iters + + tools = [t if isinstance(t, Tool) else Tool(t) for t in tools] + if any( + not inspect.isfunction(tool.func) for tool in tools + ): + raise ValueError("CodeAct only accepts functions and not callable objects.") + tools = {tool.name: tool for tool in tools} + + instructions = self._build_instructions(self.signature, tools) + + codeact_signature = ( + dspy.Signature({**self.signature.input_fields}, "\n".join(instructions)) + .append("trajectory", dspy.InputField(), type_=str) + .append("generated_code", dspy.OutputField(desc="Python code that when executed, produces output relevant to answering the question"), type_=str) + .append("finished", dspy.OutputField(desc="a boolean flag to determine if the process is done"), type_=bool) + ) + + extract_signature = dspy.Signature( + {**self.signature.input_fields, **self.signature.output_fields}, + self.signature.instructions, + ).append("trajectory", dspy.InputField(), type_=str) + + self.tools: dict[str, Tool] = tools + self.codeact = dspy.Predict(codeact_signature) + self.extractor = dspy.ChainOfThought(extract_signature) + # It will raises exception when dspy cannot find available deno instance by now. + self.interpreter = PythonInterpreter() + + def _build_instructions(self, signature, tools): + instructions = [f"{signature.instructions}\n"] if signature.instructions else [] + inputs = ", ".join([f"`{k}`" for k in signature.input_fields.keys()]) + outputs = ", ".join([f"`{k}`" for k in signature.output_fields.keys()]) + + instructions.append( + f"You are an intelligent agent. For each episode, you will receive the fields {inputs} as input.\n" + f"Your goal is to generate executable Python code that collects any necessary information for producing {outputs}.\n" + "For each iteration, you will generate a code snippet that either solves the task or progresses towards the solution.\n" + "Ensure any output you wish to extract from the code is printed to the console. The code should be enclosed in a fenced code block.\n" + f"When all information for producing the outputs ({outputs}) are available to be extracted, mark `finished=True` besides the final Python code.\n" + "You have access to the Python Standard Library and the following functions:" + ) + + for idx, tool in enumerate(tools.values()): + instructions.append(f"({idx + 1}) {tool}") + + return instructions + + def forward(self, **kwargs): + # Define the tool funcitons in the interpreter + for tool in self.tools.values(): + self.interpreter(inspect.getsource(tool.func)) + + trajectory = {} + max_iters = kwargs.pop("max_iters", self.max_iters) + for idx in range(max_iters): + code_data = self.codeact(trajectory=trajectory, **kwargs) + output = None + code, error = self._parse_code(code_data) + + if error: + trajectory[f"observation_{idx}"] = f"Failed to parse the generated code: {error}" + continue + + trajectory[f"generated_code_{idx}"] = code + output, error = self._execute_code(code) + + if not error: + trajectory[f"code_output_{idx}"] = output + else: + trajectory[f"observation_{idx}"] = f"Failed to execute the generated code: {error}" + + if code_data.finished: + break + + extract = self._call_with_potential_trajectory_truncation(self.extractor, trajectory, **kwargs) + self.interpreter.shutdown() + return dspy.Prediction(trajectory=trajectory, **extract) diff --git a/dspy/predict/react.py b/dspy/predict/react.py index f6098672a9..d0b7542fd5 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -47,6 +47,7 @@ def __init__(self, signature, tools: list[Callable], max_iters=5): for idx, tool in enumerate(tools.values()): instr.append(f"({idx + 1}) {tool}") + instr.append("When providing `next_tool_args`, the value inside the field must be in JSON format") react_signature = ( dspy.Signature({**signature.input_fields}, "\n".join(instr)) diff --git a/dspy/primitives/tool.py b/dspy/primitives/tool.py index bf4a208f5a..7d750e3b6d 100644 --- a/dspy/primitives/tool.py +++ b/dspy/primitives/tool.py @@ -177,7 +177,7 @@ def __repr__(self): def __str__(self): desc = f", whose description is {self.desc}.".replace("\n", " ") if self.desc else "." - arg_desc = f"It takes arguments {self.args} in JSON format." + arg_desc = f"It takes arguments {self.args}." return f"{self.name}{desc} {arg_desc}" diff --git a/tests/predict/test_code_act.py b/tests/predict/test_code_act.py new file mode 100644 index 0000000000..2565932648 --- /dev/null +++ b/tests/predict/test_code_act.py @@ -0,0 +1,147 @@ +import pytest +import shutil + +import dspy +from dspy import Signature +from dspy.predict import CodeAct +from dspy.utils import DummyLM + +# This test suite requires deno to be installed. Please install deno following https://docs.deno.com/runtime/getting_started/installation/ +is_deno_available = shutil.which("deno") is not None +skip_if_deno_not_available = pytest.mark.skipif( + not is_deno_available, reason="Deno is not installed or not in PATH" +) + + +class BasicQA(Signature): + question = dspy.InputField() + answer = dspy.OutputField(desc="often between 1 and 5 words") + +def add(a: float, b: float) -> float: + "add two numbers" + return a + b + +@skip_if_deno_not_available +def test_codeact_code_generation(): + lm = DummyLM( + [ + { + "reasoning": "Reason_A", + "generated_code": "```python\nresult = add(1,1)\nprint(result)\n```", + "finished": True, + }, + {"reasoning": "Reason_B", "answer": "2"}, + ] + ) + dspy.settings.configure(lm=lm) + program = CodeAct(BasicQA, tools=[add]) + res = program(question="What is 1+1?") + assert res.answer == "2" + assert res.trajectory == { + 'code_output_0': '"2\\n"', + 'generated_code_0': 'result = add(1,1)\nprint(result)', + } + assert program.interpreter.deno_process is None + + +class ExtremumFinder(Signature): + input_list = dspy.InputField() + maximum = dspy.OutputField(desc="The maximum of the given numbers") + minimum = dspy.OutputField(desc="The minimum of the given numbers") + +def extract_maximum_minimum(input_list: str) -> dict[str, float]: + numbers = list(map(float, input_list.split(","))) + return {"maximum": max(numbers), "minimum": min(numbers)} + +@skip_if_deno_not_available +def test_codeact_support_multiple_fields(): + lm = DummyLM( + [ + { + "reasoning": "Reason_A", + "generated_code": "```python\nresult = extract_maximum_minimum('2, 3, 5, 6')\nprint(result)\n```", + "finished": True, + }, + {"reasoning": "Reason_B", "maximum": "6", "minimum": "2"}, + ] + ) + dspy.settings.configure(lm=lm) + program = CodeAct(ExtremumFinder, tools=[extract_maximum_minimum]) + res = program(input_list="2, 3, 5, 6") + assert res.maximum == "6" + assert res.minimum == "2" + assert res.trajectory == { + 'code_output_0': '"{\'maximum\': 6.0, \'minimum\': 2.0}\\n"', + 'generated_code_0': "result = extract_maximum_minimum('2, 3, 5, 6')\nprint(result)", + } + assert program.interpreter.deno_process is None + + +@skip_if_deno_not_available +def test_codeact_code_parse_failure(): + lm = DummyLM( + [ + { + "reasoning": "Reason_A", + "generated_code": "```python\nparse(error\n```", + "finished": False, + }, + { + "reasoning": "Reason_A", + "generated_code": "```python\nresult = add(1,1)\nprint(result)\n```", + "finished": True, + }, + {"reasoning": "Reason_B", "answer": "2"}, + ] + ) + dspy.settings.configure(lm=lm) + program = CodeAct(BasicQA, tools=[add]) + res = program(question="What is 1+1?") + assert res.answer == "2" + assert res.trajectory == { + 'generated_code_0': 'parse(error', + 'observation_0': 'Failed to execute the generated code: Invalid Python syntax. message: ', + 'generated_code_1': 'result = add(1,1)\nprint(result)', + 'code_output_1': '"2\\n"', + } + assert program.interpreter.deno_process is None + + +@skip_if_deno_not_available +def test_codeact_code_execution_failure(): + lm = DummyLM( + [ + { + "reasoning": "Reason_A", + "generated_code": "```python\nunknown+1\n```", + "finished": False, + }, + { + "reasoning": "Reason_A", + "generated_code": "```python\nresult = add(1,1)\nprint(result)\n```", + "finished": True, + }, + {"reasoning": "Reason_B", "answer": "2"}, + ] + ) + dspy.settings.configure(lm=lm) + program = CodeAct(BasicQA, tools=[add]) + res = program(question="What is 1+1?") + assert res.answer == "2" + assert res.trajectory == { + 'generated_code_0': 'unknown+1', + 'observation_0': 'Failed to execute the generated code: NameError: ["name \'unknown\' is not defined"]', + 'generated_code_1': 'result = add(1,1)\nprint(result)', + 'code_output_1': '"2\\n"', + } + assert program.interpreter.deno_process is None + + +class CustomTool: + def __call__(self, a: float, b: float) -> float: + return a + b + +@skip_if_deno_not_available +def test_codeact_tool_validation(): + with pytest.raises(ValueError, match="CodeAct only accepts functions and not callable objects."): + CodeAct(BasicQA, tools=[CustomTool()]) diff --git a/tests/primitives/test_tool.py b/tests/primitives/test_tool.py index 3f71392d94..dfee495c6f 100644 --- a/tests/primitives/test_tool.py +++ b/tests/primitives/test_tool.py @@ -261,7 +261,7 @@ def add(x: int, y: int = 0) -> int: tool = Tool(add) assert ( str(tool) - == "add, whose description is Add two integers.. It takes arguments {'x': {'type': 'integer'}, 'y': {'type': 'integer', 'default': 0}} in JSON format." + == "add, whose description is Add two integers.. It takes arguments {'x': {'type': 'integer'}, 'y': {'type': 'integer', 'default': 0}}." )