Skip to content

Commit 0b40cb7

Browse files
authored
New example: EMPO2 (#524)
1 parent c746af2 commit 0b40cb7

19 files changed

Lines changed: 1101 additions & 36 deletions
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
import copy
4+
import logging
5+
from typing import Any, Dict
6+
7+
import numpy as np
8+
import requests
9+
from add_instruction import add_chat_all_tips, add_chat_instruction
10+
from agl_envs import make_env_manager
11+
12+
from agentlightning import LLM, NamedResources, Rollout, configure_logger, emit_reward, operation
13+
from agentlightning.utils.otel import make_link_attributes
14+
from contrib.agentlightning.contrib.agent.env_agent import EnvAgent
15+
from contrib.recipes.envs.prompt_builder import HistoryPromptBuilder
16+
17+
configure_logger()
18+
logger = configure_logger(name=__name__, level=logging.ERROR)
19+
20+
21+
def do_compress(text):
22+
url = "http://127.0.0.1:8000/key_cal/"
23+
headers = {"Content-Type": "application/json"} # 明确指定 JSON 格式
24+
data = {"text": text}
25+
response = requests.post(url, json=data, headers=headers) # 使用 json 参数
26+
return response.json()
27+
28+
29+
url_mem = "http://127.0.0.1:8001/mem/"
30+
31+
32+
def retrieve_memory(idx, key):
33+
response = requests.post(url_mem, json={"key": key, "idx": idx})
34+
count, data = response.json()
35+
return count, data
36+
37+
38+
def reset_memory(mem_list_num):
39+
requests.post(url_mem, json={"key": [], "idx": mem_list_num, "content": "Reset"}) # 用于初始化多个 memory slot
40+
41+
42+
def add_memory(idx, key, content, score):
43+
requests.post(url_mem, json={"key": key, "idx": idx, "content": content, "score": score})
44+
45+
46+
def gather_chats(prompt):
47+
chat_list = []
48+
for item in prompt:
49+
role = item.type
50+
content = item.content
51+
if "System" in role:
52+
continue
53+
elif "User" in role:
54+
role = "user"
55+
else:
56+
role = "assistant"
57+
chat_list.append(f"{role}: {content}")
58+
text = " ".join(chat_list)
59+
return text
60+
61+
62+
class EMPO2Agent(EnvAgent):
63+
def _get_all_tip_prompt(self, prompt, tip_list):
64+
prompt_type = self.config.captioner.prompt_type
65+
if prompt_type == "chat":
66+
return add_chat_all_tips(prompt, tip_list)
67+
else:
68+
raise ValueError(f"Unsupported prompt_type '{prompt_type}' for _get_tip_obs (expected 'chat')")
69+
70+
def _get_tip_generation_prompt(self, prompt):
71+
return add_chat_instruction(prompt, "tip")
72+
73+
async def rollout_async(
74+
self,
75+
task: Dict[str, Any],
76+
resources: NamedResources,
77+
rollout: Rollout,
78+
) -> float | None:
79+
rollout_id = rollout.rollout_id
80+
logger.info(f"[Rollout {rollout_id}] Task: {task}")
81+
82+
reward_scale = float(self.config["reawrd_scale"])
83+
84+
# Setup LLM + agent
85+
llm: LLM = resources.get("main_llm")
86+
print("Training with model:", llm.model, "on endpoint:", llm.endpoint)
87+
self.agent = self._build_agent(llm, 1.0 if rollout.mode == "train" else 0.4)
88+
89+
if rollout.mode == "train":
90+
train_mode = task["train_mode"]
91+
global_steps = task["global_steps"]
92+
else:
93+
train_mode = "on-policy"
94+
95+
if rollout.mode == "train" and (train_mode == "off-policy" or train_mode == "on-policy-with-tips"):
96+
use_tips = True
97+
else:
98+
use_tips = False
99+
100+
variation_idx = task["variation_idx"]
101+
102+
try:
103+
# Setup environment
104+
prompt_builder = HistoryPromptBuilder(
105+
max_history=self.config.captioner.max_history, prompt_type=self.config.captioner.prompt_type
106+
)
107+
108+
self.env = make_env_manager(self.config.env_name, task, self.config)
109+
env_obs, infos, available_actions_hint = self.env.reset()
110+
111+
prompt_builder.init(self.env)
112+
prompt_builder.update_observation(env_obs)
113+
# prompt_builder.update_admissible_actions(available_actions_hint)
114+
115+
prompt = prompt_builder.get_prompt()
116+
117+
episode_reward, done = 0.0, False
118+
119+
history_actions_for_mem = []
120+
tip_list = []
121+
step_count = 0
122+
while not done:
123+
if use_tips:
124+
text = gather_chats(prompt)
125+
key = (
126+
np.array(do_compress(text)["key"])
127+
.reshape(
128+
-1,
129+
)
130+
.tolist()
131+
)
132+
count, mem_list = retrieve_memory(variation_idx, key)
133+
else:
134+
count, mem_list = 0, []
135+
136+
ret_tips, intrinsic_reward = "", 0.0
137+
138+
if use_tips:
139+
if count > 0:
140+
ret_tips = "Here are some memories you collected in your previous exploration:\n"
141+
for mem in mem_list:
142+
ret_tips += mem + "\n"
143+
144+
tip_list.append(ret_tips)
145+
intrinsic_reward = 1 / (count + 1)
146+
else:
147+
tip_list.append("")
148+
intrinsic_reward = 1
149+
150+
try:
151+
if use_tips and any(t != "" for t in tip_list):
152+
llm_prompt = self._get_all_tip_prompt(prompt, tip_list)
153+
else:
154+
llm_prompt = prompt
155+
156+
instructed_prompt = self._get_instructed_prompt(llm_prompt)
157+
158+
# Main agent step
159+
with operation(step_count=step_count):
160+
result = await self.agent._model_client.create(instructed_prompt)
161+
output = result.content
162+
logger.info(f"[LLM output]: {output}")
163+
164+
except Exception as e:
165+
logger.error(f"[Rollout {rollout_id}] Error during training rollout: {e}", exc_info=True)
166+
break
167+
168+
env_obs, executed_action, is_valid, step_reward, terminated, truncated, info, available_actions_hint = (
169+
self.env.step(
170+
output,
171+
use_reasoning=self.config.captioner.type == "cot",
172+
use_success_rate=self.config.use_success_rate,
173+
)
174+
)
175+
176+
history_actions_for_mem.append(output)
177+
178+
action_for_history = output if self.config.get("record_original_action", False) else executed_action
179+
180+
prompt_builder.update_step_count()
181+
prompt_builder.update_action(action_for_history)
182+
prompt_builder.update_observation(env_obs)
183+
# prompt_builder.update_admissible_actions(available_actions_hint)
184+
185+
prompt = prompt_builder.get_prompt()
186+
187+
if rollout.mode == "train":
188+
step_reward = reward_scale * step_reward
189+
190+
emit_reward(
191+
{
192+
"extrinsic_reward": step_reward,
193+
"intrinsic_reward": intrinsic_reward,
194+
},
195+
primary_key="extrinsic_reward",
196+
attributes=make_link_attributes({"step_count": str(step_count)}),
197+
)
198+
199+
episode_reward += float(step_reward)
200+
done = np.logical_or(terminated, truncated)
201+
202+
step_count += 1
203+
204+
if rollout.mode == "train":
205+
prompt_builder.prompt_type = "chat"
206+
prompt_builder.max_history = -1
207+
full_prompt = prompt_builder.get_prompt()
208+
209+
# Add tips as raw text (no tags)
210+
if use_tips and len(tip_list) > 0:
211+
tip_base_prompt = copy.deepcopy(full_prompt)
212+
tips_iter = iter(tip_list)
213+
for item in tip_base_prompt:
214+
if "User" in item.type:
215+
tip = next(tips_iter, None)
216+
if tip is None:
217+
break
218+
if tip != "":
219+
item.content += tip
220+
else:
221+
tip_base_prompt = full_prompt
222+
223+
tip_generation_prompt = self._get_tip_generation_prompt(tip_base_prompt)
224+
225+
self.agent._model_client.max_tokens = 512
226+
result = await self.agent._model_client.create(tip_generation_prompt)
227+
tips = result.content
228+
229+
logger.info(f"Tips: {tips}")
230+
231+
#! Fill the ret and tip, then save memory
232+
#! Use final prompt state for ALL steps' keys
233+
final_prompt_text = gather_chats(prompt)
234+
final_key = (
235+
np.array(do_compress(final_prompt_text)["key"])
236+
.reshape(
237+
-1,
238+
)
239+
.tolist()
240+
)
241+
242+
for i in range(len(history_actions_for_mem)):
243+
max_score = 100 * reward_scale
244+
content = (
245+
tips
246+
+ f"; At that timestep, the specific action your took was {history_actions_for_mem[i]}; Eventually you got the score {round(episode_reward, 1)}/{int(max_score)}."
247+
)
248+
score = episode_reward
249+
add_memory(variation_idx, final_key, content, round(score, 1))
250+
251+
if self.config.use_success_rate:
252+
return self.env.get_success_score() * reward_scale
253+
else:
254+
return episode_reward
255+
256+
finally:
257+
if self.env is not None:
258+
self.env.close()

contrib/agentlightning/contrib/agent/env_agent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,10 @@ async def rollout_async(
127127
)
128128
)
129129

