|
| 1 | +# Copyright (c) Microsoft. All rights reserved. |
| 2 | + |
| 3 | +import copy |
| 4 | +import logging |
| 5 | +from typing import Any, Dict |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import requests |
| 9 | +from add_instruction import add_chat_all_tips, add_chat_instruction |
| 10 | +from agl_envs import make_env_manager |
| 11 | + |
| 12 | +from agentlightning import LLM, NamedResources, Rollout, configure_logger, emit_reward, operation |
| 13 | +from agentlightning.utils.otel import make_link_attributes |
| 14 | +from contrib.agentlightning.contrib.agent.env_agent import EnvAgent |
| 15 | +from contrib.recipes.envs.prompt_builder import HistoryPromptBuilder |
| 16 | + |
| 17 | +configure_logger() |
| 18 | +logger = configure_logger(name=__name__, level=logging.ERROR) |
| 19 | + |
| 20 | + |
| 21 | +def do_compress(text): |
| 22 | + url = "http://127.0.0.1:8000/key_cal/" |
| 23 | + headers = {"Content-Type": "application/json"} # 明确指定 JSON 格式 |
| 24 | + data = {"text": text} |
| 25 | + response = requests.post(url, json=data, headers=headers) # 使用 json 参数 |
| 26 | + return response.json() |
| 27 | + |
| 28 | + |
| 29 | +url_mem = "http://127.0.0.1:8001/mem/" |
| 30 | + |
| 31 | + |
| 32 | +def retrieve_memory(idx, key): |
| 33 | + response = requests.post(url_mem, json={"key": key, "idx": idx}) |
| 34 | + count, data = response.json() |
| 35 | + return count, data |
| 36 | + |
| 37 | + |
| 38 | +def reset_memory(mem_list_num): |
| 39 | + requests.post(url_mem, json={"key": [], "idx": mem_list_num, "content": "Reset"}) # 用于初始化多个 memory slot |
| 40 | + |
| 41 | + |
| 42 | +def add_memory(idx, key, content, score): |
| 43 | + requests.post(url_mem, json={"key": key, "idx": idx, "content": content, "score": score}) |
| 44 | + |
| 45 | + |
| 46 | +def gather_chats(prompt): |
| 47 | + chat_list = [] |
| 48 | + for item in prompt: |
| 49 | + role = item.type |
| 50 | + content = item.content |
| 51 | + if "System" in role: |
| 52 | + continue |
| 53 | + elif "User" in role: |
| 54 | + role = "user" |
| 55 | + else: |
| 56 | + role = "assistant" |
| 57 | + chat_list.append(f"{role}: {content}") |
| 58 | + text = " ".join(chat_list) |
| 59 | + return text |
| 60 | + |
| 61 | + |
| 62 | +class EMPO2Agent(EnvAgent): |
| 63 | + def _get_all_tip_prompt(self, prompt, tip_list): |
| 64 | + prompt_type = self.config.captioner.prompt_type |
| 65 | + if prompt_type == "chat": |
| 66 | + return add_chat_all_tips(prompt, tip_list) |
| 67 | + else: |
| 68 | + raise ValueError(f"Unsupported prompt_type '{prompt_type}' for _get_tip_obs (expected 'chat')") |
| 69 | + |
| 70 | + def _get_tip_generation_prompt(self, prompt): |
| 71 | + return add_chat_instruction(prompt, "tip") |
| 72 | + |
| 73 | + async def rollout_async( |
| 74 | + self, |
| 75 | + task: Dict[str, Any], |
| 76 | + resources: NamedResources, |
| 77 | + rollout: Rollout, |
| 78 | + ) -> float | None: |
| 79 | + rollout_id = rollout.rollout_id |
| 80 | + logger.info(f"[Rollout {rollout_id}] Task: {task}") |
| 81 | + |
| 82 | + reward_scale = float(self.config["reawrd_scale"]) |
| 83 | + |
| 84 | + # Setup LLM + agent |
| 85 | + llm: LLM = resources.get("main_llm") |
| 86 | + print("Training with model:", llm.model, "on endpoint:", llm.endpoint) |
| 87 | + self.agent = self._build_agent(llm, 1.0 if rollout.mode == "train" else 0.4) |
| 88 | + |
| 89 | + if rollout.mode == "train": |
| 90 | + train_mode = task["train_mode"] |
| 91 | + global_steps = task["global_steps"] |
| 92 | + else: |
| 93 | + train_mode = "on-policy" |
| 94 | + |
| 95 | + if rollout.mode == "train" and (train_mode == "off-policy" or train_mode == "on-policy-with-tips"): |
| 96 | + use_tips = True |
| 97 | + else: |
| 98 | + use_tips = False |
| 99 | + |
| 100 | + variation_idx = task["variation_idx"] |
| 101 | + |
| 102 | + try: |
| 103 | + # Setup environment |
| 104 | + prompt_builder = HistoryPromptBuilder( |
| 105 | + max_history=self.config.captioner.max_history, prompt_type=self.config.captioner.prompt_type |
| 106 | + ) |
| 107 | + |
| 108 | + self.env = make_env_manager(self.config.env_name, task, self.config) |
| 109 | + env_obs, infos, available_actions_hint = self.env.reset() |
| 110 | + |
| 111 | + prompt_builder.init(self.env) |
| 112 | + prompt_builder.update_observation(env_obs) |
| 113 | + # prompt_builder.update_admissible_actions(available_actions_hint) |
| 114 | + |
| 115 | + prompt = prompt_builder.get_prompt() |
| 116 | + |
| 117 | + episode_reward, done = 0.0, False |
| 118 | + |
| 119 | + history_actions_for_mem = [] |
| 120 | + tip_list = [] |
| 121 | + step_count = 0 |
| 122 | + while not done: |
| 123 | + if use_tips: |
| 124 | + text = gather_chats(prompt) |
| 125 | + key = ( |
| 126 | + np.array(do_compress(text)["key"]) |
| 127 | + .reshape( |
| 128 | + -1, |
| 129 | + ) |
| 130 | + .tolist() |
| 131 | + ) |
| 132 | + count, mem_list = retrieve_memory(variation_idx, key) |
| 133 | + else: |
| 134 | + count, mem_list = 0, [] |
| 135 | + |
| 136 | + ret_tips, intrinsic_reward = "", 0.0 |
| 137 | + |
| 138 | + if use_tips: |
| 139 | + if count > 0: |
| 140 | + ret_tips = "Here are some memories you collected in your previous exploration:\n" |
| 141 | + for mem in mem_list: |
| 142 | + ret_tips += mem + "\n" |
| 143 | + |
| 144 | + tip_list.append(ret_tips) |
| 145 | + intrinsic_reward = 1 / (count + 1) |
| 146 | + else: |
| 147 | + tip_list.append("") |
| 148 | + intrinsic_reward = 1 |
| 149 | + |
| 150 | + try: |
| 151 | + if use_tips and any(t != "" for t in tip_list): |
| 152 | + llm_prompt = self._get_all_tip_prompt(prompt, tip_list) |
| 153 | + else: |
| 154 | + llm_prompt = prompt |
| 155 | + |
| 156 | + instructed_prompt = self._get_instructed_prompt(llm_prompt) |
| 157 | + |
| 158 | + # Main agent step |
| 159 | + with operation(step_count=step_count): |
| 160 | + result = await self.agent._model_client.create(instructed_prompt) |
| 161 | + output = result.content |
| 162 | + logger.info(f"[LLM output]: {output}") |
| 163 | + |
| 164 | + except Exception as e: |
| 165 | + logger.error(f"[Rollout {rollout_id}] Error during training rollout: {e}", exc_info=True) |
| 166 | + break |
| 167 | + |
| 168 | + env_obs, executed_action, is_valid, step_reward, terminated, truncated, info, available_actions_hint = ( |
| 169 | + self.env.step( |
| 170 | + output, |
| 171 | + use_reasoning=self.config.captioner.type == "cot", |
| 172 | + use_success_rate=self.config.use_success_rate, |
| 173 | + ) |
| 174 | + ) |
| 175 | + |
| 176 | + history_actions_for_mem.append(output) |
| 177 | + |
| 178 | + action_for_history = output if self.config.get("record_original_action", False) else executed_action |
| 179 | + |
| 180 | + prompt_builder.update_step_count() |
| 181 | + prompt_builder.update_action(action_for_history) |
| 182 | + prompt_builder.update_observation(env_obs) |
| 183 | + # prompt_builder.update_admissible_actions(available_actions_hint) |
| 184 | + |
| 185 | + prompt = prompt_builder.get_prompt() |
| 186 | + |
| 187 | + if rollout.mode == "train": |
| 188 | + step_reward = reward_scale * step_reward |
| 189 | + |
| 190 | + emit_reward( |
| 191 | + { |
| 192 | + "extrinsic_reward": step_reward, |
| 193 | + "intrinsic_reward": intrinsic_reward, |
| 194 | + }, |
| 195 | + primary_key="extrinsic_reward", |
| 196 | + attributes=make_link_attributes({"step_count": str(step_count)}), |
| 197 | + ) |
| 198 | + |
| 199 | + episode_reward += float(step_reward) |
| 200 | + done = np.logical_or(terminated, truncated) |
| 201 | + |
| 202 | + step_count += 1 |
| 203 | + |
| 204 | + if rollout.mode == "train": |
| 205 | + prompt_builder.prompt_type = "chat" |
| 206 | + prompt_builder.max_history = -1 |
| 207 | + full_prompt = prompt_builder.get_prompt() |
| 208 | + |
| 209 | + # Add tips as raw text (no tags) |
| 210 | + if use_tips and len(tip_list) > 0: |
| 211 | + tip_base_prompt = copy.deepcopy(full_prompt) |
| 212 | + tips_iter = iter(tip_list) |
| 213 | + for item in tip_base_prompt: |
| 214 | + if "User" in item.type: |
| 215 | + tip = next(tips_iter, None) |
| 216 | + if tip is None: |
| 217 | + break |
| 218 | + if tip != "": |
| 219 | + item.content += tip |
| 220 | + else: |
| 221 | + tip_base_prompt = full_prompt |
| 222 | + |
| 223 | + tip_generation_prompt = self._get_tip_generation_prompt(tip_base_prompt) |
| 224 | + |
| 225 | + self.agent._model_client.max_tokens = 512 |
| 226 | + result = await self.agent._model_client.create(tip_generation_prompt) |
| 227 | + tips = result.content |
| 228 | + |
| 229 | + logger.info(f"Tips: {tips}") |
| 230 | + |
| 231 | + #! Fill the ret and tip, then save memory |
| 232 | + #! Use final prompt state for ALL steps' keys |
| 233 | + final_prompt_text = gather_chats(prompt) |
| 234 | + final_key = ( |
| 235 | + np.array(do_compress(final_prompt_text)["key"]) |
| 236 | + .reshape( |
| 237 | + -1, |
| 238 | + ) |
| 239 | + .tolist() |
| 240 | + ) |
| 241 | + |
| 242 | + for i in range(len(history_actions_for_mem)): |
| 243 | + max_score = 100 * reward_scale |
| 244 | + content = ( |
| 245 | + tips |
| 246 | + + f"; At that timestep, the specific action your took was {history_actions_for_mem[i]}; Eventually you got the score {round(episode_reward, 1)}/{int(max_score)}." |
| 247 | + ) |
| 248 | + score = episode_reward |
| 249 | + add_memory(variation_idx, final_key, content, round(score, 1)) |
| 250 | + |
| 251 | + if self.config.use_success_rate: |
| 252 | + return self.env.get_success_score() * reward_scale |
| 253 | + else: |
| 254 | + return episode_reward |
| 255 | + |
| 256 | + finally: |
| 257 | + if self.env is not None: |
| 258 | + self.env.close() |
0 commit comments