|
11 | 11 | from pydantic import ValidationError |
12 | 12 |
|
13 | 13 | from src.config.paths import get_config_path |
14 | | -from src.config.schema import AgentConfig, AgentConfigOverride |
| 14 | +from src.config.schema import AgentConfig, AgentConfigOverride, MCPServerConfig |
15 | 15 |
|
16 | 16 | logger = logging.getLogger(__name__) |
17 | 17 |
|
@@ -80,11 +80,19 @@ def merge_agent_config_overrides( |
80 | 80 | ) |
81 | 81 | return config |
82 | 82 |
|
83 | | - merged = _merge_dicts( |
| 83 | + merged = _merge_agent_config_dicts( |
84 | 84 | config.model_dump(mode="json"), |
85 | 85 | override_model.model_dump(mode="json", exclude_unset=True), |
86 | 86 | ) |
87 | | - return AgentConfig.model_validate(merged) |
| 87 | + try: |
| 88 | + return AgentConfig.model_validate(merged) |
| 89 | + except ValidationError as exc: |
| 90 | + logger.warning( |
| 91 | + "Ignoring merged agent config overrides after validation failure (%s): %s — using base config", |
| 92 | + type(exc).__name__, |
| 93 | + [str(e["loc"]) for e in exc.errors()], |
| 94 | + ) |
| 95 | + return config |
88 | 96 |
|
89 | 97 |
|
90 | 98 | # Keys in session overrides that carry subprocess definitions and therefore |
@@ -169,6 +177,70 @@ def _read_config_file(path: Path) -> dict[str, Any]: |
169 | 177 | return data |
170 | 178 |
|
171 | 179 |
|
| 180 | +def _merge_agent_config_dicts(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: |
| 181 | + """Merge top-level agent config payloads with MCP-aware server replacement.""" |
| 182 | + non_mcp_override = {key: value for key, value in override.items() if key != "mcp_servers"} |
| 183 | + merged = _merge_dicts(base, non_mcp_override) |
| 184 | + |
| 185 | + override_servers = override.get("mcp_servers") |
| 186 | + if not isinstance(override_servers, dict): |
| 187 | + if "mcp_servers" in override: |
| 188 | + merged["mcp_servers"] = override_servers |
| 189 | + return merged |
| 190 | + |
| 191 | + merged_servers = dict(base.get("mcp_servers", {})) |
| 192 | + for server_name, server_override in override_servers.items(): |
| 193 | + current_server = merged_servers.get(server_name) |
| 194 | + if isinstance(current_server, dict) and isinstance(server_override, dict): |
| 195 | + merged_servers[server_name] = _merge_mcp_server_dicts(current_server, server_override) |
| 196 | + else: |
| 197 | + merged_servers[server_name] = server_override |
| 198 | + |
| 199 | + merged["mcp_servers"] = merged_servers |
| 200 | + return merged |
| 201 | + |
| 202 | + |
| 203 | +def _merge_mcp_server_dicts(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: |
| 204 | + """Merge one MCP server payload, resetting incompatible transport fields when needed.""" |
| 205 | + if _override_switches_transport(base, override): |
| 206 | + return _merge_dicts(_default_mcp_server_payload(base), override) |
| 207 | + return _merge_dicts(base, override) |
| 208 | + |
| 209 | + |
| 210 | +def _override_switches_transport(base: dict[str, Any], override: dict[str, Any]) -> bool: |
| 211 | + """Return whether a partial override changes the server transport family.""" |
| 212 | + override_transport = _resolve_override_transport(override) |
| 213 | + if override_transport is None: |
| 214 | + return False |
| 215 | + base_transport = MCPServerConfig.model_validate(base).resolved_transport() |
| 216 | + return override_transport != base_transport |
| 217 | + |
| 218 | + |
| 219 | +def _resolve_override_transport(override: dict[str, Any]) -> str | None: |
| 220 | + """Infer transport intent from a partial MCP server override.""" |
| 221 | + explicit_type = override.get("type") |
| 222 | + if explicit_type in {"stdio", "sse", "streamableHttp"}: |
| 223 | + return str(explicit_type) |
| 224 | + if any(key in override for key in ("command", "args", "env")): |
| 225 | + return "stdio" |
| 226 | + return None |
| 227 | + |
| 228 | + |
| 229 | +def _default_mcp_server_payload(base: dict[str, Any]) -> dict[str, Any]: |
| 230 | + """Return a transport-neutral MCP server payload preserving non-transport defaults.""" |
| 231 | + enabled_tools = base.get("enabled_tools") |
| 232 | + return { |
| 233 | + "type": None, |
| 234 | + "command": "", |
| 235 | + "args": [], |
| 236 | + "env": {}, |
| 237 | + "url": "", |
| 238 | + "headers": {}, |
| 239 | + "tool_timeout": base.get("tool_timeout", 30.0), |
| 240 | + "enabled_tools": list(enabled_tools) if isinstance(enabled_tools, list) else ["*"], |
| 241 | + } |
| 242 | + |
| 243 | + |
172 | 244 | def _merge_dicts(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: |
173 | 245 | """Recursively merge two plain dictionaries. |
174 | 246 |
|
|
0 commit comments