130+
action_for_history = output if self.config.get("record_original_action", False) else executed_action
131+
130132
prompt_builder.update_step_count()
131-
prompt_builder.update_action(executed_action)
133+
prompt_builder.update_action(action_for_history)
132134
prompt_builder.update_observation(env_obs)
133135
prompt_builder.update_admissible_actions(available_actions_hint)
134136

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
from typing import Any, List
4+
5+
import torch
6+
7+
8+
def is_sublist(sub, full):
9+
n, m = len(sub), len(full)
10+
return any(full[i : i + n] == sub for i in range(m - n + 1))
11+
12+
13+
# Function to remove segments of a list between a start pattern and an end pattern
14+
def remove_pattern_ranges(seq: List[Any], start_pat: List[Any], end_pat: List[Any]) -> List[Any]:
15+
"""Remove every [start_pat ... end_pat] slice (inclusive) from seq."""
16+
17+
out: List[Any] = []
18+
i = 0
19+
n = len(seq)
20+
ls, le = len(start_pat), len(end_pat)
21+
22+
while i < n:
23+
# Check if the start pattern matches at the current position
24+
if i + ls <= n and seq[i : i + ls] == start_pat:
25+
# Look for the first occurrence of the end pattern after the start pattern
26+
j = i + ls
27+
found_end = -1
28+
while j + le <= n:
29+
if seq[j : j + le] == end_pat:
30+
found_end = j
31+
break # Stop when the end pattern is found
32+
j += 1
33+
34+
# If the end pattern is found, skip the whole segment from start to end
35+
if found_end != -1:
36+
i = found_end + le # Move the index past the end pattern
37+
continue # Skip the current iteration and go to the next
38+
else:
39+
# If the end pattern is not found, keep the current element and move one step forward
40+
out.append(seq[i])
41+
i += 1
42+
else:
43+
# If the start pattern is not found, just append the current element
44+
out.append(seq[i])
45+
i += 1
46+
47+
# Return the filtered list with the start-end pattern segments removed
48+
return out
49+
50+
51+
def low_prob_token_masking(batch, threshold: float = -5.0):
52+
response_mask = batch.batch["response_mask"] # [N, T]
53+
old_log_prob = batch.batch["old_log_probs"]
54+
55+
masked_old_log_prob = old_log_prob.masked_fill(response_mask == 0, 1e9)
56+
min_values, _ = torch.min(masked_old_log_prob, dim=1) # [N]
57+
58+
mask = min_values < threshold # [N]
59+
60+
combined_mask = mask.unsqueeze(1) & (response_mask == 1)
61+
62+
# advantages masking
63+
response_mask = response_mask.masked_fill(combined_mask, 0)
64+
batch.batch["response_mask"] = response_mask
65+
66+
print(f"Number of tokens masked: {combined_mask.sum().item()}")
67+
68+
return batch

0 commit comments

Comments
 (0)