|
| 1 | +"""ContextBuilder - GSSC流水线实现 |
| 2 | +
|
| 3 | +实现 Gather-Select-Structure-Compress 上下文构建流程: |
| 4 | +1. Gather: 从多源收集候选信息(历史、记忆、RAG、工具结果) |
| 5 | +2. Select: 基于优先级、相关性、多样性筛选 |
| 6 | +3. Structure: 组织成结构化上下文模板 |
| 7 | +4. Compress: 在预算内压缩与规范化 |
| 8 | +""" |
| 9 | + |
| 10 | +from typing import Dict, Any, List, Optional, Tuple |
| 11 | +from dataclasses import dataclass, field |
| 12 | +from datetime import datetime |
| 13 | +import tiktoken |
| 14 | +import math |
| 15 | + |
| 16 | +from ..core.message import Message |
| 17 | +from ..tools import MemoryTool, RAGTool |
| 18 | + |
| 19 | + |
| 20 | +@dataclass |
| 21 | +class ContextPacket: |
| 22 | + """上下文信息包""" |
| 23 | + content: str |
| 24 | + timestamp: datetime = field(default_factory=datetime.now) |
| 25 | + metadata: Dict[str, Any] = field(default_factory=dict) |
| 26 | + token_count: int = 0 |
| 27 | + relevance_score: float = 0.0 # 0.0-1.0 |
| 28 | + |
| 29 | + def __post_init__(self): |
| 30 | + """自动计算token数""" |
| 31 | + if self.token_count == 0: |
| 32 | + self.token_count = count_tokens(self.content) |
| 33 | + |
| 34 | + |
| 35 | +@dataclass |
| 36 | +class ContextConfig: |
| 37 | + """上下文构建配置""" |
| 38 | + max_tokens: int = 8000 # 总预算 |
| 39 | + reserve_ratio: float = 0.15 # 生成余量(10-20%) |
| 40 | + min_relevance: float = 0.3 # 最小相关性阈值 |
| 41 | + enable_mmr: bool = True # 启用最大边际相关性(多样性) |
| 42 | + mmr_lambda: float = 0.7 # MMR平衡参数(0=纯多样性, 1=纯相关性) |
| 43 | + system_prompt_template: str = "" # 系统提示模板 |
| 44 | + enable_compression: bool = True # 启用压缩 |
| 45 | + |
| 46 | + def get_available_tokens(self) -> int: |
| 47 | + """获取可用token预算(扣除余量)""" |
| 48 | + return int(self.max_tokens * (1 - self.reserve_ratio)) |
| 49 | + |
| 50 | + |
| 51 | +class ContextBuilder: |
| 52 | + """上下文构建器 - GSSC流水线 |
| 53 | + |
| 54 | + 用法示例: |
| 55 | + ```python |
| 56 | + builder = ContextBuilder( |
| 57 | + memory_tool=memory_tool, |
| 58 | + rag_tool=rag_tool, |
| 59 | + config=ContextConfig(max_tokens=8000) |
| 60 | + ) |
| 61 | + |
| 62 | + context = builder.build( |
| 63 | + user_query="用户问题", |
| 64 | + conversation_history=[...], |
| 65 | + system_instructions="系统指令" |
| 66 | + ) |
| 67 | + ``` |
| 68 | + """ |
| 69 | + |
| 70 | + def __init__( |
| 71 | + self, |
| 72 | + memory_tool: Optional[MemoryTool] = None, |
| 73 | + rag_tool: Optional[RAGTool] = None, |
| 74 | + config: Optional[ContextConfig] = None |
| 75 | + ): |
| 76 | + self.memory_tool = memory_tool |
| 77 | + self.rag_tool = rag_tool |
| 78 | + self.config = config or ContextConfig() |
| 79 | + self._encoding = tiktoken.get_encoding("cl100k_base") |
| 80 | + |
| 81 | + def build( |
| 82 | + self, |
| 83 | + user_query: str, |
| 84 | + conversation_history: Optional[List[Message]] = None, |
| 85 | + system_instructions: Optional[str] = None, |
| 86 | + additional_packets: Optional[List[ContextPacket]] = None |
| 87 | + ) -> str: |
| 88 | + """构建完整上下文 |
| 89 | + |
| 90 | + Args: |
| 91 | + user_query: 用户查询 |
| 92 | + conversation_history: 对话历史 |
| 93 | + system_instructions: 系统指令 |
| 94 | + additional_packets: 额外的上下文包 |
| 95 | + |
| 96 | + Returns: |
| 97 | + 结构化上下文字符串 |
| 98 | + """ |
| 99 | + # 1. Gather: 收集候选信息 |
| 100 | + packets = self._gather( |
| 101 | + user_query=user_query, |
| 102 | + conversation_history=conversation_history or [], |
| 103 | + system_instructions=system_instructions, |
| 104 | + additional_packets=additional_packets or [] |
| 105 | + ) |
| 106 | + |
| 107 | + # 2. Select: 筛选与排序 |
| 108 | + selected_packets = self._select(packets, user_query) |
| 109 | + |
| 110 | + # 3. Structure: 组织成结构化模板 |
| 111 | + structured_context = self._structure( |
| 112 | + selected_packets=selected_packets, |
| 113 | + user_query=user_query, |
| 114 | + system_instructions=system_instructions |
| 115 | + ) |
| 116 | + |
| 117 | + # 4. Compress: 压缩与规范化(如果超预算) |
| 118 | + final_context = self._compress(structured_context) |
| 119 | + |
| 120 | + return final_context |
| 121 | + |
| 122 | + def _gather( |
| 123 | + self, |
| 124 | + user_query: str, |
| 125 | + conversation_history: List[Message], |
| 126 | + system_instructions: Optional[str], |
| 127 | + additional_packets: List[ContextPacket] |
| 128 | + ) -> List[ContextPacket]: |
| 129 | + """Gather: 收集候选信息""" |
| 130 | + packets = [] |
| 131 | + |
| 132 | + # P0: 系统指令(强约束) |
| 133 | + if system_instructions: |
| 134 | + packets.append(ContextPacket( |
| 135 | + content=system_instructions, |
| 136 | + metadata={"type": "instructions"} |
| 137 | + )) |
| 138 | + |
| 139 | + # P1: 从记忆中获取任务状态与关键结论 |
| 140 | + if self.memory_tool: |
| 141 | + try: |
| 142 | + # 搜索任务状态相关记忆 |
| 143 | + state_results = self.memory_tool.execute( |
| 144 | + "search", |
| 145 | + query="(任务状态 OR 子目标 OR 结论 OR 阻塞)", |
| 146 | + min_importance=0.7, |
| 147 | + limit=5 |
| 148 | + ) |
| 149 | + if state_results and "未找到" not in state_results: |
| 150 | + packets.append(ContextPacket( |
| 151 | + content=state_results, |
| 152 | + metadata={"type": "task_state", "importance": "high"} |
| 153 | + )) |
| 154 | + |
| 155 | + # 搜索与当前查询相关的记忆 |
| 156 | + related_results = self.memory_tool.execute( |
| 157 | + "search", |
| 158 | + query=user_query, |
| 159 | + limit=5 |
| 160 | + ) |
| 161 | + if related_results and "未找到" not in related_results: |
| 162 | + packets.append(ContextPacket( |
| 163 | + content=related_results, |
| 164 | + metadata={"type": "related_memory"} |
| 165 | + )) |
| 166 | + except Exception as e: |
| 167 | + print(f"⚠️ 记忆检索失败: {e}") |
| 168 | + |
| 169 | + # P2: 从RAG中获取事实证据 |
| 170 | + if self.rag_tool: |
| 171 | + try: |
| 172 | + rag_results = self.rag_tool.run({ |
| 173 | + "action": "search", |
| 174 | + "query": user_query, |
| 175 | + "top_k": 5 |
| 176 | + }) |
| 177 | + if rag_results and "未找到" not in rag_results and "错误" not in rag_results: |
| 178 | + packets.append(ContextPacket( |
| 179 | + content=rag_results, |
| 180 | + metadata={"type": "knowledge_base"} |
| 181 | + )) |
| 182 | + except Exception as e: |
| 183 | + print(f"⚠️ RAG检索失败: {e}") |
| 184 | + |
| 185 | + # P3: 对话历史(辅助材料) |
| 186 | + if conversation_history: |
| 187 | + # 只保留最近N条 |
| 188 | + recent_history = conversation_history[-10:] |
| 189 | + history_text = "\n".join([ |
| 190 | + f"[{msg.role}] {msg.content}" |
| 191 | + for msg in recent_history |
| 192 | + ]) |
| 193 | + packets.append(ContextPacket( |
| 194 | + content=history_text, |
| 195 | + metadata={"type": "history", "count": len(recent_history)} |
| 196 | + )) |
| 197 | + |
| 198 | + # 添加额外包 |
| 199 | + packets.extend(additional_packets) |
| 200 | + |
| 201 | + return packets |
| 202 | + |
| 203 | + def _select( |
| 204 | + self, |
| 205 | + packets: List[ContextPacket], |
| 206 | + user_query: str |
| 207 | + ) -> List[ContextPacket]: |
| 208 | + """Select: 基于分数与预算的筛选""" |
| 209 | + # 1) 计算相关性(关键词重叠) |
| 210 | + query_tokens = set(user_query.lower().split()) |
| 211 | + for packet in packets: |
| 212 | + content_tokens = set(packet.content.lower().split()) |
| 213 | + if len(query_tokens) > 0: |
| 214 | + overlap = len(query_tokens & content_tokens) |
| 215 | + packet.relevance_score = overlap / len(query_tokens) |
| 216 | + else: |
| 217 | + packet.relevance_score = 0.0 |
| 218 | + |
| 219 | + # 2) 计算新近性(指数衰减) |
| 220 | + def recency_score(ts: datetime) -> float: |
| 221 | + delta = max((datetime.now() - ts).total_seconds(), 0) |
| 222 | + tau = 3600 # 1小时时间尺度,可暴露到配置 |
| 223 | + return math.exp(-delta / tau) |
| 224 | + |
| 225 | + # 3) 计算复合分:0.7*相关性 + 0.3*新近性 |
| 226 | + scored_packets: List[Tuple[float, ContextPacket]] = [] |
| 227 | + for p in packets: |
| 228 | + rec = recency_score(p.timestamp) |
| 229 | + score = 0.7 * p.relevance_score + 0.3 * rec |
| 230 | + scored_packets.append((score, p)) |
| 231 | + |
| 232 | + # 4) 系统指令单独拿出,固定纳入 |
| 233 | + system_packets = [p for (_, p) in scored_packets if p.metadata.get("type") == "instructions"] |
| 234 | + remaining = [p for (s, p) in sorted(scored_packets, key=lambda x: x[0], reverse=True) |
| 235 | + if p.metadata.get("type") != "instructions"] |
| 236 | + |
| 237 | + # 5) 依据 min_relevance 过滤(对非系统包) |
| 238 | + filtered = [p for p in remaining if p.relevance_score >= self.config.min_relevance] |
| 239 | + |
| 240 | + # 6) 按预算填充 |
| 241 | + available_tokens = self.config.get_available_tokens() |
| 242 | + selected: List[ContextPacket] = [] |
| 243 | + used_tokens = 0 |
| 244 | + |
| 245 | + # 先放入系统指令(不排序) |
| 246 | + for p in system_packets: |
| 247 | + if used_tokens + p.token_count <= available_tokens: |
| 248 | + selected.append(p) |
| 249 | + used_tokens += p.token_count |
| 250 | + |
| 251 | + # 再按分数加入其余 |
| 252 | + for p in filtered: |
| 253 | + if used_tokens + p.token_count > available_tokens: |
| 254 | + continue |
| 255 | + selected.append(p) |
| 256 | + used_tokens += p.token_count |
| 257 | + |
| 258 | + return selected |
| 259 | + |
| 260 | + def _structure( |
| 261 | + self, |
| 262 | + selected_packets: List[ContextPacket], |
| 263 | + user_query: str, |
| 264 | + system_instructions: Optional[str] |
| 265 | + ) -> str: |
| 266 | + """Structure: 组织成结构化上下文模板""" |
| 267 | + sections = [] |
| 268 | + |
| 269 | + # [Role & Policies] - 系统指令 |
| 270 | + p0_packets = [p for p in selected_packets if p.metadata.get("type") == "instructions"] |
| 271 | + if p0_packets: |
| 272 | + role_section = "[Role & Policies]\n" |
| 273 | + role_section += "\n".join([p.content for p in p0_packets]) |
| 274 | + sections.append(role_section) |
| 275 | + |
| 276 | + # [Task] - 当前任务 |
| 277 | + sections.append(f"[Task]\n用户问题:{user_query}") |
| 278 | + |
| 279 | + # [State] - 任务状态 |
| 280 | + p1_packets = [p for p in selected_packets if p.metadata.get("type") == "task_state"] |
| 281 | + if p1_packets: |
| 282 | + state_section = "[State]\n关键进展与未决问题:\n" |
| 283 | + state_section += "\n".join([p.content for p in p1_packets]) |
| 284 | + sections.append(state_section) |
| 285 | + |
| 286 | + # [Evidence] - 事实证据 |
| 287 | + p2_packets = [ |
| 288 | + p for p in selected_packets |
| 289 | + if p.metadata.get("type") in {"related_memory", "knowledge_base", "retrieval", "tool_result"} |
| 290 | + ] |
| 291 | + if p2_packets: |
| 292 | + evidence_section = "[Evidence]\n事实与引用:\n" |
| 293 | + for p in p2_packets: |
| 294 | + evidence_section += f"\n{p.content}\n" |
| 295 | + sections.append(evidence_section) |
| 296 | + |
| 297 | + # [Context] - 辅助材料(历史等) |
| 298 | + p3_packets = [p for p in selected_packets if p.metadata.get("type") == "history"] |
| 299 | + if p3_packets: |
| 300 | + context_section = "[Context]\n对话历史与背景:\n" |
| 301 | + context_section += "\n".join([p.content for p in p3_packets]) |
| 302 | + sections.append(context_section) |
| 303 | + |
| 304 | + # [Output] - 输出约束 |
| 305 | + output_section = """[Output] |
| 306 | + 请按以下格式回答: |
| 307 | + 1. 结论(简洁明确) |
| 308 | + 2. 依据(列出支撑证据及来源) |
| 309 | + 3. 风险与假设(如有) |
| 310 | + 4. 下一步行动建议(如适用)""" |
| 311 | + sections.append(output_section) |
| 312 | + |
| 313 | + return "\n\n".join(sections) |
| 314 | + |
| 315 | + def _compress(self, context: str) -> str: |
| 316 | + """Compress: 压缩与规范化""" |
| 317 | + if not self.config.enable_compression: |
| 318 | + return context |
| 319 | + |
| 320 | + current_tokens = count_tokens(context) |
| 321 | + available_tokens = self.config.get_available_tokens() |
| 322 | + |
| 323 | + if current_tokens <= available_tokens: |
| 324 | + return context |
| 325 | + |
| 326 | + # 简单截断策略(保留前N个token) |
| 327 | + # 实际应用中可用LLM做高保真摘要 |
| 328 | + print(f"⚠️ 上下文超预算 ({current_tokens} > {available_tokens}),执行截断") |
| 329 | + |
| 330 | + # 按段落截断,保留结构 |
| 331 | + lines = context.split("\n") |
| 332 | + compressed_lines = [] |
| 333 | + used_tokens = 0 |
| 334 | + |
| 335 | + for line in lines: |
| 336 | + line_tokens = count_tokens(line) |
| 337 | + if used_tokens + line_tokens > available_tokens: |
| 338 | + break |
| 339 | + compressed_lines.append(line) |
| 340 | + used_tokens += line_tokens |
| 341 | + |
| 342 | + return "\n".join(compressed_lines) |
| 343 | + |
| 344 | + |
| 345 | +def count_tokens(text: str) -> int: |
| 346 | + """计算文本token数(使用tiktoken)""" |
| 347 | + try: |
| 348 | + encoding = tiktoken.get_encoding("cl100k_base") |
| 349 | + return len(encoding.encode(text)) |
| 350 | + except Exception: |
| 351 | + # 降级方案:粗略估算(1 token ≈ 4 字符) |
| 352 | + return len(text) // 4 |
| 353 | + |
0 commit comments