Skip to content

Add CodeAct module #8222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dspy/predict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from dspy.predict.predict import Predict
from dspy.predict.program_of_thought import ProgramOfThought
from dspy.predict.react import ReAct, Tool
from dspy.predict.code_act import CodeAct
from dspy.predict.refine import Refine

__all__ = [
"majority",
"BestOfN",
"ChainOfThought",
"ChainOfThoughtWithHint",
"CodeAct",
"KNN",
"MultiChainComparison",
"Predict",
Expand Down
121 changes: 121 additions & 0 deletions dspy/predict/code_act.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import logging
import inspect

from typing import Callable, Union, Type
from inspect import Signature

import dspy
from dspy.primitives.python_interpreter import PythonInterpreter
from dspy.primitives.tool import Tool
from dspy.signatures.signature import ensure_signature
from dspy.predict.react import ReAct
from dspy.predict.program_of_thought import ProgramOfThought

logger = logging.getLogger(__name__)

class CodeAct(ReAct, ProgramOfThought):
"""
CodeAct is a module that utilizes the Code Interpreter and predefined tools to solve the problem.
"""

def __init__(self, signature: Union[str, Type[Signature]], tools: list[Callable], max_iters: int = 5):
"""
Initializes the CodeAct class with the specified model, temperature, and max tokens.

Args:
signature (Union[str, Type[Signature]]): The signature of the module.
tools (list[Callable]): The tool callables to be used. CodeAct only accepts functions and not callable objects.
max_iters (int): The maximum number of iterations to generate the answer.

Example:

```python
from dspy.predict import CodeAct
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we follow up with a code tutorial demonstrating a use case of CodeAct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, I will file a follow up PR for documentation

def factorial(n):
if n == 1:
return 1
return n * factorial(n-1)

act = CodeAct("n->factorial", tools=[factorial])
act(n=5) # 120
```
"""
self.signature = ensure_signature(signature)
self.max_iters = max_iters

tools = [t if isinstance(t, Tool) else Tool(t) for t in tools]
if any(
not inspect.isfunction(tool.func) for tool in tools
):
raise ValueError("CodeAct only accepts functions and not callable objects.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious - why do we have this constraint?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it is hard to define the callable object in the python interpreter. Unlike functions, we need to define class and what parameters are passed to the object initialization, which is not visible.

tools = {tool.name: tool for tool in tools}

instructions = self._build_instructions(self.signature, tools)

codeact_signature = (
dspy.Signature({**self.signature.input_fields}, "\n".join(instructions))
.append("trajectory", dspy.InputField(), type_=str)
.append("generated_code", dspy.OutputField(desc="Python code that when executed, produces output relevant to answering the question"), type_=str)
.append("finished", dspy.OutputField(desc="a boolean flag to determine if the process is done"), type_=bool)
)

extract_signature = dspy.Signature(
{**self.signature.input_fields, **self.signature.output_fields},
self.signature.instructions,
).append("trajectory", dspy.InputField(), type_=str)

self.tools: dict[str, Tool] = tools
self.codeact = dspy.Predict(codeact_signature)
self.extractor = dspy.ChainOfThought(extract_signature)
# It will raises exception when dspy cannot find available deno instance by now.
self.interpreter = PythonInterpreter()

def _build_instructions(self, signature, tools):
instructions = [f"{signature.instructions}\n"] if signature.instructions else []
inputs = ", ".join([f"`{k}`" for k in signature.input_fields.keys()])
outputs = ", ".join([f"`{k}`" for k in signature.output_fields.keys()])

instructions.append(
f"You are an intelligent agent. For each episode, you will receive the fields {inputs} as input.\n"
f"Your goal is to generate executable Python code that collects any necessary information for producing {outputs}.\n"
"For each iteration, you will generate a code snippet that either solves the task or progresses towards the solution.\n"
"Ensure any output you wish to extract from the code is printed to the console. The code should be enclosed in a fenced code block.\n"
f"When all information for producing the outputs ({outputs}) are available to be extracted, mark `finished=True` besides the final Python code.\n"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should that be "besides the final Python code" or finished=True and generated_code is not None be mutual exclusive?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have an extractor at last, I designed it in a way that the code passed with finished=True is executed to minimize the interaction count.

"You have access to the Python Standard Library and the following functions:"
)

for idx, tool in enumerate(tools.values()):
instructions.append(f"({idx + 1}) {tool}")

return instructions

def forward(self, **kwargs):
# Define the tool funcitons in the interpreter
for tool in self.tools.values():
self.interpreter(inspect.getsource(tool.func))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we move it to init, and only run this line for call-time tools? (the new change in ReAct)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since python interpreter is shutdown every forward, this needs to happen in forward too.


trajectory = {}
max_iters = kwargs.pop("max_iters", self.max_iters)
for idx in range(max_iters):
code_data = self.codeact(trajectory=trajectory, **kwargs)
output = None
code, error = self._parse_code(code_data)

if error:
trajectory[f"observation_{idx}"] = f"Failed to parse the generated code: {error}"
continue

trajectory[f"generated_code_{idx}"] = code
output, error = self._execute_code(code)

if not error:
trajectory[f"code_output_{idx}"] = output
else:
trajectory[f"observation_{idx}"] = f"Failed to execute the generated code: {error}"

if code_data.finished:
break

extract = self._call_with_potential_trajectory_truncation(self.extractor, trajectory, **kwargs)
self.interpreter.shutdown()
return dspy.Prediction(trajectory=trajectory, **extract)
1 change: 1 addition & 0 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):

for idx, tool in enumerate(tools.values()):
instr.append(f"({idx + 1}) {tool}")
instr.append("When providing `next_tool_args`, the value inside the field must be in JSON format")

react_signature = (
dspy.Signature({**signature.input_fields}, "\n".join(instr))
Expand Down
2 changes: 1 addition & 1 deletion dspy/primitives/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __repr__(self):

def __str__(self):
desc = f", whose description is <desc>{self.desc}</desc>.".replace("\n", " ") if self.desc else "."
arg_desc = f"It takes arguments {self.args} in JSON format."
arg_desc = f"It takes arguments {self.args}."
return f"{self.name}{desc} {arg_desc}"


Expand Down
147 changes: 147 additions & 0 deletions tests/predict/test_code_act.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import pytest
import shutil

import dspy
from dspy import Signature
from dspy.predict import CodeAct
from dspy.utils import DummyLM

# This test suite requires deno to be installed. Please install deno following https://docs.deno.com/runtime/getting_started/installation/
is_deno_available = shutil.which("deno") is not None
skip_if_deno_not_available = pytest.mark.skipif(
not is_deno_available, reason="Deno is not installed or not in PATH"
)


class BasicQA(Signature):
question = dspy.InputField()
answer = dspy.OutputField(desc="often between 1 and 5 words")

def add(a: float, b: float) -> float:
"add two numbers"
return a + b

@skip_if_deno_not_available
def test_codeact_code_generation():
lm = DummyLM(
[
{
"reasoning": "Reason_A",
"generated_code": "```python\nresult = add(1,1)\nprint(result)\n```",
"finished": True,
},
{"reasoning": "Reason_B", "answer": "2"},
]
)
dspy.settings.configure(lm=lm)
program = CodeAct(BasicQA, tools=[add])
res = program(question="What is 1+1?")
assert res.answer == "2"
assert res.trajectory == {
'code_output_0': '"2\\n"',
'generated_code_0': 'result = add(1,1)\nprint(result)',
}
assert program.interpreter.deno_process is None


class ExtremumFinder(Signature):
input_list = dspy.InputField()
maximum = dspy.OutputField(desc="The maximum of the given numbers")
minimum = dspy.OutputField(desc="The minimum of the given numbers")

def extract_maximum_minimum(input_list: str) -> dict[str, float]:
numbers = list(map(float, input_list.split(",")))
return {"maximum": max(numbers), "minimum": min(numbers)}

@skip_if_deno_not_available
def test_codeact_support_multiple_fields():
lm = DummyLM(
[
{
"reasoning": "Reason_A",
"generated_code": "```python\nresult = extract_maximum_minimum('2, 3, 5, 6')\nprint(result)\n```",
"finished": True,
},
{"reasoning": "Reason_B", "maximum": "6", "minimum": "2"},
]
)
dspy.settings.configure(lm=lm)
program = CodeAct(ExtremumFinder, tools=[extract_maximum_minimum])
res = program(input_list="2, 3, 5, 6")
assert res.maximum == "6"
assert res.minimum == "2"
assert res.trajectory == {
'code_output_0': '"{\'maximum\': 6.0, \'minimum\': 2.0}\\n"',
'generated_code_0': "result = extract_maximum_minimum('2, 3, 5, 6')\nprint(result)",
}
assert program.interpreter.deno_process is None


@skip_if_deno_not_available
def test_codeact_code_parse_failure():
lm = DummyLM(
[
{
"reasoning": "Reason_A",
"generated_code": "```python\nparse(error\n```",
"finished": False,
},
{
"reasoning": "Reason_A",
"generated_code": "```python\nresult = add(1,1)\nprint(result)\n```",
"finished": True,
},
{"reasoning": "Reason_B", "answer": "2"},
]
)
dspy.settings.configure(lm=lm)
program = CodeAct(BasicQA, tools=[add])
res = program(question="What is 1+1?")
assert res.answer == "2"
assert res.trajectory == {
'generated_code_0': 'parse(error',
'observation_0': 'Failed to execute the generated code: Invalid Python syntax. message: ',
'generated_code_1': 'result = add(1,1)\nprint(result)',
'code_output_1': '"2\\n"',
}
assert program.interpreter.deno_process is None


@skip_if_deno_not_available
def test_codeact_code_execution_failure():
lm = DummyLM(
[
{
"reasoning": "Reason_A",
"generated_code": "```python\nunknown+1\n```",
"finished": False,
},
{
"reasoning": "Reason_A",
"generated_code": "```python\nresult = add(1,1)\nprint(result)\n```",
"finished": True,
},
{"reasoning": "Reason_B", "answer": "2"},
]
)
dspy.settings.configure(lm=lm)
program = CodeAct(BasicQA, tools=[add])
res = program(question="What is 1+1?")
assert res.answer == "2"
assert res.trajectory == {
'generated_code_0': 'unknown+1',
'observation_0': 'Failed to execute the generated code: NameError: ["name \'unknown\' is not defined"]',
'generated_code_1': 'result = add(1,1)\nprint(result)',
'code_output_1': '"2\\n"',
}
assert program.interpreter.deno_process is None


class CustomTool:
def __call__(self, a: float, b: float) -> float:
return a + b

@skip_if_deno_not_available
def test_codeact_tool_validation():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may also want to test the input flow - use a mock litellm.completion to capture the prompt, and validate our tool + code information are correctly included in the prompt.

with pytest.raises(ValueError, match="CodeAct only accepts functions and not callable objects."):
CodeAct(BasicQA, tools=[CustomTool()])
2 changes: 1 addition & 1 deletion tests/primitives/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def add(x: int, y: int = 0) -> int:
tool = Tool(add)
assert (
str(tool)
== "add, whose description is <desc>Add two integers.</desc>. It takes arguments {'x': {'type': 'integer'}, 'y': {'type': 'integer', 'default': 0}} in JSON format."
== "add, whose description is <desc>Add two integers.</desc>. It takes arguments {'x': {'type': 'integer'}, 'y': {'type': 'integer', 'default': 0}}."
)


Expand Down