-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathtrae_agent.py
More file actions
258 lines (221 loc) · 10 KB
/
trae_agent.py
File metadata and controls
258 lines (221 loc) · 10 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""TraeAgent for software engineering tasks."""
import asyncio
import contextlib
import os
import subprocess
from typing import override
from trae_agent.agent.agent_basics import AgentError, AgentExecution
from trae_agent.agent.base_agent import BaseAgent
from trae_agent.prompt.agent_prompt import TRAE_AGENT_SYSTEM_PROMPT
from trae_agent.tools import tools_registry
from trae_agent.tools.base import Tool, ToolResult
from trae_agent.utils.config import MCPServerConfig, TraeAgentConfig
from trae_agent.utils.llm_clients.llm_basics import LLMMessage, LLMResponse
from trae_agent.utils.mcp_client import MCPClient
TraeAgentToolNames = [
"str_replace_based_edit_tool",
"sequentialthinking",
"json_edit_tool",
"task_done",
"bash",
]
class TraeAgent(BaseAgent):
"""Trae Agent specialized for software engineering tasks."""
def __init__(
self,
trae_agent_config: TraeAgentConfig,
docker_config: dict | None = None,
docker_keep: bool = True,
):
"""Initialize TraeAgent.
Args:
config: Configuration object containing model parameters and other settings.
Required if llm_client is not provided.
llm_client: Optional pre-configured LLMClient instance.
If provided, it will be used instead of creating a new one from config.
docker_config: Optional configuration for running in a Docker environment.
"""
self.project_path: str = ""
self.base_commit: str | None = None
self.must_patch: str = "false"
self.patch_path: str | None = None
self.mcp_servers_config: dict[str, MCPServerConfig] | None = (
trae_agent_config.mcp_servers_config if trae_agent_config.mcp_servers_config else None
)
self.allow_mcp_servers: list[str] | None = (
trae_agent_config.allow_mcp_servers if trae_agent_config.allow_mcp_servers else []
)
self.mcp_tools: list[Tool] = []
self.mcp_clients: list[MCPClient] = [] # Keep track of MCP clients for cleanup
self.docker_config = docker_config
super().__init__(
agent_config=trae_agent_config, docker_config=docker_config, docker_keep=docker_keep
)
async def initialise_mcp(self):
"""Async factory to create and initialize TraeAgent."""
await self.discover_mcp_tools()
if self.mcp_tools:
self._tools.extend(self.mcp_tools)
async def discover_mcp_tools(self):
if self.mcp_servers_config:
for mcp_server_name, mcp_server_config in self.mcp_servers_config.items():
if self.allow_mcp_servers is None:
return
if mcp_server_name not in self.allow_mcp_servers:
continue
mcp_client = MCPClient()
try:
await mcp_client.connect_and_discover(
mcp_server_name,
mcp_server_config,
self.mcp_tools,
self._llm_client.provider.value,
)
# Store client for later cleanup
self.mcp_clients.append(mcp_client)
except Exception:
# Clean up failed client
with contextlib.suppress(Exception):
await mcp_client.cleanup(mcp_server_name)
continue
except asyncio.CancelledError:
# If the task is cancelled, clean up and skip this server
with contextlib.suppress(Exception):
await mcp_client.cleanup(mcp_server_name)
continue
else:
return
@override
def new_task(
self,
task: str,
extra_args: dict[str, str] | None = None,
tool_names: list[str] | None = None,
):
"""Create a new task."""
self._task: str = task
if tool_names is None and len(self._tools) == 0:
tool_names = TraeAgentToolNames
# Get the model provider from the LLM client
provider = self._model_config.model_provider.provider
self._tools: list[Tool] = [
tools_registry[tool_name](model_provider=provider) for tool_name in tool_names
]
# self._tool_caller: ToolExecutor = ToolExecutor(self._tools)
self._initial_messages: list[LLMMessage] = []
self._initial_messages.append(LLMMessage(role="system", content=self.get_system_prompt()))
user_message = ""
if not extra_args:
raise AgentError("Project path and issue information are required.")
if "project_path" not in extra_args:
raise AgentError("Project path is required")
self.project_path = extra_args.get("project_path", "")
if self.docker_config:
user_message += "[Project root path]:/workspace\n\n"
else:
user_message += f"[Project root path]:\n{self.project_path}\n\n"
if "issue" in extra_args:
user_message += f"[Problem statement]: We're currently solving the following issue within our repository. Here's the issue text:\n{extra_args['issue']}\n"
optional_attrs_to_set = ["base_commit", "must_patch", "patch_path"]
for attr in optional_attrs_to_set:
if attr in extra_args:
setattr(self, attr, extra_args[attr])
self._initial_messages.append(LLMMessage(role="user", content=user_message))
# If trajectory recorder is set, start recording
if self._trajectory_recorder:
self._trajectory_recorder.start_recording(
task=task,
provider=self._llm_client.provider.value,
model=self._model_config.model,
max_steps=self._max_steps,
)
@override
async def execute_task(self) -> AgentExecution:
"""Execute the task and finalize trajectory recording."""
execution = await super().execute_task()
# Finalize trajectory recording if recorder is available
if self._trajectory_recorder:
self._trajectory_recorder.finalize_recording(
success=execution.success, final_result=execution.final_result
)
if self.patch_path is not None:
with open(self.patch_path, "w") as patch_f:
_ = patch_f.write(self.get_git_diff())
return execution
def get_system_prompt(self) -> str:
"""Get the system prompt for TraeAgent."""
return TRAE_AGENT_SYSTEM_PROMPT
@override
def reflect_on_result(self, tool_results: list[ToolResult]) -> str | None:
return None
def get_git_diff(self) -> str:
"""Get the git diff of the project."""
pwd = os.getcwd()
if not os.path.isdir(self.project_path):
return ""
os.chdir(self.project_path)
try:
if not self.base_commit:
stdout = subprocess.check_output(["git", "--no-pager", "diff"]).decode()
else:
stdout = subprocess.check_output(
["git", "--no-pager", "diff", self.base_commit, "HEAD"]
).decode()
except (subprocess.CalledProcessError, FileNotFoundError):
stdout = ""
finally:
os.chdir(pwd)
return stdout
# Copyright (c) 2024 paul-gauthier
# SPDX-License-Identifier: Apache-2.0
# Original remove_patches_to_tests function was released under Apache-2.0 License, with the full license text
# available at https://github.com/Aider-AI/aider-swe-bench/blob/6e98cd6c3b2cbcba12976d6ae1b07f847480cb74/LICENSE.txt
# Original function is at https://github.com/Aider-AI/aider-swe-bench/blob/6e98cd6c3b2cbcba12976d6ae1b07f847480cb74/tests.py#L45
def remove_patches_to_tests(self, model_patch: str) -> str:
"""
Remove any changes to the tests directory from the provided patch.
This is to ensure that the model_patch does not disturb the repo's
tests when doing acceptance testing with the `test_patch`.
"""
lines = model_patch.splitlines(keepends=True)
filtered_lines: list[str] = []
test_patterns = ["/test/", "/tests/", "/testing/", "test_", "tox.ini"]
is_tests = False
for line in lines:
if line.startswith("diff --git a/"):
target_path = line.split()[-1]
is_tests = target_path.startswith("b/") and any(
p in target_path for p in test_patterns
)
if not is_tests:
filtered_lines.append(line)
return "".join(filtered_lines)
@override
def llm_indicates_task_completed(self, llm_response: LLMResponse) -> bool:
"""Check if the LLM indicates that the task is completed."""
if llm_response.tool_calls is None:
return False
return any(tool_call.name == "task_done" for tool_call in llm_response.tool_calls)
@override
def _is_task_completed(self, llm_response: LLMResponse) -> bool:
"""Enhanced task completion detection."""
if self.must_patch == "true":
model_patch = self.get_git_diff()
patch = self.remove_patches_to_tests(model_patch)
if not patch.strip():
return False
return True
@override
def task_incomplete_message(self) -> str:
"""Return a message indicating that the task is incomplete."""
return "ERROR! Your Patch is empty. Please provide a patch that fixes the problem."
@override
async def cleanup_mcp_clients(self) -> None:
"""Clean up all MCP clients to prevent async context leaks."""
for client in self.mcp_clients:
with contextlib.suppress(Exception):
# Use a generic server name for cleanup since we don't track which server each client is for
await client.cleanup("cleanup")
self.mcp_clients.clear()