forked from HKUDS/Vibe-Trading
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathservice.py
More file actions
390 lines (327 loc) · 15 KB
/
service.py
File metadata and controls
390 lines (327 loc) · 15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
"""Session lifecycle orchestration for message flow, attempt creation, and execution scheduling.
V5: Uses AgentLoop instead of the fixed pipeline behind the generate skill.
"""
from __future__ import annotations
import asyncio
import concurrent.futures
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional
# Dedicated thread pool limited to four concurrent agents to avoid exhausting the default executor.
_AGENT_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=4, thread_name_prefix="agent")
from src.session.events import EventBus
from src.session.models import (
Attempt,
AttemptStatus,
Message,
Session,
)
from src.session.search import get_shared_index
from src.session.store import SessionStore
class SessionService:
"""Session lifecycle service.
Attributes:
store: Session persistence store.
event_bus: SSE event bus.
runs_dir: Root runs directory.
"""
def __init__(
self,
store: SessionStore,
event_bus: EventBus,
runs_dir: Path,
) -> None:
"""Initialize the session service.
Args:
store: Session persistence store.
event_bus: SSE event bus.
runs_dir: Root runs directory.
"""
self.store = store
self.event_bus = event_bus
self.runs_dir = runs_dir
self._active_loops: Dict[str, "AgentLoop"] = {}
self._search_index = get_shared_index()
def create_session(self, title: str = "", config: Optional[Dict[str, Any]] = None) -> Session:
"""Create a new session.
Args:
title: Session title.
config: Session configuration.
Returns:
The newly created Session.
"""
session = Session(title=title, config=config or {})
self.store.create_session(session)
self._search_index.index_session(session.session_id, title)
self.event_bus.emit(session.session_id, "session.created", {"session_id": session.session_id, "title": title})
return session
def get_session(self, session_id: str) -> Optional[Session]:
"""Return a session by ID."""
return self.store.get_session(session_id)
def list_sessions(self, limit: int = 50) -> list[Session]:
"""List all sessions."""
return self.store.list_sessions(limit)
def delete_session(self, session_id: str) -> bool:
"""Delete a session."""
self.event_bus.clear(session_id)
return self.store.delete_session(session_id)
async def send_message(
self,
session_id: str,
content: str,
role: str = "user",
*,
include_shell_tools: bool = False,
) -> Dict[str, Any]:
"""Send a message to a session and trigger execution.
Args:
session_id: Session ID.
content: Message content.
role: Message role.
include_shell_tools: Whether this attempt may use shell tools.
Returns:
Dictionary containing message_id and attempt_id.
"""
session = self.store.get_session(session_id)
if not session:
raise ValueError(f"Session {session_id} not found")
message = Message(session_id=session_id, role=role, content=content)
self.store.append_message(message)
self._search_index.index_message(session_id, role, content)
self.event_bus.emit(session_id, "message.received", {"message_id": message.message_id, "role": role, "content": content})
if role != "user":
return {"message_id": message.message_id}
attempt = Attempt(session_id=session_id, parent_attempt_id=session.last_attempt_id, prompt=content)
self.store.create_attempt(attempt)
session.config["include_shell_tools"] = include_shell_tools
session.last_attempt_id = attempt.attempt_id
session.updated_at = datetime.now().isoformat()
self.store.update_session(session)
self.event_bus.emit(session_id, "attempt.created", {"attempt_id": attempt.attempt_id, "prompt": content})
asyncio.create_task(self._run_attempt(session, attempt, include_shell_tools=include_shell_tools))
return {"message_id": message.message_id, "attempt_id": attempt.attempt_id}
async def resume_attempt(self, session_id: str, attempt_id: str, user_input: str) -> Dict[str, Any]:
"""Resume an attempt that is waiting for user input.
Args:
session_id: Session ID.
attempt_id: Attempt ID.
user_input: User reply content.
Returns:
Dictionary containing status and attempt_id.
"""
session = self.store.get_session(session_id)
if not session:
raise ValueError(f"Session {session_id} not found")
attempt = self.store.get_attempt(session_id, attempt_id)
if not attempt:
raise ValueError(f"Attempt {attempt_id} not found")
if attempt.status != AttemptStatus.WAITING_USER:
raise ValueError(f"Attempt {attempt_id} is not waiting for user input")
message = Message(session_id=session_id, role="user", content=user_input, linked_attempt_id=attempt_id)
self.store.append_message(message)
# Append the user's reply to the prompt and rerun the attempt.
attempt.prompt = f"{attempt.prompt}\n\nUser reply: {user_input}"
attempt.status = AttemptStatus.RUNNING
self.store.update_attempt(attempt)
self.event_bus.emit(session_id, "attempt.resumed", {"attempt_id": attempt_id, "user_input": user_input})
include_shell_tools = bool(session.config.get("include_shell_tools", False))
asyncio.create_task(self._run_attempt(session, attempt, include_shell_tools=include_shell_tools))
return {"status": "resumed", "attempt_id": attempt_id}
def get_messages(self, session_id: str, limit: int = 100) -> list[Message]:
"""Return the message history."""
return self.store.get_messages(session_id, limit)
def get_attempts(self, session_id: str) -> list[Attempt]:
"""Return all execution attempts."""
return self.store.list_attempts(session_id)
def get_attempt(self, session_id: str, attempt_id: str) -> Optional[Attempt]:
"""Return a single execution attempt."""
return self.store.get_attempt(session_id, attempt_id)
def cancel_current(self, session_id: str) -> bool:
"""Cancel the currently running AgentLoop for a session.
Args:
session_id: Session ID.
Returns:
Whether cancellation succeeded. True means an active loop existed and received a cancel signal.
"""
loop = self._active_loops.get(session_id)
if loop is None:
return False
loop.cancel()
return True
async def _run_attempt(self, session: Session, attempt: Attempt, *, include_shell_tools: bool = False) -> None:
"""Execute an Attempt in the background."""
attempt.mark_running()
self.store.update_attempt(attempt)
self.event_bus.emit(session.session_id, "attempt.started", {"attempt_id": attempt.attempt_id})
try:
messages = self.store.get_messages(session.session_id)
result = await self._run_with_agent(
attempt,
messages=messages,
include_shell_tools=include_shell_tools,
session_config=dict(session.config),
)
if result.get("status") == "success":
attempt.mark_completed(summary=result.get("content", ""))
else:
attempt.mark_failed(error=result.get("reason", "unknown"))
attempt.run_dir = result.get("run_dir")
self.store.update_attempt(attempt)
reply_metadata = {}
if attempt.run_dir:
reply_metadata["run_id"] = Path(attempt.run_dir).name
reply_metadata["status"] = attempt.status.value
if attempt.metrics:
reply_metadata["metrics"] = attempt.metrics
reply = Message(
session_id=session.session_id, role="assistant",
content=self._format_result_message(attempt),
linked_attempt_id=attempt.attempt_id,
metadata=reply_metadata,
)
self.store.append_message(reply)
self._search_index.index_message(session.session_id, "assistant", reply.content)
self.event_bus.emit(
session.session_id,
"attempt.completed" if attempt.status == AttemptStatus.COMPLETED else "attempt.failed",
{"attempt_id": attempt.attempt_id, "status": attempt.status.value,
"summary": attempt.summary, "error": attempt.error, "run_dir": attempt.run_dir},
)
except Exception as exc:
attempt.mark_failed(error=str(exc))
self.store.update_attempt(attempt)
self.event_bus.emit(session.session_id, "attempt.failed", {"attempt_id": attempt.attempt_id, "error": str(exc)})
async def _run_with_agent(
self,
attempt: Attempt,
messages: list = None,
*,
include_shell_tools: bool = False,
session_config: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Execute an attempt with the V5 AgentLoop.
Args:
attempt: Current execution attempt.
messages: Session message history.
include_shell_tools: Whether the registry may include shell tools.
session_config: Optional session-level config overrides. MCP server
definitions under the ``mcpServers`` key are merged on top of
the user config file via ``load_runtime_agent_config`` so each
session can extend or override the global MCP server list.
Returns:
Result dictionary containing status, run_dir, run_id, metrics, and related fields.
"""
from src.tools import build_registry
from src.providers.chat import ChatLLM
from src.agent.loop import AgentLoop
from src.memory.persistent import PersistentMemory
from src.config.loader import load_runtime_agent_config, sanitize_session_overrides
llm = ChatLLM()
pm = PersistentMemory()
session_id = attempt.session_id
attempt_id = attempt.attempt_id
loop = asyncio.get_running_loop()
safe_overrides = sanitize_session_overrides(session_config) if session_config else session_config
agent_config = load_runtime_agent_config(overrides=safe_overrides)
def event_callback(event_type: str, data: Dict[str, Any]) -> None:
"""Forward AgentLoop events to the SSE event bus."""
data["attempt_id"] = attempt_id
self.event_bus.emit(session_id, event_type, data)
def _mcp_collision_warn(msg: str) -> None:
"""Forward MCP server-name collision warnings to the operator event channel."""
self.event_bus.emit(session_id, "mcp.warning", {"attempt_id": attempt_id, "message": msg})
registry = await loop.run_in_executor(
_AGENT_EXECUTOR,
lambda: build_registry(
persistent_memory=pm,
include_shell_tools=include_shell_tools,
agent_config=agent_config,
warn_callback=_mcp_collision_warn,
),
)
agent = AgentLoop(
registry=registry,
llm=llm,
event_callback=event_callback,
max_iterations=50,
persistent_memory=pm,
)
self._active_loops[session_id] = agent
# Build the message history context.
history = self._convert_messages_to_history(messages) if messages else None
try:
result = await loop.run_in_executor(
_AGENT_EXECUTOR,
lambda: agent.run(
user_message=attempt.prompt,
history=history,
session_id=session_id,
),
)
finally:
self._active_loops.pop(session_id, None)
# Load metrics from the run output when available.
if result.get("run_dir"):
metrics = self._load_metrics(Path(result["run_dir"]))
if metrics:
result["metrics"] = metrics
return result
@staticmethod
def _convert_messages_to_history(messages: list) -> list[Dict[str, Any]]:
"""Convert Session messages into OpenAI-format history.
Keeps the readable ``[prev_run: {run_id}]`` marker instead of removing it
completely, and trims by character budget instead of a hard six-message cap
so the LLM can still see previous artifact paths and strategy content during
iterative updates.
Args:
messages: Session message list without the current turn.
Returns:
OpenAI-format messages trimmed from the newest items within the token budget.
"""
import re
from pathlib import Path
def _shorten_run_dir(match: re.Match) -> str:
path_str = match.group(0).replace("Run directory:", "").strip()
run_id = Path(path_str).name if path_str else ""
return f"[prev_run: {run_id}]" if run_id else ""
history = []
for msg in messages[:-1]:
role = msg.role if hasattr(msg, "role") else msg.get("role", "user")
content = msg.content if hasattr(msg, "content") else msg.get("content", "")
if not content.strip() or role not in ("user", "assistant"):
continue
content = re.sub(r"Run directory:\s*\S+", _shorten_run_dir, content).strip()
if content:
history.append({"role": role, "content": content})
# Trim from the newest messages within a character budget of roughly 3000 tokens.
MAX_HISTORY_CHARS = 12000
total_chars = 0
trimmed: list = []
for msg in reversed(history):
msg_len = len(msg.get("content", ""))
if total_chars + msg_len > MAX_HISTORY_CHARS:
break
trimmed.append(msg)
total_chars += msg_len
return list(reversed(trimmed))
@staticmethod
def _load_metrics(run_dir: Path) -> Optional[Dict[str, Any]]:
"""Load metrics.csv from a run directory."""
import csv
metrics_path = run_dir / "artifacts" / "metrics.csv"
if not metrics_path.exists():
return None
try:
with open(metrics_path, "r", encoding="utf-8") as f:
rows = list(csv.DictReader(f))
if rows:
return {k: float(v) for k, v in rows[0].items() if v}
except Exception:
pass
return None
@staticmethod
def _format_result_message(attempt: Attempt) -> str:
"""Format the final execution result message."""
if attempt.status == AttemptStatus.COMPLETED:
return attempt.summary or "Strategy execution completed."
return f"Execution failed: {attempt.error or 'unknown error'}"