Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions tests/tools/test_mcp_tool_issue_948.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
from unittest.mock import patch

from tools.mcp_tool import _format_connect_error, _resolve_stdio_command


def test_resolve_stdio_command_falls_back_to_hermes_node_bin(tmp_path):
node_bin = tmp_path / "node" / "bin"
node_bin.mkdir(parents=True)
npx_path = node_bin / "npx"
npx_path.write_text("#!/bin/sh\nexit 0\n", encoding="utf-8")
npx_path.chmod(0o755)

with patch.dict("os.environ", {"HERMES_HOME": str(tmp_path)}, clear=False):
command, env = _resolve_stdio_command("npx", {"PATH": "/usr/bin"})

assert command == str(npx_path)
assert env["PATH"].split(os.pathsep)[0] == str(node_bin)


def test_resolve_stdio_command_respects_explicit_empty_path():
seen_paths = []

def _fake_which(_cmd, path=None):
seen_paths.append(path)
if path is None:
return "/usr/bin/python"
return None

with patch("tools.mcp_tool.shutil.which", side_effect=_fake_which):
command, env = _resolve_stdio_command("python", {"PATH": ""})

assert command == "python"
assert env["PATH"] == ""
assert seen_paths == [""]


def test_format_connect_error_unwraps_exception_group():
error = ExceptionGroup(
"unhandled errors in a TaskGroup",
[FileNotFoundError(2, "No such file or directory", "node")],
)

message = _format_connect_error(error)

assert "missing executable 'node'" in message
94 changes: 92 additions & 2 deletions tools/mcp_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
import math
import os
import re
import shutil
import threading
import time
from typing import Any, Dict, List, Optional
Expand Down Expand Up @@ -176,6 +177,91 @@ def _sanitize_error(text: str) -> str:
return _CREDENTIAL_PATTERN.sub("[REDACTED]", text)


def _resolve_stdio_command(command: str, env: dict) -> tuple[str, dict]:
resolved_command = os.path.expanduser(str(command).strip())
resolved_env = dict(env or {})

if os.sep not in resolved_command:
path_arg = resolved_env["PATH"] if "PATH" in resolved_env else None
which_hit = shutil.which(resolved_command, path=path_arg)
if which_hit:
resolved_command = which_hit
elif resolved_command in {"npx", "npm", "node"}:
hermes_home = os.path.expanduser(
os.getenv(
"HERMES_HOME", os.path.join(os.path.expanduser("~"), ".hermes")
)
)
candidates = [
os.path.join(hermes_home, "node", "bin", resolved_command),
os.path.join(
os.path.expanduser("~"), ".local", "bin", resolved_command
),
]
for candidate in candidates:
if os.path.isfile(candidate) and os.access(candidate, os.X_OK):
resolved_command = candidate
break

command_dir = os.path.dirname(resolved_command)
if command_dir:
parts = [p for p in (resolved_env.get("PATH") or "").split(os.pathsep) if p]
resolved_env["PATH"] = (
os.pathsep.join(parts)
if command_dir in parts
else os.pathsep.join([command_dir, *parts]) if parts else command_dir
)

return resolved_command, resolved_env


def _format_connect_error(exc: BaseException) -> str:
def _find_missing(current: BaseException) -> Optional[str]:
nested = getattr(current, "exceptions", None)
if nested:
for child in nested:
missing = _find_missing(child)
if missing:
return missing
return None
if isinstance(current, FileNotFoundError):
if getattr(current, "filename", None):
return str(current.filename)
message = str(current)
match = re.search(r"No such file or directory: '([^']+)'", message)
if match:
return match.group(1)
return None

def _flatten_messages(current: BaseException) -> List[str]:
nested = getattr(current, "exceptions", None)
if nested:
flattened: List[str] = []
for child in nested:
flattened.extend(_flatten_messages(child))
return flattened
text = str(current).strip()
return [text or current.__class__.__name__]

missing = _find_missing(exc)
if missing:
message = f"missing executable '{missing}'"
if os.path.basename(missing) in {"npx", "npm", "node"}:
message += (
" (ensure Node.js is installed and PATH includes its bin directory, "
"or set mcp_servers.<name>.command to an absolute path and include "
"that directory in mcp_servers.<name>.env.PATH)"
)
return _sanitize_error(message)

messages = _flatten_messages(exc)
deduped: List[str] = []
for item in messages:
if item not in deduped:
deduped.append(item)
return _sanitize_error("; ".join(deduped[:3]))


# ---------------------------------------------------------------------------
# Sampling -- server-initiated LLM requests (MCP sampling/createMessage)
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -612,6 +698,7 @@ async def _run_stdio(self, config: dict):
)

safe_env = _build_safe_env(user_env)
command, safe_env = _resolve_stdio_command(command, safe_env)
server_params = StdioServerParameters(
command=command,
args=args,
Expand Down Expand Up @@ -1344,9 +1431,12 @@ async def _discover_all():
for name, result in zip(server_names, results):
if isinstance(result, Exception):
failed_count += 1
command = new_servers.get(name, {}).get("command")
logger.warning(
"Failed to connect to MCP server '%s': %s",
name, result,
"Failed to connect to MCP server '%s'%s: %s",
name,
f" (command={command})" if command else "",
_format_connect_error(result),
)
elif isinstance(result, list):
all_tools.extend(result)
Expand Down
Loading