|
| 1 | +# 使用 AgentScope-Tuner 训练邮件搜索智能体 |
| 2 | + |
| 3 | +本示例展示如何使用 AgentScope-Tuner 对邮件搜索任务(灵感来自 [ART](https://openpipe.ai/blog/art-e-mail-agent))进行强化微调,其 RFT 功能由 [Trinity-RFT](https://github.com/modelscope/Trinity-RFT) 提供支持。 |
| 4 | + |
| 5 | +## 任务设定 |
| 6 | + |
| 7 | +智能体的目标是通过搜索邮件收件箱来回答用户查询。智能体需要: |
| 8 | +- 理解用户的问题 |
| 9 | +- 使用关键词搜索相关邮件 |
| 10 | +- 阅读邮件内容以提取信息 |
| 11 | +- 提供准确的答案并附上适当的来源引用 |
| 12 | + |
| 13 | +**智能体类型**:智能体(`EmailSearchAgent`)继承自 `ReActAgent`,遵循推理-行动循环来迭代解决任务。 |
| 14 | + |
| 15 | +**环境**:环境是一个包含来自 Enron 邮件数据集的 SQLite 数据库。每个任务提供: |
| 16 | +- `question`:用户的邮件搜索查询 |
| 17 | +- `inbox_address`:要搜索的邮件收件箱 |
| 18 | +- `query_date`:查询的日期上下文 |
| 19 | +- `answer`:预期答案(真实值),仅用于奖励计算 |
| 20 | +- `message_ids`:包含答案的相关邮件 ID,仅用于奖励计算 |
| 21 | + |
| 22 | +**可用工具**: |
| 23 | +- `search_emails`:通过关键词、收件箱地址和日期范围查找邮件。返回邮件摘要列表(message_id 和片段)。 |
| 24 | +- `read_email`:通过 message_id 读取特定邮件的完整内容。 |
| 25 | +- `generate_response`:提供带有来源的最终结构化答案(继承自 ReAct 智能体)。 |
| 26 | + |
| 27 | +## 数据集准备 |
| 28 | + |
| 29 | +数据集包含基于 [Enron 邮件数据集](https://huggingface.co/datasets/corbt/enron-emails) 的邮件查询。运行数据准备脚本以生成邮件数据库和数据集: |
| 30 | + |
| 31 | +```bash |
| 32 | +python prepare_data.py |
| 33 | +``` |
| 34 | + |
| 35 | +如果你想选择新的数据库路径,可以修改 [`prepare_data.py`](./prepare_data.py) 中的 `DEFAULT_DB_PATH`。同时,请记住在进入下一步之前设置环境变量 `DEFAULT_EMAIL_DB_PATH` 指向数据库路径: |
| 36 | + |
| 37 | +```bash |
| 38 | +export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db |
| 39 | +``` |
| 40 | + |
| 41 | +这将创建一个 SQLite 数据库和数据集: |
| 42 | + |
| 43 | +``` |
| 44 | +/path/to/enron_emails_dataset/ |
| 45 | + ├── data |
| 46 | + └── enron_emails.db # 邮件数据库 |
| 47 | + ├── train.parquet # 训练样本 |
| 48 | + └── test.parquet # 测试样本 |
| 49 | +``` |
| 50 | + |
| 51 | +每个样本如下所示: |
| 52 | + |
| 53 | +```json |
| 54 | +{ |
| 55 | + "id": 0, |
| 56 | + "question": "Were there any variances detected for hour 6 on 3/9/01?", |
| 57 | + "answer": "Yes, variances were detected in both Generation and Energy Import/Export schedules for hour 6 on 3/9/01.", |
| 58 | + "message_ids": ["<17407857.1075840601283.JavaMail.evans@thyme>"], |
| 59 | + "how_realistic": 0.800000011920929, |
| 60 | + "inbox_address": "pete.davis@enron.com", |
| 61 | + "query_date": "2001-03-16" |
| 62 | +} |
| 63 | +``` |
| 64 | + |
| 65 | +## 代码实现 |
| 66 | + |
| 67 | +本节提供代码实现的高级概览。详细实现请参考源代码。 |
| 68 | + |
| 69 | +### 智能体工作流 |
| 70 | + |
| 71 | +工作流函数 `run_email_search_agent` 实现智能体-环境交互循环: |
| 72 | + |
| 73 | +```python |
| 74 | +async def run_email_search_agent( |
| 75 | + task: Dict, |
| 76 | + model: ChatModelBase, |
| 77 | + auxiliary_models: Dict[str, ChatModelBase], |
| 78 | +) -> WorkflowOutput: |
| 79 | + # 解析任务并创建智能体 |
| 80 | + agent = EmailSearchAgent( |
| 81 | + name="email_search_agent", |
| 82 | + sys_prompt=system_prompt, |
| 83 | + model=model, |
| 84 | + max_iters=max_turns, |
| 85 | + ) |
| 86 | + |
| 87 | + # 使用结构化输出运行智能体 |
| 88 | + response = await agent.reply( |
| 89 | + msg=Msg("user", question, role="user"), |
| 90 | + structured_model=AnswerModel, |
| 91 | + ) |
| 92 | + |
| 93 | + return WorkflowOutput(response=response) |
| 94 | +``` |
| 95 | + |
| 96 | +智能体遵循 ReAct 模式:它推理任务,调用工具搜索和阅读邮件,最后生成包含答案和来源消息 ID 的结构化响应。 |
| 97 | + |
| 98 | +### 评判函数 |
| 99 | + |
| 100 | +评判函数 `email_search_judge` 使用 LLM-as-a-Judge 实现奖励计算: |
| 101 | + |
| 102 | +```python |
| 103 | +async def email_search_judge( |
| 104 | + task: Dict, |
| 105 | + response: Msg, |
| 106 | + auxiliary_models: Dict[str, ChatModelBase], |
| 107 | +) -> JudgeOutput: |
| 108 | + # 从响应中提取答案和来源 |
| 109 | + answer = answer_and_sources.get("answer") |
| 110 | + sources = answer_and_sources.get("sources", []) |
| 111 | + |
| 112 | + # 使用 LLM-as-a-Judge 评判正确性 |
| 113 | + judge_model = auxiliary_models.get('judge') or list(auxiliary_models.values())[0] |
| 114 | + judge_response = await judge_correctness( |
| 115 | + answer, query, judge_model |
| 116 | + ) |
| 117 | + |
| 118 | + # 基于以下因素计算奖励: |
| 119 | + # - 答案正确性(准确度:-1.0 到 1.0) |
| 120 | + # - 来源正确性(格式:部分奖励) |
| 121 | + # - 效率(对更少轮次、正确来源的奖励) |
| 122 | + result = {"accuracy": ..., "format": ...} # 基于 judge_response 计算 |
| 123 | + |
| 124 | + return JudgeOutput( |
| 125 | + reward=sum(result.values()), |
| 126 | + metrics=metrics, |
| 127 | + ) |
| 128 | +``` |
| 129 | + |
| 130 | +奖励函数考虑以下因素: |
| 131 | +- **答案正确性**:通过 LLM-as-a-Judge 比较智能体的答案与真实值进行评估 |
| 132 | +- **来源正确性**:智能体是否引用了正确的邮件消息 ID |
| 133 | +- **效率**:对找到/阅读正确邮件和更少轮次的奖励 |
| 134 | + |
| 135 | +详细实现请参考 [`main.py`](./main.py) 和 [`email_search_agent.py`](./email_search_agent.py)。 |
| 136 | + |
| 137 | +## 运行方法 |
| 138 | + |
| 139 | +### 前置要求 |
| 140 | + |
| 141 | +- 至少 4 张 NVIDIA GPU,CUDA 版本 ≥ 12.8 |
| 142 | + * 注意:对于 30B 评判模型,需要使用至少 4080 显存的 GPU;你也可以通过使用 `tensor_parallel_size > 1` 在多张 GPU 上运行模型以减少显存使用(默认情况下,`tensor_parallel_size=2`)。 |
| 143 | +- 按照 Trinity-RFT [安装指南](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) 从源码安装最新版本 |
| 144 | +- 下载模型检查点(示例): |
| 145 | + |
| 146 | + ```bash |
| 147 | + huggingface-cli download Qwen/Qwen3-4B-Instruct-2507 |
| 148 | + huggingface-cli download Qwen/Qwen3-30B-A3B-Instruct-2507 # 评判模型 |
| 149 | + ``` |
| 150 | + |
| 151 | +### 配置 |
| 152 | + |
| 153 | +根据你的硬件调整配置文件([`config.yaml`](./config.yaml))。关键配置部分包括: |
| 154 | + |
| 155 | +- **TunerModelConfig**:将 `model_path` 设置为你的模型检查点路径 |
| 156 | +- **AlgorithmConfig**:配置 RL 算法参数(例如,`multi_step_grpo`、学习率、策略损失函数) |
| 157 | +- **DatasetConfig**:数据集路径在创建 `DatasetConfig` 对象时在 `main.py` 中指定 |
| 158 | +- **辅助模型**:为 LLM-as-a-Judge 配置评判模型设置 |
| 159 | + |
| 160 | +完整配置详情请参考 [Trinity-RFT 配置指南](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html)。 |
| 161 | + |
| 162 | +### 启动命令 |
| 163 | + |
| 164 | +1. 准备数据集: |
| 165 | + |
| 166 | + ```bash |
| 167 | + python prepare_data.py |
| 168 | + export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db |
| 169 | + ``` |
| 170 | + |
| 171 | +2. 启动 [Ray](https://github.com/ray-project/ray): |
| 172 | + |
| 173 | + ```bash |
| 174 | + ray start --head |
| 175 | + ``` |
| 176 | + |
| 177 | +3. 运行训练脚本: |
| 178 | + |
| 179 | + ```bash |
| 180 | + python main.py |
| 181 | + ``` |
| 182 | + |
| 183 | +## 实验结果 |
| 184 | + |
| 185 | +### 定量结果 |
| 186 | + |
| 187 | +训练结果显示智能体性能随训练迭代次数的提升。关键指标包括: |
| 188 | + |
| 189 | +- **训练奖励**:训练样本上的平均奖励随着智能体学习更好的策略而增加 |
| 190 | +- **Rollout 准确度**:Rollout 样本上的平均准确度随着智能体学习更好的策略而增加 |
| 191 | + |
| 192 | + |
| 193 | + |
| 194 | + |
| 195 | + |
| 196 | + |
| 197 | +### 具体示例 |
| 198 | + |
| 199 | +智能体行为示例如下: |
| 200 | + |
| 201 | +**查询:** "What do the color codes mean in the curve assessment?" |
| 202 | + |
| 203 | +我们展示智能体响应的最后几轮: |
| 204 | + |
| 205 | +智能体执行多次搜索尝试以找到相关邮件。经过一些不成功的搜索后,智能体尝试: |
| 206 | + |
| 207 | +**工具调用:** |
| 208 | +```json |
| 209 | +{ |
| 210 | + "type": "tool_use", |
| 211 | + "name": "search_emails", |
| 212 | + "input": { |
| 213 | + "inbox_address": "steven.kean@enron.com", |
| 214 | + "query_date": "2001-04-03", |
| 215 | + "keywords": ["curve", "assessment"] |
| 216 | + } |
| 217 | +} |
| 218 | +``` |
| 219 | + |
| 220 | +**工具结果:** |
| 221 | +```json |
| 222 | +{ |
| 223 | + "type": "tool_result", |
| 224 | + "name": "search_emails", |
| 225 | + "output": [ |
| 226 | + { |
| 227 | + "message_id": "<2654330.1075846153519.JavaMail.evans@thyme>", |
| 228 | + "snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..." |
| 229 | + }, |
| 230 | + { |
| 231 | + "message_id": "<12499967.1075846153495.JavaMail.evans@thyme>", |
| 232 | + "snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..." |
| 233 | + } |
| 234 | + ] |
| 235 | +} |
| 236 | +``` |
| 237 | + |
| 238 | +找到相关邮件后,智能体使用 `read_email` 工具读取两封邮件的完整内容: |
| 239 | + |
| 240 | +**工具调用:** |
| 241 | +```json |
| 242 | +{ |
| 243 | + "type": "tool_use", |
| 244 | + "name": "read_email", |
| 245 | + "input": { |
| 246 | + "message_id": "<2654330.1075846153519.JavaMail.evans@thyme>" |
| 247 | + } |
| 248 | +} |
| 249 | +``` |
| 250 | + |
| 251 | +**工具调用:** |
| 252 | +```json |
| 253 | +{ |
| 254 | + "type": "tool_use", |
| 255 | + "name": "read_email", |
| 256 | + "input": { |
| 257 | + "message_id": "<12499967.1075846153495.JavaMail.evans@thyme>" |
| 258 | + } |
| 259 | +} |
| 260 | +``` |
| 261 | + |
| 262 | +阅读邮件后,智能体生成最终的结构化响应: |
| 263 | + |
| 264 | +**工具调用:** |
| 265 | +```json |
| 266 | +{ |
| 267 | + "type": "tool_use", |
| 268 | + "name": "generate_response", |
| 269 | + "input": { |
| 270 | + "answer": "In the curve assessment, the color codes are used as follows:\n- Green: The curve is considered reasonable, and the P&L (profit and loss) is small.", |
| 271 | + "sources": [ |
| 272 | + "<2654330.1075846153519.JavaMail.evans@thyme>", |
| 273 | + "<12499967.1075846153495.JavaMail.evans@thyme>" |
| 274 | + ] |
| 275 | + } |
| 276 | +} |
| 277 | +``` |
| 278 | + |
| 279 | +评判器评估上面的答案为正确。 |
0 commit comments