Skip to content

Add CodeAct module #8222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dspy/predict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from dspy.predict.predict import Predict
from dspy.predict.program_of_thought import ProgramOfThought
from dspy.predict.react import ReAct, Tool
from dspy.predict.code_act import CodeAct
from dspy.predict.refine import Refine

__all__ = [
"majority",
"BestOfN",
"ChainOfThought",
"ChainOfThoughtWithHint",
"CodeAct",
"KNN",
"MultiChainComparison",
"Predict",
Expand Down
116 changes: 116 additions & 0 deletions dspy/predict/code_act.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import logging
import inspect

from typing import Callable, Union, Type
from inspect import Signature

import dspy
from dspy.primitives.python_interpreter import PythonInterpreter
from dspy.primitives.tool import Tool
from dspy.signatures.signature import ensure_signature
from dspy.predict.react import ReAct
from dspy.predict.program_of_thought import ProgramOfThought

logger = logging.getLogger(__name__)

class CodeAct(ReAct, ProgramOfThought):
"""
CodeAct is a module that utilizes the Code Interpreter and predefined tools to solve the problem.
"""

def __init__(self, signature: Union[str, Type[Signature]], tools: list[Callable], max_iters=5):
"""
Initializes the CodeAct class with the specified model, temperature, and max tokens.

Args:
signature (Union[str, Type[Signature]]): The signature of the module.
tools (list[Callable]): The tool callables to be used. CodeAct only accepts functions and not callable objects.
max_iters (int): The maximum number of iterations to generate the answer.

Example:

```python
from dspy.predict import CodeAct
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we follow up with a code tutorial demonstrating a use case of CodeAct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, I will file a follow up PR for documentation

def factorial(n):
if n == 1:
return 1
return n * factorial(n-1)

act = CodeAct("n->factorial", tools=[factorial])
act(n=5) # 120
```
"""
self.signature = ensure_signature(signature)
self.max_iters = max_iters

tools = [t if isinstance(t, Tool) else Tool(t) for t in tools]
if any(
not inspect.isfunction(tool.func) for tool in tools
):
raise ValueError("CodeAct only accepts functions and not callable objects.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious - why do we have this constraint?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it is hard to define the callable object in the python interpreter. Unlike functions, we need to define class and what parameters are passed to the object initialization, which is not visible.

tools = {tool.name: tool for tool in tools}

instructions = self._build_instructions(self.signature, tools)

codeact_signature = (
dspy.Signature({**self.signature.input_fields}, "\n".join(instructions))
.append("trajectory", dspy.InputField(), type_=str)
.append("generated_code", dspy.OutputField(desc="python code that answers the question"), type_=str)
.append("finished", dspy.OutputField(desc="a boolean flag to determine if the process is done"), type_=bool)
)

extract_signature = dspy.Signature(
{**self.signature.input_fields, **self.signature.output_fields},
self.signature.instructions,
).append("trajectory", dspy.InputField(), type_=str)

self.tools: dict[str, Tool] = tools
self.codeact = dspy.Predict(codeact_signature)
self.extractor = dspy.ChainOfThought(extract_signature)
# It will raises exception when dspy cannot find available deno instance by now.
self.interpreter = PythonInterpreter()

def _build_instructions(self, signature, tools):
instructions = [f"{signature.instructions}\n"] if signature.instructions else []
inputs = ", ".join([f"`{k}`" for k in signature.input_fields.keys()])
outputs = ", ".join([f"`{k}`" for k in signature.output_fields.keys()])

instructions.append(
f"You are an intelligent agent. For each episode, you will receive the fields {inputs} as input.\n"
f"Your goal is to generate executable Python code that collects any necessary information for producing {outputs}.\n"
"For each iteration, you will generate a code snippet that either solves the task or progresses towards the solution.\n"
"Ensure any output you wish to extract from the code is printed to the console. The code should be enclosed in a fenced code block.\n"
f"When all information for producing the outputs ({outputs}) are available to be extracted, mark `finished=True` besides the final Python code.\n"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should that be "besides the final Python code" or finished=True and generated_code is not None be mutual exclusive?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have an extractor at last, I designed it in a way that the code passed with finished=True is executed to minimize the interaction count.

"You have access to the Python Standard Library and the following functions:"
)

for idx, tool in enumerate(tools.values()):
instructions.append(f"({idx + 1}) {tool}")

return instructions

def forward(self, **kwargs):
# Define the tool funcitons in the interpreter
for tool in self.tools.values():
self.interpreter(inspect.getsource(tool.func))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we move it to init, and only run this line for call-time tools? (the new change in ReAct)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since python interpreter is shutdown every forward, this needs to happen in forward too.


trajectory = {}
max_iters = kwargs.pop("max_iters", self.max_iters)
for idx in range(max_iters):
code_data = self.codeact(trajectory=trajectory, **kwargs)
output = None
code, error = self._parse_code(code_data)
if not error:
output, error = self._execute_code(code)

trajectory[f"generated_code_{idx}"] = code
trajectory[f"code_output_{idx}"] = output
else:
trajectory[f"observation_{idx}"] = f"Execution error in {error}"

if code_data.finished:
break

extract = self._call_with_potential_trajectory_truncation(self.extractor, trajectory, **kwargs)
self.interpreter.shutdown()
return dspy.Prediction(trajectory=trajectory, **extract)
1 change: 1 addition & 0 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):

for idx, tool in enumerate(tools.values()):
instr.append(f"({idx + 1}) {tool}")
instr.append("You should pass the tool argument in JSON format")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe revise to "When providing next_tool_args, the value inside the field must be in JSON format".

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I'm confused about line 37, which has

                "When selecting the next_tool_name and its next_tool_args, the tool must be one of:\n",

I thought that was removed? Am I thinking of a different PR?

Copy link
Collaborator Author

@TomeHirata TomeHirata May 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 37 will be removed in #8190 👍


react_signature = (
dspy.Signature({**signature.input_fields}, "\n".join(instr))
Expand Down
2 changes: 1 addition & 1 deletion dspy/primitives/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __repr__(self):

def __str__(self):
desc = f", whose description is <desc>{self.desc}</desc>.".replace("\n", " ") if self.desc else "."
arg_desc = f"It takes arguments {self.args} in JSON format."
arg_desc = f"It takes arguments {self.args}."
return f"{self.name}{desc} {arg_desc}"


Expand Down
88 changes: 88 additions & 0 deletions tests/predict/test_code_act.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import pytest
import shutil

import dspy
from dspy import Signature
from dspy.predict import CodeAct
from dspy.utils import DummyLM

# This test suite requires deno to be installed. Please install deno following https://docs.deno.com/runtime/getting_started/installation/
is_deno_available = shutil.which("deno") is not None
skip_if_deno_not_available = pytest.mark.skipif(
not is_deno_available, reason="Deno is not installed or not in PATH"
)


class BasicQA(Signature):
question = dspy.InputField()
answer = dspy.OutputField(desc="often between 1 and 5 words")


def add(a: float, b: float) -> float:
"add two numbers"
return a + b


@skip_if_deno_not_available
def test_codeact_code_generation():
lm = DummyLM(
[
{
"reasoning": "Reason_A",
"generated_code": "```python\nresult = add(1,1)\nprint(result)\n```",
"finished": True,
},
{"reasoning": "Reason_B", "answer": "2"},
]
)
dspy.settings.configure(lm=lm)
program = CodeAct(BasicQA, tools=[add])
res = program(question="What is 1+1?")
assert res.answer == "2"
assert res.trajectory == {
'code_output_0': '"2\\n"',
'generated_code_0': 'result = add(1,1)\nprint(result)',
}
assert program.interpreter.deno_process is None


class ExtremumFinder(Signature):
input_list = dspy.InputField()
maximum = dspy.OutputField(desc="The maximum of the given numbers")
minimum = dspy.OutputField(desc="The minimum of the given numbers")

def extract_maximum_minimum(input_list: str) -> dict[str, float]:
numbers = list(map(float, input_list.split(",")))
return {"maximum": max(numbers), "minimum": min(numbers)}

@skip_if_deno_not_available
def test_codeact_support_multiple_fields():
lm = DummyLM(
[
{
"reasoning": "Reason_A",
"generated_code": "```python\nresult = extract_maximum_minimum('2, 3, 5, 6')\nprint(result)\n```",
"finished": True,
},
{"reasoning": "Reason_B", "maximum": "6", "minimum": "2"},
]
)
dspy.settings.configure(lm=lm)
program = CodeAct(ExtremumFinder, tools=[extract_maximum_minimum])
res = program(input_list="2, 3, 5, 6")
assert res.maximum == "6"
assert res.minimum == "2"
assert res.trajectory == {
'code_output_0': '"{\'maximum\': 6.0, \'minimum\': 2.0}\\n"',
'generated_code_0': "result = extract_maximum_minimum('2, 3, 5, 6')\nprint(result)",
}
assert program.interpreter.deno_process is None

class CustomTool:
def __call__(self, a: float, b: float) -> float:
return a + b

@skip_if_deno_not_available
def test_codeact_tool_validation():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may also want to test the input flow - use a mock litellm.completion to capture the prompt, and validate our tool + code information are correctly included in the prompt.

with pytest.raises(ValueError, match="CodeAct only accepts functions and not callable objects."):
CodeAct(BasicQA, tools=[CustomTool()])
2 changes: 1 addition & 1 deletion tests/primitives/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def test_tool_str():
tool = Tool(dummy_function)
assert (
str(tool)
== "dummy_function, whose description is <desc>A dummy function for testing. Args: x: An integer parameter y: A string parameter </desc>. It takes arguments {'x': {'type': 'integer'}, 'y': {'type': 'string', 'default': 'hello'}} in JSON format."
== "dummy_function, whose description is <desc>A dummy function for testing. Args: x: An integer parameter y: A string parameter </desc>. It takes arguments {'x': {'type': 'integer'}, 'y': {'type': 'string', 'default': 'hello'}}."
)


Expand Down