diff --git a/src/ha_mcp/__main__.py b/src/ha_mcp/__main__.py index 81cdfe6fa..67e62c709 100644 --- a/src/ha_mcp/__main__.py +++ b/src/ha_mcp/__main__.py @@ -30,6 +30,8 @@ from collections.abc import Coroutine # noqa: E402 from typing import TYPE_CHECKING, Any # noqa: E402 +from fastmcp.exceptions import ToolError # noqa: E402 +from pydantic import ValidationError as PydanticValidationError # noqa: E402 from starlette.requests import Request # noqa: E402 from starlette.responses import PlainTextResponse # noqa: E402 @@ -365,6 +367,36 @@ def filter(self, record: logging.LogRecord) -> bool: return True +class ToolValidationLogFilter(logging.Filter): + """Demote fastmcp tool-failure tracebacks to single-line warnings. + + Pydantic ValidationError and tool-raised ToolError aren't server bugs, + so the traceback through fastmcp/pydantic internals is just noise. The + structured error detail is preserved in the WARNING message; stack is + intentionally dropped because these are user-input errors, not bugs. + """ + + def filter(self, record: logging.LogRecord) -> bool: + if record.name != "fastmcp.server.server" or not record.exc_info: + return True + + msg = record.getMessage() + err = record.exc_info[1] + if "Error validating tool" in msg and isinstance(err, PydanticValidationError): + record.msg = f"{msg}: {err.errors(include_url=False)}" + elif "Error calling tool" in msg and isinstance(err, ToolError): + record.msg = f"{msg}: {err}" + else: + return True + + record.args = () + record.levelno = logging.WARNING + record.levelname = "WARNING" + record.exc_info = None + record.exc_text = None + return True + + def _setup_logging(log_level_str: str, force: bool = False) -> None: """Configure root logger with consistent timestamp format.""" logging.basicConfig( @@ -376,6 +408,7 @@ def _setup_logging(log_level_str: str, force: bool = False) -> None: logging.getLogger("mcp.server.streamable_http").addFilter( StatelessSessionLogFilter() ) + logging.getLogger("fastmcp.server.server").addFilter(ToolValidationLogFilter()) def _get_timestamped_uvicorn_log_config() -> dict: diff --git a/tests/src/unit/test_tool_validation_log_filter.py b/tests/src/unit/test_tool_validation_log_filter.py new file mode 100644 index 000000000..756276674 --- /dev/null +++ b/tests/src/unit/test_tool_validation_log_filter.py @@ -0,0 +1,118 @@ +"""Unit tests for ToolValidationLogFilter.""" + +import logging + +import pytest +from fastmcp.exceptions import FastMCPError, ToolError +from pydantic import BaseModel, ValidationError + +from ha_mcp.__main__ import ToolValidationLogFilter + + +def _pydantic_validation_error() -> ValidationError: + class _Model(BaseModel): + age: int + + with pytest.raises(ValidationError) as excinfo: + _Model(age="nope") + return excinfo.value + + +class TestToolValidationLogFilter: + """Verify the filter demotes fastmcp tool-failure tracebacks to WARNING.""" + + def setup_method(self): + self.log_filter = ToolValidationLogFilter() + + def _make_record( + self, + name: str, + msg: str, + exc: BaseException | None, + ) -> logging.LogRecord: + exc_info = (type(exc), exc, None) if exc is not None else None + return logging.LogRecord( + name=name, + level=logging.ERROR, + pathname="", + lineno=0, + msg=msg, + args=(), + exc_info=exc_info, + ) + + def test_demotes_validation_error_to_warning(self): + err = _pydantic_validation_error() + record = self._make_record( + "fastmcp.server.server", + "Error validating tool 'ha_foo'", + err, + ) + assert self.log_filter.filter(record) is True + assert record.levelno == logging.WARNING + assert record.levelname == "WARNING" + assert record.exc_info is None + assert record.exc_text is None + # Structured error info folded into the message, no pydantic URL. + assert "age" in record.getMessage() + assert "errors.pydantic.dev" not in record.getMessage() + + def test_demotes_tool_error_to_warning(self): + err = ToolError("bad input") + record = self._make_record( + "fastmcp.server.server", + "Error calling tool 'ha_foo'", + err, + ) + assert self.log_filter.filter(record) is True + assert record.levelno == logging.WARNING + assert record.exc_info is None + assert "bad input" in record.getMessage() + + def test_passes_bare_exception_through_untouched(self): + err = RuntimeError("server bug") + record = self._make_record( + "fastmcp.server.server", + "Error calling tool 'ha_foo'", + err, + ) + original_exc_info = record.exc_info + assert self.log_filter.filter(record) is True + assert record.levelno == logging.ERROR + assert record.exc_info is original_exc_info + + def test_passes_non_tool_fastmcp_error_through(self): + # A hypothetical future FastMCPError subclass that is NOT a ToolError + # (e.g. AuthorizationError) should retain its traceback. + class FutureAuthError(FastMCPError): + pass + + err = FutureAuthError("unauthorized") + record = self._make_record( + "fastmcp.server.server", + "Error calling tool 'ha_foo'", + err, + ) + assert self.log_filter.filter(record) is True + assert record.levelno == logging.ERROR + assert record.exc_info is not None + + def test_leaves_other_loggers_unchanged(self): + err = ToolError("boom") + record = self._make_record( + "some.other.logger", + "Error calling tool 'ha_foo'", + err, + ) + assert self.log_filter.filter(record) is True + assert record.levelno == logging.ERROR + assert record.exc_info is not None + + def test_passes_record_without_exc_info(self): + record = self._make_record( + "fastmcp.server.server", + "Error calling tool 'ha_foo'", + None, + ) + assert self.log_filter.filter(record) is True + assert record.levelno == logging.ERROR diff --git a/tests/uat/README.md b/tests/uat/README.md index 05be8b233..48ac588b2 100644 --- a/tests/uat/README.md +++ b/tests/uat/README.md @@ -2,6 +2,23 @@ Executes MCP test scenarios on real AI agent CLIs (Claude, Gemini, OpenAI-compatible) against a Home Assistant test instance. Designed to be driven by a calling agent that generates scenarios dynamically, runs them, and evaluates results. +## Quick Start + +Two runners, two use cases: + +```bash +# Run the pre-built story catalog (most common) +uv run python tests/uat/stories/run_story.py --all --agents gemini + +# Run one ad-hoc scenario (must pipe JSON via stdin or use --scenario-file) +echo '{"test_prompt":"Search for light entities. Report how many you found."}' \ + | uv run python tests/uat/run_uat.py --agents gemini +``` + +Commands must be prefixed with `uv run python` — the repo targets Python 3.13 via uv. + +For OpenAI-compatible endpoints (LM Studio, Ollama, vLLM), the `--base-url` must include the `/v1` suffix, e.g. `http://172.19.0.1:1234/v1`. LM Studio requires the model to be loaded in its UI first; Ollama auto-loads on demand. + ## Architecture ``` @@ -48,26 +65,26 @@ Each prompt runs in a separate CLI invocation (fresh context, no PR knowledge). ```bash # Pipe scenario from stdin echo '{"test_prompt":"Search for light entities. Report how many you found."}' | \ - python tests/uat/run_uat.py --agents gemini + uv run python tests/uat/run_uat.py --agents gemini # From file -python tests/uat/run_uat.py --scenario-file /tmp/scenario.json --agents claude,gemini +uv run python tests/uat/run_uat.py --scenario-file /tmp/scenario.json --agents claude,gemini # Against already-running HA (skip container startup) -python tests/uat/run_uat.py --ha-url http://localhost:8123 --ha-token TOKEN --agents gemini +uv run python tests/uat/run_uat.py --ha-url http://localhost:8123 --ha-token TOKEN --agents gemini # Test a specific branch -echo '{"test_prompt":"..."}' | python tests/uat/run_uat.py --branch feat/tool-errors --agents gemini +echo '{"test_prompt":"..."}' | uv run python tests/uat/run_uat.py --branch feat/tool-errors --agents gemini # Local code (default) vs branch -python tests/uat/run_uat.py # uses: uv run --project . ha-mcp -python tests/uat/run_uat.py --branch pr-551 # uses: uvx --from git+...@pr-551 ha-mcp +uv run python tests/uat/run_uat.py # uses: uv run --project . ha-mcp +uv run python tests/uat/run_uat.py --branch pr-551 # uses: uvx --from git+...@pr-551 ha-mcp # OpenAI-compatible local LLM (LM Studio, Ollama, vLLM, etc.) -echo '{"test_prompt":"..."}' | python tests/uat/run_uat.py --agents openai --base-url http://localhost:1234/v1 +echo '{"test_prompt":"..."}' | uv run python tests/uat/run_uat.py --agents openai --base-url http://localhost:1234/v1 # With a specific model and API key -echo '{"test_prompt":"..."}' | python tests/uat/run_uat.py --agents openai \ +echo '{"test_prompt":"..."}' | uv run python tests/uat/run_uat.py --agents openai \ --base-url http://localhost:1234/v1 --model my-model --api-key sk-xxx ``` @@ -176,10 +193,10 @@ To check if a failure is a regression vs pre-existing: ```bash # Test the PR branch -echo '{"test_prompt":"..."}' | python tests/uat/run_uat.py --branch feat/tool-errors --agents gemini +echo '{"test_prompt":"..."}' | uv run python tests/uat/run_uat.py --branch feat/tool-errors --agents gemini # Compare against master -echo '{"test_prompt":"..."}' | python tests/uat/run_uat.py --branch master --agents gemini +echo '{"test_prompt":"..."}' | uv run python tests/uat/run_uat.py --branch master --agents gemini ``` ## Dependencies diff --git a/tests/uat/_inprocess.py b/tests/uat/_inprocess.py new file mode 100644 index 000000000..23eacb9b6 --- /dev/null +++ b/tests/uat/_inprocess.py @@ -0,0 +1,67 @@ +"""Shared in-process FastMCP client context for UAT. + +Used by the story runner (setup/verify/teardown) and pytest fixtures. +Constructing the server is ~1s; this lets callers share one instance +across many tool calls against the same HA instance. +""" + +from __future__ import annotations + +import contextlib +import os +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from fastmcp import Client + + +@contextlib.asynccontextmanager +async def inprocess_mcp_client( + ha_url: str, ha_token: str +) -> AsyncIterator[Client]: + """Build one in-process FastMCP client for setup/verify/teardown. + + Clearing ``ha_mcp.config._settings`` forces the next ``get_global_settings()`` + call to re-read the env vars we just set. The WebSocket disconnect tears + down any cached connection to the previous URL so the next tool call + reconnects to ``ha_url``. + + One client is shared across many tool calls to amortize the ~1s server + construction cost. The tradeoff: if the shared client's WebSocket gets into + a bad state mid-run, every subsequent call through it inherits the problem. + + Not safe for concurrent use: ``os.environ`` and ``ha_mcp.config._settings`` + are process-global, so overlapping callers would race on both. + """ + from fastmcp import Client + + import ha_mcp.config + from ha_mcp.client import HomeAssistantClient + from ha_mcp.client.websocket_client import websocket_manager + from ha_mcp.server import HomeAssistantSmartMCPServer + + prev_url = os.environ.get("HOMEASSISTANT_URL") + prev_token = os.environ.get("HOMEASSISTANT_TOKEN") + prev_settings = ha_mcp.config._settings + try: + os.environ["HOMEASSISTANT_URL"] = ha_url + os.environ["HOMEASSISTANT_TOKEN"] = ha_token + ha_mcp.config._settings = None + await websocket_manager.disconnect() + + client = HomeAssistantClient(base_url=ha_url, token=ha_token) + server = HomeAssistantSmartMCPServer(client=client) + async with Client(server.mcp) as mcp_client: + yield mcp_client + finally: + await websocket_manager.disconnect() + if prev_url is None: + os.environ.pop("HOMEASSISTANT_URL", None) + else: + os.environ["HOMEASSISTANT_URL"] = prev_url + if prev_token is None: + os.environ.pop("HOMEASSISTANT_TOKEN", None) + else: + os.environ["HOMEASSISTANT_TOKEN"] = prev_token + ha_mcp.config._settings = prev_settings diff --git a/tests/uat/_logging.py b/tests/uat/_logging.py new file mode 100644 index 000000000..65a3f6624 --- /dev/null +++ b/tests/uat/_logging.py @@ -0,0 +1,11 @@ +"""Shared logging setup for UAT entry points.""" + +from __future__ import annotations + +import logging + + +def configure_cli_logging() -> None: + """Silence third-party INFO chatter; keep our uat.* trace visible.""" + logging.basicConfig(level=logging.WARNING, format="%(message)s") + logging.getLogger("uat").setLevel(logging.INFO) diff --git a/tests/uat/ha_wait.py b/tests/uat/ha_wait.py index 1142f5398..8b5d4b4d2 100644 --- a/tests/uat/ha_wait.py +++ b/tests/uat/ha_wait.py @@ -8,13 +8,11 @@ from __future__ import annotations import logging -import sys import time -from collections.abc import Callable import requests -logger = logging.getLogger(__name__) +logger = logging.getLogger("uat.ha_wait") MIN_COMPONENTS = 50 MIN_ENTITIES = 50 @@ -23,16 +21,7 @@ ENTITY_TIMEOUT = 30 -def _log(msg: str) -> None: - print(msg, file=sys.stderr, flush=True) - - -def wait_for_ha_ready( - url: str, - token: str, - *, - log: Callable[[str], None] = _log, -) -> None: +def wait_for_ha_ready(url: str, token: str) -> None: """Wait until HA is fully ready: components loaded, entities registered. Raises TimeoutError if any gate is not reached within its timeout. @@ -40,7 +29,7 @@ def wait_for_ha_ready( headers = {"Authorization": f"Bearer {token}"} # Gate 1: API reachable and components loaded - log(f"Waiting for HA at {url} ...") + logger.info(f"Waiting for HA at {url} ...") api_responded = False last_component_count = 0 for attempt in range(API_TIMEOUT): @@ -52,13 +41,13 @@ def wait_for_ha_ready( component_count = len(data.get("components", [])) if component_count >= MIN_COMPONENTS: version = data.get("version", "unknown") - log( + logger.info( f"HA stabilized: {component_count} components, " f"version {version} ({attempt + 1}s)" ) break if component_count != last_component_count: - log(f" {component_count} components loaded, waiting for {MIN_COMPONENTS}+...") + logger.info(f" {component_count} components loaded, waiting for {MIN_COMPONENTS}+...") last_component_count = component_count except (requests.RequestException, ValueError) as exc: logger.debug("Readiness check failed (retrying): %s", exc) @@ -72,7 +61,7 @@ def wait_for_ha_ready( ) # Gate 2: Entities registered - log("Waiting for HA entities to register...") + logger.info("Waiting for HA entities to register...") last_entity_count = 0 for attempt in range(ENTITY_TIMEOUT): try: @@ -80,10 +69,10 @@ def wait_for_ha_ready( if r.status_code == 200: entity_count = len(r.json()) if entity_count >= MIN_ENTITIES: - log(f"HA ready: {entity_count} entities registered ({attempt + 1}s)") + logger.info(f"HA ready: {entity_count} entities registered ({attempt + 1}s)") break if entity_count != last_entity_count: - log(f" {entity_count} entities registered, waiting for {MIN_ENTITIES}+...") + logger.info(f" {entity_count} entities registered, waiting for {MIN_ENTITIES}+...") last_entity_count = entity_count except (requests.RequestException, ValueError) as exc: logger.debug("Readiness check failed (retrying): %s", exc) diff --git a/tests/uat/openai_agent.py b/tests/uat/openai_agent.py index 64bc047ad..e76cf8ec9 100644 --- a/tests/uat/openai_agent.py +++ b/tests/uat/openai_agent.py @@ -17,22 +17,37 @@ import argparse import asyncio import json +import logging +import re import sys -import traceback from pathlib import Path import openai from fastmcp import Client as MCPClient from mcp.types import Tool as MCPTool +# Allow `python tests/uat/openai_agent.py` (subprocess path from run_uat.py) +# to resolve the `uat` namespace package. +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) +from uat._logging import configure_cli_logging + DEFAULT_API_KEY = "no-key" DEFAULT_TIMEOUT = 120 DEFAULT_MAX_TOKENS = 8192 MAX_TOOL_LOOP_ITERATIONS = 20 -def log(msg: str) -> None: - print(msg, file=sys.stderr, flush=True) +_PYDANTIC_URL_LINE = re.compile( + r"\s*For further information visit https://errors\.pydantic\.dev/\S+" +) + + +def _strip_pydantic_url(text: str) -> str: + """Drop Pydantic's documentation URL footer from a stringified exception.""" + return _PYDANTIC_URL_LINE.sub("", text) + + +logger = logging.getLogger("uat.openai_agent") def mcp_tool_to_openai(tool: MCPTool) -> dict: @@ -54,7 +69,7 @@ async def detect_model(client: openai.AsyncOpenAI) -> str: if not models.data: raise RuntimeError("No models available at the API endpoint") model_id = models.data[0].id - log(f"Auto-detected model: {model_id}") + logger.info(f"Auto-detected model: {model_id}") return model_id @@ -201,7 +216,7 @@ async def tool_call_loop( f" [tool] {tool_name}: malformed arguments: " f"{tc.function.arguments!r}" ) - log(malformed_line) + logger.info(malformed_line) if tool_trace_sink is not None: tool_trace_sink.append(malformed_line.strip()) total_fail += 1 @@ -215,7 +230,7 @@ async def tool_call_loop( continue call_line = f" [tool] {tool_name}({tool_args})" - log(call_line) + logger.info(call_line) if tool_trace_sink is not None: tool_trace_sink.append(call_line.strip()) @@ -224,12 +239,15 @@ async def tool_call_loop( result_text = extract_tool_result_text(result) total_success += 1 except Exception as e: - result_text = f"Error: {str(e)}" + err_text = _strip_pydantic_url(str(e)) + result_text = f"Error: {err_text}" total_fail += 1 - fail_line = f" [tool] {tool_name} failed: {e}" - log(fail_line) + # Server-side WARNING log already shows the failure details; + # only record to the trace sink for test artifacts. if tool_trace_sink is not None: - tool_trace_sink.append(fail_line.strip()) + tool_trace_sink.append( + f"[tool] {tool_name} failed: {err_text}" + ) messages.append( { @@ -265,7 +283,7 @@ async def run_agent( # Read MCP config — same format as Claude's --mcp-config config = json.loads(Path(args.mcp_config).read_text()) # noqa: ASYNC240 - log("Starting MCP server...") + logger.info("Starting MCP server...") # fastmcp.Client accepts a config dict (same format as Claude's --mcp-config) async with MCPClient(config) as mcp_client: @@ -314,7 +332,7 @@ async def run_scenario_inline( """ if openai_tools is None: openai_tools = await fetch_openai_tools(mcp_client, max_tools=max_tools) - log(f"Loaded {len(openai_tools)} MCP tools") + logger.info(f"Loaded {len(openai_tools)} MCP tools") agent_prompt = ("/no_think\n\n" + prompt) if no_think else prompt messages = [{"role": "user", "content": agent_prompt}] @@ -346,14 +364,14 @@ async def create_and_warm_openai_client( """ client = openai.AsyncOpenAI(base_url=base_url, api_key=api_key, timeout=timeout) resolved_model = model or await detect_model(client) - log(f"Using model: {resolved_model}") - log("Warming up model (may take a minute if not loaded)...") + logger.info(f"Using model: {resolved_model}") + logger.info("Warming up model (may take a minute if not loaded)...") await client.chat.completions.create( model=resolved_model, messages=[{"role": "user", "content": "hi"}], max_tokens=1, ) - log("Model ready") + logger.info("Model ready") return client, resolved_model @@ -366,31 +384,32 @@ async def _main_async(args: argparse.Namespace) -> None: model=args.model, ) except openai.BadRequestError as e: - log(f"ERROR: Model warmup failed (BadRequestError): {e}") + logger.error(f"Model warmup failed (BadRequestError): {e}") sys.exit(1) - except Exception as e: - log(f"ERROR ({type(e).__name__}): {e}\n{traceback.format_exc()}") + except Exception: + logger.exception("Model warmup failed") sys.exit(1) - log(f"MCP config: {args.mcp_config}") + logger.info(f"MCP config: {args.mcp_config}") try: try: result = await run_agent(client, model, args) finally: await client.close() - except Exception as e: - log(f"ERROR ({type(e).__name__}): {e}\n{traceback.format_exc()}") + except Exception: + logger.exception("Agent run failed") sys.exit(1) json.dump(result, sys.stdout, indent=2) print() if result.get("hit_iteration_limit"): - log("ERROR: hit max tool-call iterations without a final response") + logger.error("hit max tool-call iterations without a final response") sys.exit(1) def main() -> None: + configure_cli_logging() args = parse_args() asyncio.run(_main_async(args)) diff --git a/tests/uat/run_uat.py b/tests/uat/run_uat.py index 19ba39560..02e33a116 100644 --- a/tests/uat/run_uat.py +++ b/tests/uat/run_uat.py @@ -10,17 +10,20 @@ the file path — the calling agent only reads the full file when needed. Usage: - echo '{"test_prompt":"Search for light entities."}' | python tests/uat/run_uat.py --agents gemini - python tests/uat/run_uat.py --scenario-file /tmp/scenario.json --agents claude,gemini - python tests/uat/run_uat.py --ha-url http://localhost:8123 --ha-token TOKEN --agents gemini + echo '{"test_prompt":"Search for light entities."}' | uv run python tests/uat/run_uat.py --agents gemini + uv run python tests/uat/run_uat.py --scenario-file /tmp/scenario.json --agents claude,gemini + uv run python tests/uat/run_uat.py --ha-url http://localhost:8123 --ha-token TOKEN --agents gemini """ from __future__ import annotations import argparse import asyncio +import difflib import json +import logging import os +import re import shutil import subprocess import sys @@ -28,6 +31,7 @@ import time from collections.abc import Callable from pathlib import Path +from typing import NoReturn import requests from testcontainers.core.container import DockerContainer @@ -39,6 +43,7 @@ sys.path.insert(0, str(TESTS_DIR)) from test_constants import HA_TEST_IMAGE, TEST_TOKEN # noqa: E402 +from uat._logging import configure_cli_logging # noqa: E402 from uat.ha_wait import wait_for_ha_ready # noqa: E402 HA_IMAGE = HA_TEST_IMAGE @@ -47,11 +52,20 @@ DEFAULT_AGENTS = "claude,gemini" -# --------------------------------------------------------------------------- -# Logging (stderr only - stdout is reserved for JSON output) -# --------------------------------------------------------------------------- -def log(msg: str) -> None: - print(msg, file=sys.stderr, flush=True) +logger = logging.getLogger("uat.run_uat") + + +class SuggestingArgumentParser(argparse.ArgumentParser): + """argparse parser that suggests close matches for unknown flags.""" + + def error(self, message: str) -> NoReturn: + match = re.search(r"unrecognized arguments?: (--\S+)", message) + if match: + known = [opt for opt in self._option_string_actions if opt.startswith("--")] + suggestions = difflib.get_close_matches(match.group(1), known, n=1, cutoff=0.6) + if suggestions: + message = f"{message} (did you mean {suggestions[0]}?)" + super().error(message) # --------------------------------------------------------------------------- @@ -103,8 +117,8 @@ def __enter__(self) -> HAContainer: try: port = self.container.get_exposed_port(8123) self.url = f"http://localhost:{port}" - log(f"HA container started on {self.url}") - wait_for_ha_ready(self.url, self.token, log=log) + logger.info(f"HA container started on {self.url}") + wait_for_ha_ready(self.url, self.token) except Exception: self.__exit__(None, None, None) raise @@ -112,7 +126,7 @@ def __enter__(self) -> HAContainer: def __exit__(self, *exc: object) -> None: if self.container: - log("Stopping HA container...") + logger.info("Stopping HA container...") self.container.stop() if self.config_dir and self.config_dir.exists(): shutil.rmtree(self.config_dir, ignore_errors=True) @@ -174,7 +188,12 @@ def preflight_check_base_url(base_url: str, timeout: float = 5.0) -> str | None: def _build_mcp_env( ha_url: str, ha_token: str, extra_env: dict[str, str] | None ) -> dict[str, str]: - env = {"HOMEASSISTANT_URL": ha_url, "HOMEASSISTANT_TOKEN": ha_token} + # Override with --mcp-env LOG_LEVEL=INFO when debugging the server. + env = { + "HOMEASSISTANT_URL": ha_url, + "HOMEASSISTANT_TOKEN": ha_token, + "LOG_LEVEL": "WARNING", + } if extra_env: env.update(extra_env) return env @@ -482,7 +501,7 @@ async def run_agent_scenario( continue phase_key = phase.replace("_prompt", "") - log(f" [{agent_name}] Running {phase_key}...") + logger.info(f" [{agent_name}] Running {phase_key}...") if agent_name == "claude": assert stdio_config_path is not None @@ -517,7 +536,7 @@ async def run_agent_scenario( } results[phase_key] = result - log( + logger.info( f" [{agent_name}] {phase_key} completed (exit={result['exit_code']}, {result['duration_ms']}ms)" ) # Forward agent stderr on failure so the error is visible to the user @@ -525,9 +544,9 @@ async def run_agent_scenario( _BOX_CHARS = frozenset("│╭╰╮─▄█▀ \t") for line in result["stderr"].splitlines(): if "error" in line.lower(): - log(f" [{agent_name}] !! {line.strip()}") + logger.info(f" [{agent_name}] !! {line.strip()}") elif not all(c in _BOX_CHARS for c in line): - log(f" [{agent_name}] stderr: {line}") + logger.info(f" [{agent_name}] stderr: {line}") finally: # Cleanup temp files if stdio_config_path and stdio_config_path.exists(): @@ -661,6 +680,13 @@ async def run(args: argparse.Namespace) -> dict: if args.scenario_file: scenario = json.loads(Path(args.scenario_file).read_text()) # noqa: ASYNC240 else: + if sys.stdin.isatty(): + raise ValueError( + "No scenario provided. Pipe scenario JSON via stdin, or pass --scenario-file.\n" + " echo '{\"test_prompt\":\"...\"}' | uv run python tests/uat/run_uat.py --agents gemini\n" + " uv run python tests/uat/run_uat.py --scenario-file scenario.json --agents gemini\n" + "For the pre-built story catalog, use tests/uat/stories/run_story.py --all." + ) scenario = json.loads(sys.stdin.read()) if "test_prompt" not in scenario: @@ -673,7 +699,7 @@ async def run(args: argparse.Namespace) -> dict: available = check_agent_available(name) agents[name] = available if not available: - log(f"WARNING: {name} CLI not found, skipping") + logger.warning(f"{name} CLI not found, skipping") active_agents = [name for name, avail in agents.items() if avail] if not active_agents: @@ -709,14 +735,14 @@ async def run(args: argparse.Namespace) -> dict: ha_token = container.token try: - log(f"HA: {ha_url}") - log(f"MCP source: {mcp_source}" + (f" ({args.branch})" if args.branch else "")) - log(f"Agents: {', '.join(active_agents)}") + logger.info(f"HA: {ha_url}") + logger.info(f"MCP source: {mcp_source}" + (f" ({args.branch})" if args.branch else "")) + logger.info(f"Agents: {', '.join(active_agents)}") extra_env = parse_mcp_env( getattr(args, "mcp_env", None), base_url=getattr(args, "base_url", None), - on_default_applied=log, + on_default_applied=logger.info, ) # Run agents sequentially to avoid resource contention @@ -756,15 +782,17 @@ async def run(args: argparse.Namespace) -> dict: def main() -> None: - parser = argparse.ArgumentParser( + configure_cli_logging() + + parser = SuggestingArgumentParser( description="BAT Runner - Execute MCP test scenarios on AI agent CLIs", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: - echo '{"test_prompt":"Search for light entities."}' | python tests/uat/run_uat.py --agents gemini - python tests/uat/run_uat.py --scenario-file /tmp/scenario.json --agents claude,gemini - python tests/uat/run_uat.py --ha-url http://localhost:8123 --ha-token TOKEN --agents gemini - python tests/uat/run_uat.py --branch feat/tool-errors --agents gemini + echo '{"test_prompt":"Search for light entities."}' | uv run python tests/uat/run_uat.py --agents gemini + uv run python tests/uat/run_uat.py --scenario-file /tmp/scenario.json --agents claude,gemini + uv run python tests/uat/run_uat.py --ha-url http://localhost:8123 --ha-token TOKEN --agents gemini + uv run python tests/uat/run_uat.py --branch feat/tool-errors --agents gemini """, ) parser.add_argument( @@ -835,7 +863,7 @@ def main() -> None: try: full_results = asyncio.run(run(args)) except ValueError as e: - log(f"ERROR: {e}") + logger.error(str(e)) sys.exit(1) except KeyboardInterrupt: sys.exit(130) diff --git a/tests/uat/stories/conftest.py b/tests/uat/stories/conftest.py index 260d64256..a9dce1767 100644 --- a/tests/uat/stories/conftest.py +++ b/tests/uat/stories/conftest.py @@ -30,11 +30,9 @@ from fastmcp import Client # noqa: E402 from test_constants import HA_TEST_IMAGE, TEST_TOKEN # noqa: E402 +from uat._inprocess import inprocess_mcp_client # noqa: E402 from uat.ha_wait import wait_for_ha_ready # noqa: E402 -from ha_mcp.client import HomeAssistantClient # noqa: E402 -from ha_mcp.server import HomeAssistantSmartMCPServer # noqa: E402 - logger = logging.getLogger(__name__) HA_IMAGE = HA_TEST_IMAGE @@ -128,18 +126,10 @@ def event_loop(): @pytest.fixture async def mcp_client(ha_container) -> AsyncGenerator[Client]: """FastMCP in-memory client for programmatic setup/teardown.""" - import ha_mcp.config - - ha_mcp.config._settings = None - - client = HomeAssistantClient( - base_url=ha_container["url"], token=ha_container["token"] - ) - server = HomeAssistantSmartMCPServer(client=client) - fastmcp_client = Client(server.mcp) - - async with fastmcp_client: - yield fastmcp_client + async with inprocess_mcp_client( + ha_container["url"], ha_container["token"] + ) as client: + yield client # --------------------------------------------------------------------------- diff --git a/tests/uat/stories/run_story.py b/tests/uat/stories/run_story.py index 44d4cefb9..8128032c9 100644 --- a/tests/uat/stories/run_story.py +++ b/tests/uat/stories/run_story.py @@ -69,13 +69,12 @@ sys.path.insert(0, str(SCRIPT_DIR)) # for scripts/ subdirectory imports from scripts.verify_story import verify_ha_checks # noqa: E402 +from uat._inprocess import inprocess_mcp_client # noqa: E402 +from uat._logging import configure_cli_logging # noqa: E402 from uat.ha_wait import wait_for_ha_ready # noqa: E402 +from uat.run_uat import SuggestingArgumentParser # noqa: E402 -logger = logging.getLogger(__name__) - - -def log(msg: str) -> None: - print(msg, file=sys.stderr, flush=True) +logger = logging.getLogger("uat.stories.run_story") # --------------------------------------------------------------------------- @@ -131,10 +130,10 @@ def _start_container(*, keep_alive: bool = False) -> dict: try: port = container.get_exposed_port(8123) url = f"http://localhost:{port}" - log(f"HA container started on {url}") + logger.info(f"HA container started on {url}") # Wait for HA to be fully ready (API + components + entities) - wait_for_ha_ready(url, TEST_TOKEN, log=log) + wait_for_ha_ready(url, TEST_TOKEN) except Exception: container.stop() shutil.rmtree(config_dir, ignore_errors=True) @@ -152,7 +151,7 @@ def _stop_container(ha: dict) -> None: """Stop HA container and clean up.""" import shutil - log("Stopping HA container...") + logger.info("Stopping HA container...") ha["container"].stop() shutil.rmtree(ha["config_dir"], ignore_errors=True) @@ -194,7 +193,7 @@ def _extract_tokens(session_file: str | None, agent: str) -> dict | None: ) + usage.get("cache_creation_input_tokens", 0) return totals except Exception as exc: - log(f" Token extraction failed: {exc}") + logger.warning(f" Token extraction failed: {exc}") return None return None @@ -223,7 +222,7 @@ def _extract_tool_calls(session_file: str | None, agent: str) -> int | None: count += 1 return count except Exception as exc: - log(f" Tool call extraction failed: {exc}") + logger.warning(f" Tool call extraction failed: {exc}") return None return None @@ -273,51 +272,25 @@ def _find_latest_session_file(agent: str, after: float) -> str | None: return None -# --------------------------------------------------------------------------- -# FastMCP in-memory setup # --------------------------------------------------------------------------- async def _run_mcp_steps( - ha_url: str, ha_token: str, steps: list[dict], phase: str + mcp_client: MCPClient, steps: list[dict], phase: str ) -> None: - """Execute setup or teardown steps via FastMCP in-memory client.""" - if not steps: - return - - import os - - import ha_mcp.config - from ha_mcp.client import HomeAssistantClient - from ha_mcp.client.websocket_client import websocket_manager - from ha_mcp.server import HomeAssistantSmartMCPServer - - # Point global settings at the test HA instance before resetting. - # The WebSocket client uses get_global_settings() (reads env vars), not the - # HomeAssistantClient base_url, so we must set env vars explicitly. - os.environ["HOMEASSISTANT_URL"] = ha_url - os.environ["HOMEASSISTANT_TOKEN"] = ha_token - ha_mcp.config._settings = None - - # Disconnect any cached WebSocket so it reconnects to the test instance. - await websocket_manager.disconnect() - - client = HomeAssistantClient(base_url=ha_url, token=ha_token) - server = HomeAssistantSmartMCPServer(client=client) - - from fastmcp import Client - - async with Client(server.mcp) as mcp_client: - for step in steps: - tool_name = step["tool"] - args = step.get("args", {}) - log(f" [{phase}] {tool_name}({args})") - try: - await mcp_client.call_tool(tool_name, args) - except Exception as e: - if phase == "setup": - log(f" [{phase}] {tool_name} FAILED: {e}") - raise - else: - log(f" [{phase}] {tool_name} failed (ok): {e}") + """Execute setup or teardown steps via a shared in-memory MCP client.""" + for step in steps: + tool_name = step["tool"] + args = step.get("args", {}) + logger.info(f" [{phase}] {tool_name}({args})") + try: + await mcp_client.call_tool(tool_name, args) + except Exception: + if phase == "setup": + logger.info(f" [{phase}] {tool_name} FAILED (see server log)") + raise + logger.warning( + f" [{phase}] {tool_name} failed, ignored (may poison shared client " + "for next story; see server log)" + ) # --------------------------------------------------------------------------- @@ -379,8 +352,8 @@ def _run_test_prompt( timeout=600, ) - if result.stderr: - print(result.stderr, file=sys.stderr, end="") + for line in result.stderr.splitlines(): + logger.info(line) summary = None if result.stdout.strip(): @@ -431,9 +404,8 @@ async def _run_test_prompt_inline( openai_tools=openai_tools, ) except Exception as e: - log(f" [{agent_name}] inline run failed ({type(e).__name__}): {e}") + logger.exception(f" [{agent_name}] inline run failed") tb = traceback.format_exc() - log(tb) duration_ms = int((time.time() - start) * 1000) return 1, _inline_failure_summary( agent_name, @@ -531,7 +503,7 @@ def _record_setup_failure( """ for _path, story in filtered: summary = _inline_failure_summary(agent, error_msg=error_msg) - all_results.append((agent, story["id"], story, 1, summary, None)) + all_results.append((agent, story["id"], story, 1, summary, None, False)) append_result( results_file, story, @@ -541,7 +513,7 @@ def _record_setup_failure( branch, summary, None, - exit_code=1, + passed=False, verify_results=None, ) @@ -610,7 +582,7 @@ def append_result( branch: str | None, bat_summary: dict, session_file: str | None = None, - exit_code: int = 0, + passed: bool = False, verify_results: list[dict] | None = None, ) -> None: """Append a single story result as one JSONL line.""" @@ -627,11 +599,7 @@ def append_result( "story": story["id"], "category": story["category"], "weight": story["weight"], - "passed": _compute_passed( - exit_code=exit_code, - tool_calls=aggregate.get("total_tool_calls"), - verify_results=verify_results, - ), + "passed": passed, "test_duration_ms": test_phase.get("duration_ms"), "total_duration_ms": aggregate.get("total_duration_ms"), "tool_calls": aggregate.get("total_tool_calls"), @@ -696,6 +664,7 @@ async def run_stories( For each agent: start container -> run all stories -> stop container. When --ha-url is provided, all agents share the external instance. """ + run_start = time.time() sha, describe = get_git_info() agent_list = [a.strip() for a in args.agents.split(",")] using_external_ha = bool(args.ha_url) @@ -710,28 +679,28 @@ async def run_stories( if not using_external_ha: err = preflight_check_docker() if err: - log(f"FATAL: {err}") + logger.critical(err) return 2 if args.base_url and "openai" in agent_list: err = preflight_check_base_url(args.base_url) if err: - log(f"FATAL: {err}") + logger.critical(err) return 2 - all_results: list[tuple[str, str, dict, int, dict | None, str | None]] = [] - # Each entry: (agent, story_id, story, exit_code, summary, session_file) + all_results: list[tuple[str, str, dict, int, dict | None, str | None, bool]] = [] + # Each entry: (agent, story_id, story, exit_code, summary, session_file, passed) mcp_env_dict = parse_mcp_env( getattr(args, "mcp_env", None), base_url=args.base_url, - on_default_applied=log, + on_default_applied=logger.info, ) effective_mcp_env: list[str] = [f"{k}={v}" for k, v in mcp_env_dict.items()] for agent in agent_list: - log(f"\n{'#' * 60}") - log(f"Agent: {agent}") - log(f"{'#' * 60}") + logger.info(f"\n{'#' * 60}") + logger.info(f"Agent: {agent}") + logger.info(f"{'#' * 60}") ha = None ha_url = args.ha_url @@ -774,7 +743,7 @@ async def run_stories( agent_stack.push_async_callback(openai_client.close) except Exception as e: error_msg = f"Failed to initialise OpenAI client: {type(e).__name__}: {e}" - log(f"[{agent}] {error_msg}") + logger.error(f"[{agent}] {error_msg}") _record_setup_failure( filtered, agent, @@ -787,16 +756,18 @@ async def run_stories( ) continue try: + source = f"uvx download @ {args.branch}" if args.branch else "local" + logger.info(f"[{agent}] Starting MCP server ({source})...") inline_mcp_client = await agent_stack.enter_async_context( _MCPClient(config) ) openai_tools = await fetch_openai_tools( inline_mcp_client, max_tools=args.max_tools ) - log(f"[{agent}] MCP server ready ({len(openai_tools)} tools)") + logger.info(f"[{agent}] MCP server ready ({len(openai_tools)} tools)") except Exception as e: error_msg = f"Failed to start MCP server: {type(e).__name__}: {e}" - log(f"[{agent}] {error_msg}") + logger.error(f"[{agent}] {error_msg}") _record_setup_failure( filtered, agent, @@ -810,26 +781,30 @@ async def run_stories( continue if ha and args.keep_container: - log(f"\n[{agent}] Container kept alive: {ha['url']}") - log(f"[{agent}] Token: {ha['token']}") - log(f"[{agent}] Config dir: {ha['config_dir']}") - log(f"[{agent}] Stop manually: docker stop ") + logger.info(f"\n[{agent}] Container kept alive: {ha['url']}") + logger.info(f"[{agent}] Token: {ha['token']}") + logger.info(f"[{agent}] Config dir: {ha['config_dir']}") + logger.info(f"[{agent}] Stop manually: docker stop ") + + shared_mcp = await agent_stack.enter_async_context( + inprocess_mcp_client(ha_url, ha_token) + ) for _path, story in filtered: sid = story["id"] - log(f"\n{'=' * 60}") - log(f"[{agent}] Story {sid}: {story['title']}") - log(f"{'=' * 60}") + logger.info(f"\n{'=' * 60}") + logger.info(f"[{agent}] Story {sid}: {story['title']}") + logger.info(f"{'=' * 60}") setup_steps = story.get("setup") or [] if setup_steps: - log( + logger.info( f"[{agent}/{sid}] Setup ({len(setup_steps)} steps via FastMCP)..." ) - await _run_mcp_steps(ha_url, ha_token, setup_steps, "setup") + await _run_mcp_steps(shared_mcp, setup_steps, "setup") - log(f"[{agent}/{sid}] Running test prompt...") - run_start = time.time() + logger.info(f"[{agent}/{sid}] Running test prompt...") + prompt_start = time.time() summary: dict | None if use_inline: assert ( @@ -872,7 +847,7 @@ async def run_stories( if claude_session_id: session_file = _find_session_file_by_id(claude_session_id) if not session_file: - session_file = _find_latest_session_file(agent, after=run_start) + session_file = _find_latest_session_file(agent, after=prompt_start) verify_results = None ha_checks = (story.get("verify") or {}).get("ha_checks") @@ -884,19 +859,29 @@ async def run_stories( .get("test", {}) .get("output", "") ) - log(f"[{agent}/{sid}] Verifying {len(ha_checks)} ha_check(s)...") + logger.info(f"[{agent}/{sid}] Verifying {len(ha_checks)} ha_check(s)...") verify_results = await verify_ha_checks( - ha_url, ha_token, ha_checks, agent_output + ha_url, ha_token, ha_checks, agent_output, shared_mcp ) failed_checks = [r for r in verify_results if not r["passed"]] if failed_checks: - log(f"[{agent}/{sid}] {len(failed_checks)}/{len(ha_checks)} check(s) FAILED") + logger.warning( + f"[{agent}/{sid}] {len(failed_checks)}/{len(ha_checks)} check(s) FAILED" + ) for r in failed_checks: - log(f" FAIL [{r['type']}] {r['detail']}") + logger.warning(f" FAIL [{r['type']}] {r['detail']}") else: - log(f"[{agent}/{sid}] All checks passed") + logger.info(f"[{agent}/{sid}] All checks passed") - all_results.append((agent, sid, story, rc, summary, session_file)) + agg = (summary or {}).get("agents", {}).get(agent, {}).get("aggregate", {}) + passed = _compute_passed( + exit_code=rc, + tool_calls=agg.get("total_tool_calls"), + verify_results=verify_results, + ) + all_results.append( + (agent, sid, story, rc, summary, session_file, passed) + ) if summary or verify_results is not None: append_result( @@ -908,36 +893,38 @@ async def run_stories( args.branch, summary or {}, session_file, - exit_code=rc, + passed=passed, verify_results=verify_results, ) if session_file: - log(f"[{agent}/{sid}] Session file: {session_file}") + logger.info(f"[{agent}/{sid}] Session file: {session_file}") # Summary - log(f"\n{'=' * 60}") - log("Summary") - log(f"{'=' * 60}") - for agent, sid, story, rc, _, session_file in all_results: - status = "PASS" if rc == 0 else "FAIL" + logger.info(f"\n{'=' * 60}") + logger.info("Summary") + logger.info(f"{'=' * 60}") + for agent, sid, story, _rc, _, session_file, passed in all_results: + status = "PASS" if passed else "FAIL" session_info = f" (session: {session_file})" if session_file else "" - log(f" [{status}] {agent}/{sid}: {story['title']}{session_info}") + logger.info(f" [{status}] {agent}/{sid}: {story['title']}{session_info}") - log(f"\nResults appended to {args.results_file}") + elapsed = time.time() - run_start + mins, secs = divmod(int(elapsed), 60) + logger.info(f"\nTotal time: {mins}m {secs}s") + logger.info(f"Results appended to {args.results_file}") - failed = sum(1 for _, _, _, rc, _, _ in all_results if rc != 0) + failed = sum(1 for *_, passed in all_results if not passed) total = len(all_results) if failed: - log(f"\n{failed}/{total} story runs failed") + logger.warning(f"\n{failed}/{total} story runs failed") return 1 - else: - log(f"\nAll {total} story runs passed") - return 0 + logger.info(f"\nAll {total} story runs passed") + return 0 def main() -> None: - parser = argparse.ArgumentParser( + parser = SuggestingArgumentParser( description="Run user acceptance stories via BAT", formatter_class=argparse.RawDescriptionHelpFormatter, ) @@ -1014,7 +1001,7 @@ def main() -> None: if "openai" in agent_list and not args.base_url: parser.error("--base-url is required when using the openai agent") - logging.basicConfig(level=logging.INFO, format="%(message)s") + configure_cli_logging() if args.all: stories = sorted(CATALOG_DIR.glob("s*.yaml")) @@ -1045,7 +1032,7 @@ def main() -> None: try: exit_code = asyncio.run(run_stories(args, filtered)) except KeyboardInterrupt: - log("\nInterrupted") + logger.info("\nInterrupted") sys.exit(130) sys.exit(exit_code) diff --git a/tests/uat/stories/scripts/verify_story.py b/tests/uat/stories/scripts/verify_story.py index e7ae58038..8c6916d27 100644 --- a/tests/uat/stories/scripts/verify_story.py +++ b/tests/uat/stories/scripts/verify_story.py @@ -4,11 +4,13 @@ import asyncio import re -from contextlib import asynccontextmanager -from typing import Any +from typing import TYPE_CHECKING, Any import httpx +if TYPE_CHECKING: + from fastmcp import Client + async def _retry(fn, attempts: int = 3, delay: float = 2.0) -> Any | None: """Call fn() up to `attempts` times, returning first non-None result.""" @@ -239,48 +241,6 @@ def _check_response_matches(check: dict, agent_output: str) -> dict: } -# --------------------------------------------------------------------------- -# Shared MCP context — one server instance per verify_ha_checks call -# --------------------------------------------------------------------------- - - -@asynccontextmanager -async def _mcp_context(ha_url: str, ha_token: str): - """Create a single shared MCP client for all async checks in one verify run.""" - import os - - from fastmcp import Client - - import ha_mcp.config - from ha_mcp.client import HomeAssistantClient - from ha_mcp.client.websocket_client import websocket_manager - from ha_mcp.server import HomeAssistantSmartMCPServer - - prev_url = os.environ.get("HOMEASSISTANT_URL") - prev_token = os.environ.get("HOMEASSISTANT_TOKEN") - prev_settings = ha_mcp.config._settings - try: - os.environ["HOMEASSISTANT_URL"] = ha_url - os.environ["HOMEASSISTANT_TOKEN"] = ha_token - ha_mcp.config._settings = None - await websocket_manager.disconnect() - - ha_client = HomeAssistantClient(base_url=ha_url, token=ha_token) - server = HomeAssistantSmartMCPServer(client=ha_client) - async with Client(server.mcp) as mcp_client: - yield mcp_client - finally: - if prev_url is None: - os.environ.pop("HOMEASSISTANT_URL", None) - else: - os.environ["HOMEASSISTANT_URL"] = prev_url - if prev_token is None: - os.environ.pop("HOMEASSISTANT_TOKEN", None) - else: - os.environ["HOMEASSISTANT_TOKEN"] = prev_token - ha_mcp.config._settings = prev_settings - - # --------------------------------------------------------------------------- # Public entry point # --------------------------------------------------------------------------- @@ -291,25 +251,24 @@ async def verify_ha_checks( ha_token: str, checks: list[dict], agent_output: str, + mcp_client: Client, ) -> list[dict]: - """Run all checks concurrently and return results list [{type, passed, detail, ...}].""" - headers = {"Authorization": f"Bearer {ha_token}"} + """Run all checks concurrently and return results list [{type, passed, detail, ...}]. - async def run_all(mcp_client=None) -> list[dict]: - async def run_check(check: dict) -> dict: - check_type = check["type"] - if check_type in SYNC_CHECKS: - return await SYNC_CHECKS[check_type](http, check) - if check_type in ASYNC_CHECKS: - return await ASYNC_CHECKS[check_type](mcp_client, check) - if check_type in RESPONSE_CHECKS: - return RESPONSE_CHECKS[check_type](check, agent_output) - return {**check, "passed": False, "detail": f"Unknown check type: {check_type}"} + Caller owns the ``mcp_client`` lifecycle so one in-process server can be + shared across many verify_ha_checks calls. + """ + headers = {"Authorization": f"Bearer {ha_token}"} - return list(await asyncio.gather(*[run_check(c) for c in checks])) + async def run_check(check: dict, http) -> dict: + check_type = check["type"] + if check_type in SYNC_CHECKS: + return await SYNC_CHECKS[check_type](http, check) + if check_type in ASYNC_CHECKS: + return await ASYNC_CHECKS[check_type](mcp_client, check) + if check_type in RESPONSE_CHECKS: + return RESPONSE_CHECKS[check_type](check, agent_output) + return {**check, "passed": False, "detail": f"Unknown check type: {check_type}"} async with httpx.AsyncClient(base_url=ha_url, headers=headers, timeout=10) as http: - if any(c["type"] in ASYNC_CHECKS for c in checks): - async with _mcp_context(ha_url, ha_token) as mcp_client: - return await run_all(mcp_client) - return await run_all() + return list(await asyncio.gather(*[run_check(c, http) for c in checks]))