Skip to content

Commit 52f356f

Browse files
sizhit2The tunix Authors
authored andcommitted
Add DeepSWE train script.
PiperOrigin-RevId: 873920307
1 parent df627a6 commit 52f356f

File tree

9 files changed

+1442
-16
lines changed

9 files changed

+1442
-16
lines changed

examples/deepswe/swe_agent.py

Lines changed: 486 additions & 0 deletions
Large diffs are not rendered by default.

examples/deepswe/swe_env.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import json
2+
import os
3+
from typing import Any, Optional, cast
4+
import numpy as np
5+
6+
try:
7+
import r2egym # pytype: disable=import-error
8+
from r2egym.agenthub.action import Action # pytype: disable=import-error
9+
from r2egym.agenthub.environment.env import EnvArgs, RepoEnv # pytype: disable=import-error
10+
except ImportError:
11+
r2egym = cast(Any, None)
12+
EnvArgs = cast(Any, None)
13+
RepoEnv = cast(Any, None)
14+
Action = cast(Any, None)
15+
16+
from tunix.rl.agentic.environments.base_environment import BaseTaskEnv, EnvStepResult
17+
from tunix.rl.agentic.rewards import reward_types
18+
19+
if r2egym:
20+
R2EGYM_PATH = os.path.dirname(r2egym.__file__)
21+
else:
22+
R2EGYM_PATH = ""
23+
# List of tools to be used in the environment.
24+
R2EGYM_COMMAND_FILES = [
25+
os.path.join(R2EGYM_PATH, "agenthub/tools/r2egym/file_editor.py"),
26+
os.path.join(R2EGYM_PATH, "agenthub/tools/search.py"),
27+
os.path.join(R2EGYM_PATH, "agenthub/tools/r2egym/execute_bash.py"),
28+
os.path.join(R2EGYM_PATH, "agenthub/tools/finish.py"),
29+
]
30+
31+
SWEAGENT_COMMAND_FILES = [
32+
os.path.join(R2EGYM_PATH, "agenthub/tools/str_replace_editor.py"),
33+
os.path.join(R2EGYM_PATH, "agenthub/tools/execute_bash.py"),
34+
os.path.join(R2EGYM_PATH, "agenthub/tools/submit.py"),
35+
]
36+
37+
38+
def _unpack_entry(entry: dict) -> dict:
39+
"""Utility to clean up and unpack the dataset entry."""
40+
unpacked_entry = {}
41+
for k, v in entry.items():
42+
if isinstance(v, np.ndarray):
43+
unpacked_entry[k] = v.item()
44+
elif isinstance(v, list):
45+
if len(v) != 1:
46+
raise ValueError(
47+
f"Can only convert a list of size 1; got size {len(v)}"
48+
)
49+
unpacked_entry[k] = v[0]
50+
else:
51+
unpacked_entry[k] = v
52+
return unpacked_entry
53+
54+
55+
class SWEEnv(BaseTaskEnv):
56+
"""Software Engineering Environment for code-related tasks."""
57+
58+
def __init__(
59+
self,
60+
entry: dict,
61+
group_id: int | None = None,
62+
pair_index: int | None = None,
63+
step_timeout: int = 90,
64+
reward_timeout: int = 300,
65+
backend: str = "kubernetes",
66+
delete_image: bool = False,
67+
verbose: bool = False,
68+
scaffold: str = "r2egym",
69+
max_steps: int = 1,
70+
):
71+
"""Initialize the SWE environment.
72+
73+
Args:
74+
entry: Dataset containing the tasks. If None, uses default dataset.
75+
group_id: ID of the group to which the task belongs.
76+
pair_index: Index of the pair to use. If None, selects a random pair.
77+
step_timeout: Timeout for each step in seconds.
78+
reward_timeout: Timeout for reward computation in seconds.
79+
backend: Backend to use for the environment.
80+
delete_image: Whether to delete the Docker image after closing.
81+
"""
82+
self.entry = _unpack_entry(entry)
83+
self.step_timeout = step_timeout
84+
self.reward_timeout = reward_timeout
85+
self.total_steps = 0
86+
self.delete_image = delete_image
87+
self.backend = backend
88+
self.env = None
89+
self.verbose = verbose
90+
self.scaffold = scaffold
91+
assert scaffold in [
92+
"r2egym",
93+
"sweagent",
94+
], f"Invalid scaffold: {scaffold}, must be one of ['r2egym', 'sweagent']"
95+
super().__init__(max_steps=max_steps)
96+
97+
if not hasattr(self, "extra_kwargs"):
98+
self.extra_kwargs = {}
99+
100+
self.extra_kwargs["group_id"] = group_id
101+
self.extra_kwargs["pair_index"] = pair_index
102+
103+
def _initial_observation(self) -> Any:
104+
if not self.env:
105+
# Initialize environment if not created yet.
106+
env_args = EnvArgs(ds=self.entry)
107+
self.env = RepoEnv(
108+
env_args,
109+
backend=self.backend,
110+
step_timeout=self.step_timeout,
111+
reward_timeout=self.reward_timeout,
112+
verbose=self.verbose,
113+
)
114+
else:
115+
self.env.reset()
116+
if self.scaffold == "r2egym":
117+
self.env.add_commands(R2EGYM_COMMAND_FILES)
118+
else:
119+
self.env.add_commands(SWEAGENT_COMMAND_FILES)
120+
self.total_steps = 0
121+
122+
# Polls docker runtime to get task instruction.
123+
return self.env.get_task_instruction()
124+
125+
def _step_impl(self, action: Any) -> EnvStepResult:
126+
if isinstance(action, str):
127+
action_obj = Action.from_string(action)
128+
else:
129+
action_obj = action
130+
131+
if not action_obj.function_name:
132+
return EnvStepResult(observation="", reward=0, done=False, info={})
133+
134+
# RepoEnv always returns 0 reward, must be evaluated by DockerRuntime.
135+
if not self.env:
136+
raise ValueError("Environment not initialized")
137+
obs, reward, done, info = self.env.step(action_obj)
138+
139+
self.total_steps += 1
140+
141+
return EnvStepResult(
142+
observation=str(obs), reward=reward, done=done, info=info
143+
)
144+
145+
def close(self) -> None:
146+
"""Close the environment and clean up resources."""
147+
if self.env is not None:
148+
self.env.close()
149+
150+
if self.delete_image and self.env:
151+
docker_image = self.env.runtime.docker_image
152+
os.system(f"docker rmi {docker_image}")
153+
154+
def compute_final_reward(self, *args) -> reward_types.RewardOutput:
155+
"""Run tests in the Docker container and return reward wrapped in an object."""
156+
# Get the raw float/int reward
157+
reward_val = float(self.env.compute_reward())
158+
159+
# Return it wrapped in the object the engine expects
160+
return reward_types.RewardOutput(reward=reward_val)
161+
162+
@staticmethod
163+
def from_dict(extra_info: dict | str) -> "SWEEnv":
164+
"""Create an environment instance from JSON configuration.
165+
166+
Args:
167+
extra_info: Dictionary containing configuration parameters. The entire
168+
dict will be used as 'entry', and any keys matching __init__
169+
parameters will be extracted and passed.
170+
171+
Returns:
172+
Initialized SWEEnv instance
173+
"""
174+
import inspect
175+
176+
if isinstance(extra_info, str):
177+
extra_info = json.loads(extra_info)
178+
179+
sig = inspect.signature(SWEEnv.__init__)
180+
init_params = {}
181+
for param_name, param in sig.parameters.items():
182+
if param_name == "self":
183+
continue
184+
if param_name in extra_info:
185+
init_params[param_name] = extra_info[param_name]
186+
# else if param has default value, use the default value
187+
init_params["entry"] = extra_info
188+
return SWEEnv(**init_params)

0 commit comments

Comments
 (0)