-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Add CodeAct module #8222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CodeAct module #8222
Changes from all commits
f62bc60
5ab5d99
4b6ae68
be7b428
a6cc1d0
0aa7319
2ad8696
7288cb4
192d0f2
38d7e8e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import logging | ||
import inspect | ||
|
||
from typing import Callable, Union, Type | ||
from inspect import Signature | ||
|
||
import dspy | ||
from dspy.primitives.python_interpreter import PythonInterpreter | ||
from dspy.primitives.tool import Tool | ||
from dspy.signatures.signature import ensure_signature | ||
from dspy.predict.react import ReAct | ||
from dspy.predict.program_of_thought import ProgramOfThought | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class CodeAct(ReAct, ProgramOfThought): | ||
""" | ||
CodeAct is a module that utilizes the Code Interpreter and predefined tools to solve the problem. | ||
""" | ||
|
||
def __init__(self, signature: Union[str, Type[Signature]], tools: list[Callable], max_iters: int = 5): | ||
""" | ||
Initializes the CodeAct class with the specified model, temperature, and max tokens. | ||
|
||
Args: | ||
signature (Union[str, Type[Signature]]): The signature of the module. | ||
tools (list[Callable]): The tool callables to be used. CodeAct only accepts functions and not callable objects. | ||
max_iters (int): The maximum number of iterations to generate the answer. | ||
|
||
Example: | ||
|
||
```python | ||
from dspy.predict import CodeAct | ||
def factorial(n): | ||
if n == 1: | ||
return 1 | ||
return n * factorial(n-1) | ||
|
||
act = CodeAct("n->factorial", tools=[factorial]) | ||
act(n=5) # 120 | ||
``` | ||
""" | ||
self.signature = ensure_signature(signature) | ||
self.max_iters = max_iters | ||
|
||
tools = [t if isinstance(t, Tool) else Tool(t) for t in tools] | ||
if any( | ||
not inspect.isfunction(tool.func) for tool in tools | ||
): | ||
raise ValueError("CodeAct only accepts functions and not callable objects.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious - why do we have this constraint? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since it is hard to define the callable object in the python interpreter. Unlike functions, we need to define class and what parameters are passed to the object initialization, which is not visible. |
||
tools = {tool.name: tool for tool in tools} | ||
|
||
instructions = self._build_instructions(self.signature, tools) | ||
|
||
codeact_signature = ( | ||
dspy.Signature({**self.signature.input_fields}, "\n".join(instructions)) | ||
.append("trajectory", dspy.InputField(), type_=str) | ||
.append("generated_code", dspy.OutputField(desc="Python code that when executed, produces output relevant to answering the question"), type_=str) | ||
.append("finished", dspy.OutputField(desc="a boolean flag to determine if the process is done"), type_=bool) | ||
) | ||
|
||
extract_signature = dspy.Signature( | ||
{**self.signature.input_fields, **self.signature.output_fields}, | ||
self.signature.instructions, | ||
).append("trajectory", dspy.InputField(), type_=str) | ||
|
||
self.tools: dict[str, Tool] = tools | ||
self.codeact = dspy.Predict(codeact_signature) | ||
self.extractor = dspy.ChainOfThought(extract_signature) | ||
# It will raises exception when dspy cannot find available deno instance by now. | ||
self.interpreter = PythonInterpreter() | ||
|
||
def _build_instructions(self, signature, tools): | ||
instructions = [f"{signature.instructions}\n"] if signature.instructions else [] | ||
inputs = ", ".join([f"`{k}`" for k in signature.input_fields.keys()]) | ||
outputs = ", ".join([f"`{k}`" for k in signature.output_fields.keys()]) | ||
|
||
instructions.append( | ||
f"You are an intelligent agent. For each episode, you will receive the fields {inputs} as input.\n" | ||
f"Your goal is to generate executable Python code that collects any necessary information for producing {outputs}.\n" | ||
"For each iteration, you will generate a code snippet that either solves the task or progresses towards the solution.\n" | ||
"Ensure any output you wish to extract from the code is printed to the console. The code should be enclosed in a fenced code block.\n" | ||
f"When all information for producing the outputs ({outputs}) are available to be extracted, mark `finished=True` besides the final Python code.\n" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should that be "besides the final Python code" or There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we have an extractor at last, I designed it in a way that the code passed with |
||
"You have access to the Python Standard Library and the following functions:" | ||
) | ||
|
||
for idx, tool in enumerate(tools.values()): | ||
instructions.append(f"({idx + 1}) {tool}") | ||
|
||
return instructions | ||
|
||
def forward(self, **kwargs): | ||
# Define the tool funcitons in the interpreter | ||
for tool in self.tools.values(): | ||
self.interpreter(inspect.getsource(tool.func)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we move it to init, and only run this line for call-time tools? (the new change in ReAct) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since python interpreter is shutdown every |
||
|
||
trajectory = {} | ||
max_iters = kwargs.pop("max_iters", self.max_iters) | ||
for idx in range(max_iters): | ||
code_data = self.codeact(trajectory=trajectory, **kwargs) | ||
output = None | ||
code, error = self._parse_code(code_data) | ||
|
||
if error: | ||
trajectory[f"observation_{idx}"] = f"Failed to parse the generated code: {error}" | ||
continue | ||
|
||
trajectory[f"generated_code_{idx}"] = code | ||
output, error = self._execute_code(code) | ||
|
||
if not error: | ||
trajectory[f"code_output_{idx}"] = output | ||
else: | ||
trajectory[f"observation_{idx}"] = f"Failed to execute the generated code: {error}" | ||
|
||
if code_data.finished: | ||
break | ||
|
||
extract = self._call_with_potential_trajectory_truncation(self.extractor, trajectory, **kwargs) | ||
self.interpreter.shutdown() | ||
return dspy.Prediction(trajectory=trajectory, **extract) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import pytest | ||
import shutil | ||
|
||
import dspy | ||
from dspy import Signature | ||
from dspy.predict import CodeAct | ||
from dspy.utils import DummyLM | ||
|
||
# This test suite requires deno to be installed. Please install deno following https://docs.deno.com/runtime/getting_started/installation/ | ||
is_deno_available = shutil.which("deno") is not None | ||
skip_if_deno_not_available = pytest.mark.skipif( | ||
not is_deno_available, reason="Deno is not installed or not in PATH" | ||
) | ||
|
||
|
||
class BasicQA(Signature): | ||
question = dspy.InputField() | ||
answer = dspy.OutputField(desc="often between 1 and 5 words") | ||
|
||
def add(a: float, b: float) -> float: | ||
"add two numbers" | ||
return a + b | ||
|
||
@skip_if_deno_not_available | ||
def test_codeact_code_generation(): | ||
lm = DummyLM( | ||
[ | ||
{ | ||
"reasoning": "Reason_A", | ||
"generated_code": "```python\nresult = add(1,1)\nprint(result)\n```", | ||
"finished": True, | ||
}, | ||
{"reasoning": "Reason_B", "answer": "2"}, | ||
] | ||
) | ||
dspy.settings.configure(lm=lm) | ||
program = CodeAct(BasicQA, tools=[add]) | ||
res = program(question="What is 1+1?") | ||
assert res.answer == "2" | ||
assert res.trajectory == { | ||
'code_output_0': '"2\\n"', | ||
'generated_code_0': 'result = add(1,1)\nprint(result)', | ||
} | ||
assert program.interpreter.deno_process is None | ||
|
||
|
||
class ExtremumFinder(Signature): | ||
input_list = dspy.InputField() | ||
maximum = dspy.OutputField(desc="The maximum of the given numbers") | ||
minimum = dspy.OutputField(desc="The minimum of the given numbers") | ||
|
||
def extract_maximum_minimum(input_list: str) -> dict[str, float]: | ||
numbers = list(map(float, input_list.split(","))) | ||
return {"maximum": max(numbers), "minimum": min(numbers)} | ||
|
||
@skip_if_deno_not_available | ||
def test_codeact_support_multiple_fields(): | ||
lm = DummyLM( | ||
[ | ||
{ | ||
"reasoning": "Reason_A", | ||
"generated_code": "```python\nresult = extract_maximum_minimum('2, 3, 5, 6')\nprint(result)\n```", | ||
"finished": True, | ||
}, | ||
{"reasoning": "Reason_B", "maximum": "6", "minimum": "2"}, | ||
] | ||
) | ||
dspy.settings.configure(lm=lm) | ||
program = CodeAct(ExtremumFinder, tools=[extract_maximum_minimum]) | ||
res = program(input_list="2, 3, 5, 6") | ||
assert res.maximum == "6" | ||
assert res.minimum == "2" | ||
assert res.trajectory == { | ||
'code_output_0': '"{\'maximum\': 6.0, \'minimum\': 2.0}\\n"', | ||
'generated_code_0': "result = extract_maximum_minimum('2, 3, 5, 6')\nprint(result)", | ||
} | ||
assert program.interpreter.deno_process is None | ||
|
||
|
||
@skip_if_deno_not_available | ||
def test_codeact_code_parse_failure(): | ||
lm = DummyLM( | ||
[ | ||
{ | ||
"reasoning": "Reason_A", | ||
"generated_code": "```python\nparse(error\n```", | ||
"finished": False, | ||
}, | ||
{ | ||
"reasoning": "Reason_A", | ||
"generated_code": "```python\nresult = add(1,1)\nprint(result)\n```", | ||
"finished": True, | ||
}, | ||
{"reasoning": "Reason_B", "answer": "2"}, | ||
] | ||
) | ||
dspy.settings.configure(lm=lm) | ||
program = CodeAct(BasicQA, tools=[add]) | ||
res = program(question="What is 1+1?") | ||
assert res.answer == "2" | ||
assert res.trajectory == { | ||
'generated_code_0': 'parse(error', | ||
'observation_0': 'Failed to execute the generated code: Invalid Python syntax. message: ', | ||
'generated_code_1': 'result = add(1,1)\nprint(result)', | ||
'code_output_1': '"2\\n"', | ||
} | ||
assert program.interpreter.deno_process is None | ||
|
||
|
||
@skip_if_deno_not_available | ||
def test_codeact_code_execution_failure(): | ||
lm = DummyLM( | ||
[ | ||
{ | ||
"reasoning": "Reason_A", | ||
"generated_code": "```python\nunknown+1\n```", | ||
"finished": False, | ||
}, | ||
{ | ||
"reasoning": "Reason_A", | ||
"generated_code": "```python\nresult = add(1,1)\nprint(result)\n```", | ||
"finished": True, | ||
}, | ||
{"reasoning": "Reason_B", "answer": "2"}, | ||
] | ||
) | ||
dspy.settings.configure(lm=lm) | ||
program = CodeAct(BasicQA, tools=[add]) | ||
res = program(question="What is 1+1?") | ||
assert res.answer == "2" | ||
assert res.trajectory == { | ||
'generated_code_0': 'unknown+1', | ||
'observation_0': 'Failed to execute the generated code: NameError: ["name \'unknown\' is not defined"]', | ||
'generated_code_1': 'result = add(1,1)\nprint(result)', | ||
'code_output_1': '"2\\n"', | ||
} | ||
assert program.interpreter.deno_process is None | ||
|
||
|
||
class CustomTool: | ||
def __call__(self, a: float, b: float) -> float: | ||
return a + b | ||
|
||
@skip_if_deno_not_available | ||
def test_codeact_tool_validation(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may also want to test the input flow - use a mock litellm.completion to capture the prompt, and validate our tool + code information are correctly included in the prompt. |
||
with pytest.raises(ValueError, match="CodeAct only accepts functions and not callable objects."): | ||
CodeAct(BasicQA, tools=[CustomTool()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we follow up with a code tutorial demonstrating a use case of CodeAct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, I will file a follow up PR for documentation