Skip to content

Commit c66c300

Browse files
authored
Merge pull request #70 from nforro/tools
Add basic text editor tools
2 parents 3955c44 + 3296048 commit c66c300

File tree

2 files changed

+276
-0
lines changed

2 files changed

+276
-0
lines changed

beeai/agents/tests/unit/test_tools.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,16 @@
1616
SetZStreamReleaseTool,
1717
SetZStreamReleaseToolInput,
1818
)
19+
from tools.text import (
20+
CreateTool,
21+
CreateToolInput,
22+
ViewTool,
23+
ViewToolInput,
24+
InsertTool,
25+
InsertToolInput,
26+
StrReplaceTool,
27+
StrReplaceToolInput,
28+
)
1929

2030

2131
@pytest.mark.parametrize(
@@ -149,3 +159,123 @@ async def test_set_zstream_release(autorelease_spec):
149159
result = output.result
150160
assert not result
151161
assert autorelease_spec.read_text().splitlines()[3] == "Release: 4%{?dist}.%{autorelease -n}"
162+
163+
164+
@pytest.mark.asyncio
165+
async def test_create(tmp_path):
166+
test_file = tmp_path / "test.txt"
167+
content = "Line 1\nLine 2\n"
168+
tool = CreateTool()
169+
output = await tool.run(input=CreateToolInput(file=test_file, content=content)).middleware(
170+
GlobalTrajectoryMiddleware(pretty=True)
171+
)
172+
result = output.result
173+
assert not result
174+
assert test_file.read_text() == content
175+
176+
177+
@pytest.mark.asyncio
178+
async def test_view(tmp_path):
179+
test_dir = tmp_path
180+
test_file = test_dir / "test.txt"
181+
content = "Line 1\nLine 2\nLine 3\n"
182+
test_file.write_text(content)
183+
tool = ViewTool()
184+
output = await tool.run(input=ViewToolInput(path=test_dir)).middleware(
185+
GlobalTrajectoryMiddleware(pretty=True)
186+
)
187+
result = output.result
188+
assert result == "test.txt\n"
189+
output = await tool.run(input=ViewToolInput(path=test_file)).middleware(
190+
GlobalTrajectoryMiddleware(pretty=True)
191+
)
192+
result = output.result
193+
assert result == content
194+
output = await tool.run(input=ViewToolInput(path=test_file, view_range=(2, -1))).middleware(
195+
GlobalTrajectoryMiddleware(pretty=True)
196+
)
197+
result = output.result
198+
assert (
199+
result
200+
== dedent(
201+
"""
202+
Line 2
203+
Line 3
204+
"""
205+
)[1:]
206+
)
207+
output = await tool.run(input=ViewToolInput(path=test_file, view_range=(1, 2))).middleware(
208+
GlobalTrajectoryMiddleware(pretty=True)
209+
)
210+
result = output.result
211+
assert (
212+
result
213+
== dedent(
214+
"""
215+
Line 1
216+
Line 2
217+
"""
218+
)[1:]
219+
)
220+
221+
222+
@pytest.mark.parametrize(
223+
"line, content",
224+
[
225+
(
226+
0,
227+
dedent(
228+
"""
229+
Inserted line
230+
Line 1
231+
Line 2
232+
Line 3
233+
"""
234+
)[1:],
235+
),
236+
(
237+
1,
238+
dedent(
239+
"""
240+
Line 1
241+
Inserted line
242+
Line 2
243+
Line 3
244+
"""
245+
)[1:],
246+
),
247+
],
248+
)
249+
@pytest.mark.asyncio
250+
async def test_insert(line, content, tmp_path):
251+
test_file = tmp_path / "test.txt"
252+
test_file.write_text("Line 1\nLine 2\nLine 3\n")
253+
tool = InsertTool()
254+
output = await tool.run(
255+
input=InsertToolInput(file=test_file, line=line, new_string="Inserted line")
256+
).middleware(GlobalTrajectoryMiddleware(pretty=True))
257+
result = output.result
258+
assert not result
259+
assert test_file.read_text() == content
260+
261+
262+
@pytest.mark.asyncio
263+
async def test_str_replace(tmp_path):
264+
test_file = tmp_path / "test.txt"
265+
test_file.write_text("Line 1\nLine 2\nLine 3\n")
266+
tool = StrReplaceTool()
267+
output = await tool.run(
268+
input=StrReplaceToolInput(file=test_file, old_string="Line 2", new_string="LINE_2")
269+
).middleware(GlobalTrajectoryMiddleware(pretty=True))
270+
result = output.result
271+
assert not result
272+
assert (
273+
test_file.read_text()
274+
== dedent(
275+
"""
276+
Line 1
277+
LINE_2
278+
Line 3
279+
"""
280+
)[1:]
281+
)

beeai/agents/tools/text.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import asyncio
2+
from pathlib import Path
3+
4+
from pydantic import BaseModel, Field
5+
6+
from beeai_framework.context import RunContext
7+
from beeai_framework.emitter import Emitter
8+
from beeai_framework.tools import StringToolOutput, Tool, ToolRunOptions
9+
10+
11+
class CreateToolInput(BaseModel):
12+
file: Path = Field(description="Absolute path to a file to create")
13+
content: str = Field(description="Content to write to the new file")
14+
15+
16+
class CreateTool(Tool[CreateToolInput, ToolRunOptions, StringToolOutput]):
17+
name = "create"
18+
description = """
19+
Creates a new file with the specified content.
20+
Returns error message on failure.
21+
"""
22+
input_schema = CreateToolInput
23+
24+
def _create_emitter(self) -> Emitter:
25+
return Emitter.root().child(
26+
namespace=["tool", "text", self.name],
27+
creator=self,
28+
)
29+
30+
async def _run(
31+
self, tool_input: CreateToolInput, options: ToolRunOptions | None, context: RunContext
32+
) -> StringToolOutput:
33+
try:
34+
await asyncio.to_thread(tool_input.file.write_text, tool_input.content)
35+
except Exception as e:
36+
return StringToolOutput(result=f"Failed to create file: {e}")
37+
return StringToolOutput()
38+
39+
40+
class ViewToolInput(BaseModel):
41+
path: Path = Field(description="Absolute path to a file or directory to view")
42+
view_range: tuple[int, int] | None = Field(
43+
description="""
44+
Tuple of two integers specifying the start and end line numbers to view.
45+
Line numbers are 1-indexed, and -1 for the end line means read to the end of the file.
46+
This argument only applies when viewing files, not directories.
47+
""",
48+
default=None,
49+
)
50+
51+
52+
class ViewTool(Tool[ViewToolInput, ToolRunOptions, StringToolOutput]):
53+
name = "view"
54+
description = """
55+
Outputs the contents of a file or lists the contents of a directory. Can read an entire file
56+
or a specific range of lines. Returns error message on failure.
57+
"""
58+
input_schema = ViewToolInput
59+
60+
def _create_emitter(self) -> Emitter:
61+
return Emitter.root().child(
62+
namespace=["tool", "text", self.name],
63+
creator=self,
64+
)
65+
66+
async def _run(
67+
self, tool_input: ViewToolInput, options: ToolRunOptions | None, context: RunContext
68+
) -> StringToolOutput:
69+
try:
70+
if tool_input.path.is_file():
71+
content = await asyncio.to_thread(tool_input.path.read_text)
72+
if tool_input.view_range is not None:
73+
start, end = tool_input.view_range
74+
lines = content.splitlines(keepends=True)
75+
content = "".join(lines[start - 1 : None if end < 0 else end])
76+
return StringToolOutput(result=content)
77+
return StringToolOutput(result="\n".join(e.name for e in tool_input.path.iterdir()) + "\n")
78+
except Exception as e:
79+
return StringToolOutput(result=f"Failed to view path: {e}")
80+
81+
82+
class InsertToolInput(BaseModel):
83+
file: Path = Field(description="Absolute path to a file to edit")
84+
line: int = Field(description="Line number after which to insert the text (0 for beginning of file)")
85+
new_string: str = Field(description="Text to insert")
86+
87+
88+
class InsertTool(Tool[InsertToolInput, ToolRunOptions, StringToolOutput]):
89+
name = "insert"
90+
description = """
91+
Inserts the specified text at a specific location in a file.
92+
Returns error message on failure.
93+
"""
94+
input_schema = InsertToolInput
95+
96+
def _create_emitter(self) -> Emitter:
97+
return Emitter.root().child(
98+
namespace=["tool", "text", self.name],
99+
creator=self,
100+
)
101+
102+
async def _run(
103+
self, tool_input: InsertToolInput, options: ToolRunOptions | None, context: RunContext
104+
) -> StringToolOutput:
105+
try:
106+
lines = (await asyncio.to_thread(tool_input.file.read_text)).splitlines(keepends=True)
107+
lines.insert(tool_input.line, tool_input.new_string + "\n")
108+
await asyncio.to_thread(tool_input.file.write_text, "".join(lines))
109+
except Exception as e:
110+
return StringToolOutput(result=f"Failed to insert text: {e}")
111+
return StringToolOutput()
112+
113+
114+
class StrReplaceToolInput(BaseModel):
115+
file: Path = Field(description="Absolute path to a file to edit")
116+
old_string: str = Field(
117+
description="Text to replace (must match exactly, including whitespace and indentation)"
118+
)
119+
new_string: str = Field(description="New text to insert in place of the old text")
120+
121+
122+
class StrReplaceTool(Tool[StrReplaceToolInput, ToolRunOptions, StringToolOutput]):
123+
name = "str_replace"
124+
description = """
125+
Replaces a specific string in the specified file with a new string.
126+
Returns error message on failure.
127+
"""
128+
input_schema = StrReplaceToolInput
129+
130+
def _create_emitter(self) -> Emitter:
131+
return Emitter.root().child(
132+
namespace=["tool", "text", self.name],
133+
creator=self,
134+
)
135+
136+
async def _run(
137+
self, tool_input: StrReplaceToolInput, options: ToolRunOptions | None, context: RunContext
138+
) -> StringToolOutput:
139+
try:
140+
content = await asyncio.to_thread(tool_input.file.read_text)
141+
await asyncio.to_thread(
142+
tool_input.file.write_text, content.replace(tool_input.old_string, tool_input.new_string)
143+
)
144+
except Exception as e:
145+
return StringToolOutput(result=f"Failed to replace text: {e}")
146+
return StringToolOutput()

0 commit comments

Comments
 (0)