forked from HKUDS/Vibe-Trading
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathloader.py
More file actions
262 lines (207 loc) · 9.04 KB
/
loader.py
File metadata and controls
262 lines (207 loc) · 9.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
"""Structured agent config loading utilities."""
from __future__ import annotations
import json
import logging
import os
from pathlib import Path
from typing import Any, Mapping
from pydantic import ValidationError
from src.config.paths import get_config_path
from src.config.schema import AgentConfig, AgentConfigOverride, MCPServerConfig
logger = logging.getLogger(__name__)
try:
import yaml
except ImportError:
yaml = None # type: ignore
def load_agent_config(config_path: Path | None = None) -> AgentConfig:
"""Load structured agent config from disk with safe fallback.
Args:
config_path: Optional explicit config path. When omitted, the default
config discovery path is used.
Returns:
The validated agent config. Invalid or unreadable config files fall
back to ``AgentConfig()``.
"""
path = get_config_path(config_path)
if not path.exists():
return AgentConfig()
try:
raw = _read_config_file(path)
return AgentConfig.model_validate(raw)
except (OSError, ValueError, ValidationError) as exc:
logger.warning(
"Failed to load agent config from %s: %s",
path,
type(exc).__name__,
)
logger.debug("Agent config load error details: %s", exc)
return AgentConfig()
def merge_agent_config_overrides(
config: AgentConfig,
overrides: Mapping[str, Any] | None,
) -> AgentConfig:
"""Merge runtime overrides on top of a base config.
Overrides are validated against a partial schema first so both snake_case
and camelCase keys are accepted while only explicitly provided fields
override the base config.
Args:
config: Base agent config loaded from disk or defaults.
overrides: Runtime overrides, typically from session-level config.
Returns:
A new validated config containing the merged result.
"""
if not overrides:
return config
try:
override_model = AgentConfigOverride.model_validate(dict(overrides))
except ValidationError as exc:
logger.warning(
"Ignoring invalid agent config overrides (%s): %s — using base config",
type(exc).__name__,
[str(e["loc"]) for e in exc.errors()],
)
return config
merged = _merge_agent_config_dicts(
config.model_dump(mode="json"),
override_model.model_dump(mode="json", exclude_unset=True),
)
try:
return AgentConfig.model_validate(merged)
except ValidationError as exc:
logger.warning(
"Ignoring merged agent config overrides after validation failure (%s): %s — using base config",
type(exc).__name__,
[str(e["loc"]) for e in exc.errors()],
)
return config
# Keys in session overrides that carry subprocess definitions and therefore
# require operator-level trust rather than API-caller trust.
_SESSION_RESTRICTED_KEYS: frozenset[str] = frozenset({"mcpServers", "mcp_servers"})
def sanitize_session_overrides(overrides: Mapping[str, Any]) -> dict[str, Any]:
"""Strip operator-only keys from API-caller-supplied session overrides.
``mcpServers`` / ``mcp_servers`` define subprocess ``command``/``args``/``env``
and therefore grant execution-level capabilities. They must originate from
the operator-controlled config file on disk, not from unauthenticated or
semi-trusted API callers. Operators who deliberately want to allow session-
level MCP injection can set ``ALLOW_SESSION_MCP_SERVERS=1``.
Args:
overrides: Raw session config dict received from the API caller.
Returns:
A new dict with restricted keys removed (or the original mapping
converted to dict if the env opt-in is active).
"""
if os.environ.get("ALLOW_SESSION_MCP_SERVERS", "").strip().lower() in {"1", "true", "yes"}:
return dict(overrides)
restricted_present = _SESSION_RESTRICTED_KEYS & overrides.keys()
if restricted_present:
logger.warning(
"Stripped %s from session config overrides: MCP server definitions "
"require operator-level trust (disk config). "
"Set ALLOW_SESSION_MCP_SERVERS=1 to allow session-level injection.",
sorted(restricted_present),
)
return {k: v for k, v in overrides.items() if k not in _SESSION_RESTRICTED_KEYS}
def load_runtime_agent_config(
config_path: Path | None = None,
overrides: Mapping[str, Any] | None = None,
) -> AgentConfig:
"""Load disk config and apply runtime overrides.
Args:
config_path: Optional explicit config file path.
overrides: Runtime override mapping applied on top of file-based config.
Returns:
The merged runtime config.
"""
config = load_agent_config(config_path)
return merge_agent_config_overrides(config, overrides)
def _read_config_file(path: Path) -> dict[str, Any]:
"""Read a supported config file format into a dictionary.
Args:
path: Config file path to decode.
Returns:
The decoded config object as a dictionary.
Raises:
ValueError: If the file format is unsupported, YAML support is
unavailable, or the decoded payload is not an object.
"""
suffix = path.suffix.lower()
text = path.read_text(encoding="utf-8")
if suffix == ".json":
data = json.loads(text)
elif suffix in {".yaml", ".yml"}:
if yaml is None:
raise ValueError("YAML config is not available because PyYAML is missing")
data = yaml.safe_load(text) or {}
else:
raise ValueError(f"Unsupported config file format: {suffix or '<none>'}")
if not isinstance(data, dict):
raise ValueError("Agent config must decode to a JSON/YAML object")
return data
def _merge_agent_config_dicts(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
"""Merge top-level agent config payloads with MCP-aware server replacement."""
non_mcp_override = {key: value for key, value in override.items() if key != "mcp_servers"}
merged = _merge_dicts(base, non_mcp_override)
override_servers = override.get("mcp_servers")
if not isinstance(override_servers, dict):
if "mcp_servers" in override:
merged["mcp_servers"] = override_servers
return merged
merged_servers = dict(base.get("mcp_servers", {}))
for server_name, server_override in override_servers.items():
current_server = merged_servers.get(server_name)
if isinstance(current_server, dict) and isinstance(server_override, dict):
merged_servers[server_name] = _merge_mcp_server_dicts(current_server, server_override)
else:
merged_servers[server_name] = server_override
merged["mcp_servers"] = merged_servers
return merged
def _merge_mcp_server_dicts(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
"""Merge one MCP server payload, resetting incompatible transport fields when needed."""
if _override_switches_transport(base, override):
return _merge_dicts(_default_mcp_server_payload(base), override)
return _merge_dicts(base, override)
def _override_switches_transport(base: dict[str, Any], override: dict[str, Any]) -> bool:
"""Return whether a partial override changes the server transport family."""
override_transport = _resolve_override_transport(override)
if override_transport is None:
return False
base_transport = MCPServerConfig.model_validate(base)._resolved_transport()
return override_transport != base_transport
def _resolve_override_transport(override: dict[str, Any]) -> str | None:
"""Infer transport intent from a partial MCP server override."""
explicit_type = override.get("type")
if explicit_type in {"stdio", "sse", "streamableHttp"}:
return str(explicit_type)
if any(key in override for key in ("command", "args", "env")):
return "stdio"
return None
def _default_mcp_server_payload(base: dict[str, Any]) -> dict[str, Any]:
"""Return a transport-neutral MCP server payload preserving non-transport defaults."""
enabled_tools = base.get("enabled_tools")
return {
"type": None,
"command": "",
"args": [],
"env": {},
"url": "",
"headers": {},
"tool_timeout": base.get("tool_timeout", 30.0),
"enabled_tools": list(enabled_tools) if isinstance(enabled_tools, list) else ["*"],
}
def _merge_dicts(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
"""Recursively merge two plain dictionaries.
Args:
base: Base dictionary.
override: Override dictionary applied on top of ``base``.
Returns:
A merged dictionary where nested mappings are merged recursively and
scalar values from ``override`` replace those in ``base``.
"""
merged = dict(base)
for key, value in override.items():
current = merged.get(key)
if isinstance(current, dict) and isinstance(value, dict):
merged[key] = _merge_dicts(current, value)
else:
merged[key] = value
return merged