diff --git a/examples/plugins/cao-discord/README.md b/examples/plugins/cao-discord/README.md new file mode 100644 index 00000000..e40865da --- /dev/null +++ b/examples/plugins/cao-discord/README.md @@ -0,0 +1,52 @@ +# cao-discord + +`cao-discord` is a CAO plugin that forwards inter-agent messages to a Discord channel through a webhook, rendering your CAO workflow as a live group chat of bots in Discord. + +## Install + +From the repository root, inside the CAO development virtual environment: + +```bash +uv pip install -e examples/plugins/cao-discord +``` + +## Example `.env` + +```dotenv +CAO_DISCORD_WEBHOOK_URL=https://discord.com/api/webhooks/1234567890/abcdef... +CAO_DISCORD_TIMEOUT_SECONDS=5.0 +``` + +## Setup + +1. Create a webhook in Discord: Channel -> Edit Channel -> Integrations -> Webhooks -> New Webhook -> Copy URL. +2. Install the plugin: + ```bash + uv pip install -e examples/plugins/cao-discord + + # Or use if you prefer uv tool install + # (from project root) + uv tool install --reinstall . \ + --with-editable ./examples/plugins/cao-discord + ``` +3. Create a `.env` file in the directory where you will run `cao-server`, or export the variables in your shell: + ```dotenv + CAO_DISCORD_WEBHOOK_URL=https://discord.com/api/webhooks/1234567890/abcdef... + CAO_DISCORD_TIMEOUT_SECONDS=5.0 + ``` +4. Start the server: + ```bash + cao-server + ``` +5. Launch a multi-agent workflow such as `cao flow ...` and watch the Discord channel for forwarded inter-agent messages. + +## Configuration + +| Variable | Required | Description | +| --- | --- | --- | +| `CAO_DISCORD_WEBHOOK_URL` | Yes | Full Discord webhook URL in the form `https://discord.com/api/webhooks/{id}/{token}`. | +| `CAO_DISCORD_TIMEOUT_SECONDS` | No | HTTP timeout in seconds for webhook POSTs. Defaults to `5.0`. | + +## Troubleshooting + +If `CAO_DISCORD_WEBHOOK_URL` is missing, `PluginRegistry.load()` logs a warning during `cao-server` startup and skips registering the plugin for the lifetime of that server process. diff --git a/examples/plugins/cao-discord/cao_discord/__init__.py b/examples/plugins/cao-discord/cao_discord/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/plugins/cao-discord/cao_discord/plugin.py b/examples/plugins/cao-discord/cao_discord/plugin.py new file mode 100644 index 00000000..44fb6308 --- /dev/null +++ b/examples/plugins/cao-discord/cao_discord/plugin.py @@ -0,0 +1,74 @@ +"""Discord plugin lifecycle, hook handling, and webhook dispatch.""" + +import logging +import os + +import httpx +from dotenv import find_dotenv, load_dotenv + +from cli_agent_orchestrator.clients.database import get_terminal_metadata +from cli_agent_orchestrator.plugins import PostSendMessageEvent, hook +from cli_agent_orchestrator.plugins.base import CaoPlugin + +logger = logging.getLogger(__name__) + + +class DiscordPlugin(CaoPlugin): + """Discord webhook plugin for CAO inter-agent messaging events.""" + + _webhook_url: str + _client: httpx.AsyncClient + + async def setup(self) -> None: + """Load configuration and initialize the HTTP client.""" + + load_dotenv(find_dotenv(usecwd=True)) + + webhook_url = os.environ.get("CAO_DISCORD_WEBHOOK_URL") + if not webhook_url: + raise RuntimeError( + "CAO_DISCORD_WEBHOOK_URL is not set. " + "Set it in the environment or in a .env file before starting cao-server." + ) + + self._webhook_url = webhook_url + timeout = float(os.environ.get("CAO_DISCORD_TIMEOUT_SECONDS", "5.0")) + self._client = httpx.AsyncClient(timeout=timeout) + + async def teardown(self) -> None: + """Close the HTTP client when setup completed successfully.""" + + if hasattr(self, "_client"): + await self._client.aclose() + + @hook("post_send_message") + async def on_post_send_message(self, event: PostSendMessageEvent) -> None: + """Forward post-send-message events to the configured Discord webhook.""" + + display_name = self._resolve_display_name(event.sender) + await self._post(username=display_name, content=event.message) + + def _resolve_display_name(self, terminal_id: str) -> str: + """Resolve a human-friendly sender name from terminal metadata.""" + + metadata = get_terminal_metadata(terminal_id) + if metadata is None: + return terminal_id + return metadata.get("tmux_window") or terminal_id + + async def _post(self, *, username: str, content: str) -> None: + """Send a Discord webhook payload and swallow all HTTP failures.""" + + try: + response = await self._client.post( + self._webhook_url, + json={"username": username, "content": content}, + ) + if response.status_code >= 400: + logger.warning( + "Discord webhook POST failed: %s %s", + response.status_code, + response.text[:200], + ) + except httpx.HTTPError as exc: + logger.warning("Discord webhook POST raised: %s", exc) diff --git a/examples/plugins/cao-discord/env.template b/examples/plugins/cao-discord/env.template new file mode 100644 index 00000000..424ac018 --- /dev/null +++ b/examples/plugins/cao-discord/env.template @@ -0,0 +1,14 @@ +# cao-discord plugin configuration +# +# Copy this file to `.env` in the directory where you launch `cao-server` from +# (or any parent directory — python-dotenv walks upward from CWD), then fill in +# your webhook URL. Real shell environment variables override values set here. +# +# Create a webhook: Discord channel -> Edit Channel -> Integrations -> Webhooks +# -> New Webhook -> Copy Webhook URL. + +# Required. Full Discord webhook URL. +CAO_DISCORD_WEBHOOK_URL=https://discord.com/api/webhooks/REPLACE_ID/REPLACE_TOKEN + +# Optional. HTTP timeout (seconds) for webhook POSTs. Default: 5.0 +#CAO_DISCORD_TIMEOUT_SECONDS=5.0 diff --git a/examples/plugins/cao-discord/pyproject.toml b/examples/plugins/cao-discord/pyproject.toml new file mode 100644 index 00000000..fb0b9600 --- /dev/null +++ b/examples/plugins/cao-discord/pyproject.toml @@ -0,0 +1,17 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "cao-discord" +version = "0.1.0" +description = "Discord webhook plugin for CLI Agent Orchestrator" +requires-python = ">=3.10" +dependencies = [ + "cli-agent-orchestrator", + "httpx>=0.27", + "python-dotenv>=1.0", +] + +[project.entry-points."cao.plugins"] +discord = "cao_discord.plugin:DiscordPlugin" diff --git a/examples/plugins/cao-discord/tests/__init__.py b/examples/plugins/cao-discord/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/plugins/cao-discord/tests/test_plugin.py b/examples/plugins/cao-discord/tests/test_plugin.py new file mode 100644 index 00000000..c467d72b --- /dev/null +++ b/examples/plugins/cao-discord/tests/test_plugin.py @@ -0,0 +1,325 @@ +"""Tests for Discord plugin configuration, lifecycle, and hook dispatch.""" + +import json + +import httpx +import pytest + +from cao_discord.plugin import DiscordPlugin +from cli_agent_orchestrator.plugins import PostSendMessageEvent + + +def _timeout_values(plugin: DiscordPlugin) -> tuple[float | None, float | None, float | None, float | None]: + """Return the configured timeout values from the plugin's HTTP client.""" + + timeout = plugin._client.timeout + return timeout.connect, timeout.read, timeout.write, timeout.pool + + +async def _replace_client_with_mock_transport( + plugin: DiscordPlugin, handler: httpx.MockTransport +) -> None: + """Swap in a mock transport-backed client and close the setup client first.""" + + await plugin._client.aclose() + plugin._client = httpx.AsyncClient(transport=handler) + + +@pytest.mark.asyncio +async def test_setup_raises_when_webhook_url_is_missing( + monkeypatch: pytest.MonkeyPatch, tmp_path +) -> None: + """Missing configuration should raise a RuntimeError with guidance.""" + + monkeypatch.delenv("CAO_DISCORD_WEBHOOK_URL", raising=False) + monkeypatch.delenv("CAO_DISCORD_TIMEOUT_SECONDS", raising=False) + monkeypatch.chdir(tmp_path) + monkeypatch.setattr("cao_discord.plugin.find_dotenv", lambda usecwd=True: "") + + plugin = DiscordPlugin() + + with pytest.raises(RuntimeError, match="CAO_DISCORD_WEBHOOK_URL"): + await plugin.setup() + + +@pytest.mark.asyncio +async def test_setup_reads_webhook_url_from_dotenv( + monkeypatch: pytest.MonkeyPatch, tmp_path +) -> None: + """A .env file in the process CWD should populate the webhook URL.""" + + webhook_url = "https://discord.example/from-dotenv" + (tmp_path / ".env").write_text(f"CAO_DISCORD_WEBHOOK_URL={webhook_url}\n", encoding="utf-8") + + monkeypatch.delenv("CAO_DISCORD_WEBHOOK_URL", raising=False) + monkeypatch.delenv("CAO_DISCORD_TIMEOUT_SECONDS", raising=False) + monkeypatch.chdir(tmp_path) + + plugin = DiscordPlugin() + await plugin.setup() + + assert plugin._webhook_url == webhook_url + await plugin.teardown() + + +@pytest.mark.asyncio +async def test_setup_prefers_process_env_over_dotenv( + monkeypatch: pytest.MonkeyPatch, tmp_path +) -> None: + """Process environment variables should override .env values.""" + + dotenv_url = "https://discord.example/from-dotenv" + env_url = "https://discord.example/from-env" + (tmp_path / ".env").write_text(f"CAO_DISCORD_WEBHOOK_URL={dotenv_url}\n", encoding="utf-8") + + monkeypatch.setenv("CAO_DISCORD_WEBHOOK_URL", env_url) + monkeypatch.delenv("CAO_DISCORD_TIMEOUT_SECONDS", raising=False) + monkeypatch.chdir(tmp_path) + + plugin = DiscordPlugin() + await plugin.setup() + + assert plugin._webhook_url == env_url + await plugin.teardown() + + +@pytest.mark.asyncio +async def test_setup_uses_configured_timeout_or_default( + monkeypatch: pytest.MonkeyPatch, tmp_path +) -> None: + """Timeout should default to 5.0 seconds and honor configured overrides.""" + + monkeypatch.chdir(tmp_path) + monkeypatch.setattr("cao_discord.plugin.find_dotenv", lambda usecwd=True: "") + + default_plugin = DiscordPlugin() + monkeypatch.setenv("CAO_DISCORD_WEBHOOK_URL", "https://discord.example/default-timeout") + monkeypatch.delenv("CAO_DISCORD_TIMEOUT_SECONDS", raising=False) + await default_plugin.setup() + + configured_plugin = DiscordPlugin() + monkeypatch.setenv("CAO_DISCORD_WEBHOOK_URL", "https://discord.example/custom-timeout") + monkeypatch.setenv("CAO_DISCORD_TIMEOUT_SECONDS", "2.5") + await configured_plugin.setup() + + assert _timeout_values(default_plugin) == (5.0, 5.0, 5.0, 5.0) + assert _timeout_values(configured_plugin) == (2.5, 2.5, 2.5, 2.5) + + await default_plugin.teardown() + await configured_plugin.teardown() + + +@pytest.mark.asyncio +async def test_teardown_is_safe_after_failed_setup( + monkeypatch: pytest.MonkeyPatch, tmp_path +) -> None: + """Teardown should be a no-op when setup failed before client creation.""" + + monkeypatch.delenv("CAO_DISCORD_WEBHOOK_URL", raising=False) + monkeypatch.delenv("CAO_DISCORD_TIMEOUT_SECONDS", raising=False) + monkeypatch.chdir(tmp_path) + monkeypatch.setattr("cao_discord.plugin.find_dotenv", lambda usecwd=True: "") + + plugin = DiscordPlugin() + + with pytest.raises(RuntimeError, match="CAO_DISCORD_WEBHOOK_URL"): + await plugin.setup() + + await plugin.teardown() + + +@pytest.mark.asyncio +async def test_on_post_send_message_posts_webhook_payload_with_tmux_window_name( + monkeypatch: pytest.MonkeyPatch, tmp_path +) -> None: + """The hook should send a webhook payload with the resolved display name.""" + + requests: list[dict[str, object]] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests.append( + { + "method": request.method, + "url": str(request.url), + "json": json.loads(request.content.decode("utf-8")), + } + ) + return httpx.Response(204) + + monkeypatch.setenv("CAO_DISCORD_WEBHOOK_URL", "https://discord.example/happy-path") + monkeypatch.delenv("CAO_DISCORD_TIMEOUT_SECONDS", raising=False) + monkeypatch.chdir(tmp_path) + monkeypatch.setattr("cao_discord.plugin.find_dotenv", lambda usecwd=True: "") + monkeypatch.setattr( + "cao_discord.plugin.get_terminal_metadata", + lambda terminal_id: {"tmux_window": "coder-a1b2", "id": terminal_id}, + ) + + plugin = DiscordPlugin() + await plugin.setup() + await _replace_client_with_mock_transport(plugin, httpx.MockTransport(handler)) + + result = await plugin.on_post_send_message( + PostSendMessageEvent( + sender="abc12345", + receiver="def67890", + message="hello", + orchestration_type="send_message", + ) + ) + + assert result is None + assert requests == [ + { + "method": "POST", + "url": "https://discord.example/happy-path", + "json": {"username": "coder-a1b2", "content": "hello"}, + } + ] + + await plugin.teardown() + + +@pytest.mark.asyncio +async def test_on_post_send_message_falls_back_to_terminal_id_when_metadata_is_missing_or_empty( + monkeypatch: pytest.MonkeyPatch, tmp_path +) -> None: + """Missing or empty tmux window metadata should fall back to the sender id.""" + + requests: list[dict[str, str]] = [] + metadata_values = iter([None, {}, {"tmux_window": "", "id": "abc12345"}]) + + def handler(request: httpx.Request) -> httpx.Response: + requests.append(json.loads(request.content.decode("utf-8"))) + return httpx.Response(204) + + monkeypatch.setenv("CAO_DISCORD_WEBHOOK_URL", "https://discord.example/fallback") + monkeypatch.delenv("CAO_DISCORD_TIMEOUT_SECONDS", raising=False) + monkeypatch.chdir(tmp_path) + monkeypatch.setattr("cao_discord.plugin.find_dotenv", lambda usecwd=True: "") + monkeypatch.setattr( + "cao_discord.plugin.get_terminal_metadata", + lambda terminal_id: next(metadata_values), + ) + + plugin = DiscordPlugin() + await plugin.setup() + await _replace_client_with_mock_transport(plugin, httpx.MockTransport(handler)) + + for message in ("first", "second", "third"): + result = await plugin.on_post_send_message( + PostSendMessageEvent( + sender="abc12345", + receiver="def67890", + message=message, + orchestration_type="send_message", + ) + ) + assert result is None + + assert requests == [ + {"username": "abc12345", "content": "first"}, + {"username": "abc12345", "content": "second"}, + {"username": "abc12345", "content": "third"}, + ] + + await plugin.teardown() + + +@pytest.mark.asyncio +async def test_on_post_send_message_logs_warning_for_http_500_without_raising( + monkeypatch: pytest.MonkeyPatch, tmp_path, caplog: pytest.LogCaptureFixture +) -> None: + """HTTP 5xx responses should log a warning and not escape the hook.""" + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(500, text="webhook temporarily broken") + + monkeypatch.setenv("CAO_DISCORD_WEBHOOK_URL", "https://discord.example/server-error") + monkeypatch.delenv("CAO_DISCORD_TIMEOUT_SECONDS", raising=False) + monkeypatch.chdir(tmp_path) + monkeypatch.setattr("cao_discord.plugin.find_dotenv", lambda usecwd=True: "") + monkeypatch.setattr( + "cao_discord.plugin.get_terminal_metadata", + lambda terminal_id: {"tmux_window": "coder-a1b2", "id": terminal_id}, + ) + + plugin = DiscordPlugin() + await plugin.setup() + await _replace_client_with_mock_transport(plugin, httpx.MockTransport(handler)) + + with caplog.at_level("WARNING", logger="cao_discord.plugin"): + result = await plugin.on_post_send_message( + PostSendMessageEvent( + sender="abc12345", + receiver="def67890", + message="hello", + orchestration_type="send_message", + ) + ) + + assert result is None + assert "Discord webhook POST failed: 500 webhook temporarily broken" in caplog.text + + await plugin.teardown() + + +@pytest.mark.asyncio +async def test_on_post_send_message_logs_warning_for_httpx_error_without_raising( + monkeypatch: pytest.MonkeyPatch, tmp_path, caplog: pytest.LogCaptureFixture +) -> None: + """Transport errors should log a warning and not escape the hook.""" + + def handler(request: httpx.Request) -> httpx.Response: + raise httpx.ConnectError("boom", request=request) + + monkeypatch.setenv("CAO_DISCORD_WEBHOOK_URL", "https://discord.example/connect-error") + monkeypatch.delenv("CAO_DISCORD_TIMEOUT_SECONDS", raising=False) + monkeypatch.chdir(tmp_path) + monkeypatch.setattr("cao_discord.plugin.find_dotenv", lambda usecwd=True: "") + monkeypatch.setattr( + "cao_discord.plugin.get_terminal_metadata", + lambda terminal_id: {"tmux_window": "coder-a1b2", "id": terminal_id}, + ) + + plugin = DiscordPlugin() + await plugin.setup() + await _replace_client_with_mock_transport(plugin, httpx.MockTransport(handler)) + + with caplog.at_level("WARNING", logger="cao_discord.plugin"): + result = await plugin.on_post_send_message( + PostSendMessageEvent( + sender="abc12345", + receiver="def67890", + message="hello", + orchestration_type="send_message", + ) + ) + + assert result is None + assert "Discord webhook POST raised: boom" in caplog.text + + await plugin.teardown() + + +@pytest.mark.asyncio +async def test_teardown_closes_real_client_after_successful_setup( + monkeypatch: pytest.MonkeyPatch, tmp_path +) -> None: + """Teardown should close a real AsyncClient instance.""" + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(204) + + monkeypatch.setenv("CAO_DISCORD_WEBHOOK_URL", "https://discord.example/teardown") + monkeypatch.delenv("CAO_DISCORD_TIMEOUT_SECONDS", raising=False) + monkeypatch.chdir(tmp_path) + monkeypatch.setattr("cao_discord.plugin.find_dotenv", lambda usecwd=True: "") + + plugin = DiscordPlugin() + await plugin.setup() + await _replace_client_with_mock_transport(plugin, httpx.MockTransport(handler)) + + await plugin.teardown() + + assert plugin._client.is_closed is True diff --git a/src/cli_agent_orchestrator/api/main.py b/src/cli_agent_orchestrator/api/main.py index e6867db2..a3477596 100644 --- a/src/cli_agent_orchestrator/api/main.py +++ b/src/cli_agent_orchestrator/api/main.py @@ -12,9 +12,9 @@ import termios from contextlib import asynccontextmanager from pathlib import Path -from typing import Annotated, Dict, List, Optional +from typing import Annotated, Dict, List, Optional, cast -from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect, status +from fastapi import FastAPI, HTTPException, Query, Request, WebSocket, WebSocketDisconnect, status from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.trustedhost import TrustedHostMiddleware from pydantic import BaseModel, Field, field_validator @@ -37,8 +37,9 @@ TERMINAL_LOG_DIR, ) from cli_agent_orchestrator.models.flow import Flow -from cli_agent_orchestrator.models.inbox import MessageStatus +from cli_agent_orchestrator.models.inbox import MessageStatus, OrchestrationType from cli_agent_orchestrator.models.terminal import Terminal, TerminalId +from cli_agent_orchestrator.plugins import PluginRegistry from cli_agent_orchestrator.providers.manager import provider_manager from cli_agent_orchestrator.services import ( flow_service, @@ -127,6 +128,9 @@ async def lifespan(app: FastAPI): logger.info("Starting CLI Agent Orchestrator server...") setup_logging() init_db() + registry = PluginRegistry() + await registry.load() + app.state.plugin_registry = registry # Run cleanup in background asyncio.create_task(asyncio.to_thread(cleanup_old_data)) @@ -136,7 +140,7 @@ async def lifespan(app: FastAPI): # Start inbox watcher inbox_observer = PollingObserver(timeout=INBOX_POLLING_INTERVAL) - inbox_observer.schedule(LogFileHandler(), str(TERMINAL_LOG_DIR), recursive=False) + inbox_observer.schedule(LogFileHandler(registry), str(TERMINAL_LOG_DIR), recursive=False) inbox_observer.start() logger.info("Inbox watcher started (PollingObserver)") @@ -154,9 +158,16 @@ async def lifespan(app: FastAPI): except asyncio.CancelledError: pass + await registry.teardown() logger.info("Shutting down CLI Agent Orchestrator server...") +def get_plugin_registry(request: Request) -> PluginRegistry: + """Return the plugin registry stored on the FastAPI application state.""" + + return cast(PluginRegistry, request.app.state.plugin_registry) + + app = FastAPI( title="CLI Agent Orchestrator", description="Simplified CLI Agent Orchestrator API", @@ -289,6 +300,7 @@ async def get_skill_content(name: str) -> SkillContentResponse: @app.post("/sessions", response_model=Terminal, status_code=status.HTTP_201_CREATED) async def create_session( + request: Request, provider: str, agent_profile: str, session_name: Optional[str] = None, @@ -300,13 +312,13 @@ async def create_session( # Parse comma-separated allowed_tools string into list allowed_tools_list = allowed_tools.split(",") if allowed_tools else None - result = terminal_service.create_terminal( + result = session_service.create_session( provider=provider, agent_profile=agent_profile, session_name=session_name, - new_session=True, working_directory=working_directory, allowed_tools=allowed_tools_list, + registry=get_plugin_registry(request), ) return result @@ -344,9 +356,9 @@ async def get_session(session_name: str) -> Dict: @app.delete("/sessions/{session_name}") -async def delete_session(session_name: str) -> Dict: +async def delete_session(request: Request, session_name: str) -> Dict: try: - result = session_service.delete_session(session_name) + result = session_service.delete_session(session_name, registry=get_plugin_registry(request)) return {"success": True, **result} except ValueError as e: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) @@ -363,6 +375,7 @@ async def delete_session(session_name: str) -> Dict: status_code=status.HTTP_201_CREATED, ) async def create_terminal_in_session( + request: Request, session_name: str, provider: str, agent_profile: str, @@ -383,6 +396,7 @@ async def create_terminal_in_session( new_session=False, working_directory=working_directory, allowed_tools=allowed_tools_list, + registry=get_plugin_registry(request), ) return result except ValueError as e: @@ -438,9 +452,24 @@ async def get_terminal_working_directory(terminal_id: TerminalId) -> WorkingDire @app.post("/terminals/{terminal_id}/input") -async def send_terminal_input(terminal_id: TerminalId, message: str) -> Dict: +async def send_terminal_input( + request: Request, + terminal_id: TerminalId, + message: str, + sender_id: Optional[str] = None, + orchestration_type: Optional[OrchestrationType] = None, +) -> Dict: try: - success = terminal_service.send_input(terminal_id, message) + if sender_id is None or orchestration_type is None: + success = terminal_service.send_input(terminal_id, message) + else: + success = terminal_service.send_input( + terminal_id, + message, + registry=get_plugin_registry(request), + sender_id=sender_id, + orchestration_type=orchestration_type, + ) return {"success": success} except ValueError as e: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) @@ -493,10 +522,12 @@ async def exit_terminal(terminal_id: TerminalId) -> Dict: @app.delete("/terminals/{terminal_id}") -async def delete_terminal(terminal_id: TerminalId) -> Dict: +async def delete_terminal(request: Request, terminal_id: TerminalId) -> Dict: """Delete a terminal.""" try: - success = terminal_service.delete_terminal(terminal_id) + success = terminal_service.delete_terminal( + terminal_id, registry=get_plugin_registry(request) + ) return {"success": success} except ValueError as e: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) @@ -509,11 +540,18 @@ async def delete_terminal(terminal_id: TerminalId) -> Dict: @app.post("/terminals/{receiver_id}/inbox/messages") async def create_inbox_message_endpoint( - receiver_id: TerminalId, sender_id: str, message: str + request: Request, + receiver_id: TerminalId, + sender_id: str, + message: str, ) -> Dict: """Create inbox message and attempt immediate delivery.""" try: - inbox_msg = create_inbox_message(sender_id, receiver_id, message) + inbox_msg = create_inbox_message( + sender_id, + receiver_id, + message, + ) except ValueError as e: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) except Exception as e: @@ -527,7 +565,9 @@ async def create_inbox_message_endpoint( # the terminal becomes idle. Delivery failures must not cause the API # to report an error — the message was already persisted above. try: - inbox_service.check_and_send_pending_messages(receiver_id) + inbox_service.check_and_send_pending_messages( + receiver_id, registry=get_plugin_registry(request) + ) except Exception as e: logger.warning(f"Immediate delivery attempt failed for {receiver_id}: {e}") diff --git a/src/cli_agent_orchestrator/mcp_server/server.py b/src/cli_agent_orchestrator/mcp_server/server.py index 505735b9..fb3b78cb 100644 --- a/src/cli_agent_orchestrator/mcp_server/server.py +++ b/src/cli_agent_orchestrator/mcp_server/server.py @@ -12,6 +12,7 @@ from cli_agent_orchestrator.constants import API_BASE_URL, DEFAULT_PROVIDER from cli_agent_orchestrator.mcp_server.models import HandoffResult +from cli_agent_orchestrator.models.inbox import OrchestrationType from cli_agent_orchestrator.models.terminal import TerminalStatus from cli_agent_orchestrator.utils.terminal import generate_session_name, wait_until_terminal_status @@ -177,18 +178,26 @@ def _create_terminal( return terminal["id"], provider -def _send_direct_input(terminal_id: str, message: str) -> None: +def _send_direct_input( + terminal_id: str, message: str, orchestration_type: OrchestrationType +) -> None: """Send input directly to a terminal (bypasses inbox). Args: terminal_id: Terminal ID message: Message to send + orchestration_type: Orchestration mode for plugin event emission Raises: Exception: If sending fails """ response = requests.post( - f"{API_BASE_URL}/terminals/{terminal_id}/input", params={"message": message} + f"{API_BASE_URL}/terminals/{terminal_id}/input", + params={ + "message": message, + "sender_id": os.environ.get("CAO_TERMINAL_ID", "supervisor"), + "orchestration_type": orchestration_type, + }, ) response.raise_for_status() @@ -211,7 +220,7 @@ def _send_direct_input_handoff(terminal_id: str, provider: str, message: str) -> else: handoff_message = message - _send_direct_input(terminal_id, handoff_message) + _send_direct_input(terminal_id, handoff_message, OrchestrationType.HANDOFF) def _send_direct_input_assign(terminal_id: str, message: str) -> None: @@ -224,7 +233,7 @@ def _send_direct_input_assign(terminal_id: str, message: str) -> None: f"When done, send results back to terminal {sender_id} using send_message]" ) - _send_direct_input(terminal_id, message) + _send_direct_input(terminal_id, message, OrchestrationType.ASSIGN) def _send_to_inbox(receiver_id: str, message: str) -> Dict[str, Any]: @@ -247,7 +256,10 @@ def _send_to_inbox(receiver_id: str, message: str) -> Dict[str, Any]: response = requests.post( f"{API_BASE_URL}/terminals/{receiver_id}/inbox/messages", - params={"sender_id": sender_id, "message": message}, + params={ + "sender_id": sender_id, + "message": message, + }, ) response.raise_for_status() return response.json() diff --git a/src/cli_agent_orchestrator/models/inbox.py b/src/cli_agent_orchestrator/models/inbox.py index 91996ba8..d8ce7a69 100644 --- a/src/cli_agent_orchestrator/models/inbox.py +++ b/src/cli_agent_orchestrator/models/inbox.py @@ -6,6 +6,14 @@ from pydantic import BaseModel, Field +class OrchestrationType(str, Enum): + """Orchestration mode for a message delivery.""" + + SEND_MESSAGE = "send_message" + HANDOFF = "handoff" + ASSIGN = "assign" + + class MessageStatus(str, Enum): """Message status enumeration.""" diff --git a/src/cli_agent_orchestrator/plugins/__init__.py b/src/cli_agent_orchestrator/plugins/__init__.py new file mode 100644 index 00000000..c8661cf9 --- /dev/null +++ b/src/cli_agent_orchestrator/plugins/__init__.py @@ -0,0 +1,24 @@ +"""Public API for the CAO plugin system.""" + +from cli_agent_orchestrator.plugins.base import CaoPlugin, hook +from cli_agent_orchestrator.plugins.events import ( + CaoEvent, + PostCreateSessionEvent, + PostCreateTerminalEvent, + PostKillSessionEvent, + PostKillTerminalEvent, + PostSendMessageEvent, +) +from cli_agent_orchestrator.plugins.registry import PluginRegistry + +__all__ = [ + "CaoPlugin", + "hook", + "CaoEvent", + "PostSendMessageEvent", + "PostCreateSessionEvent", + "PostKillSessionEvent", + "PostCreateTerminalEvent", + "PostKillTerminalEvent", + "PluginRegistry", +] diff --git a/src/cli_agent_orchestrator/plugins/base.py b/src/cli_agent_orchestrator/plugins/base.py new file mode 100644 index 00000000..a3036a4a --- /dev/null +++ b/src/cli_agent_orchestrator/plugins/base.py @@ -0,0 +1,52 @@ +"""Plugin base class and hook decorator for CAO plugins. + +This module defines the marker base class plugin authors subclass and the +decorator used to associate async plugin methods with CAO event types. +""" + +from typing import Awaitable, Callable, ParamSpec, TypeVar + +P = ParamSpec("P") +R = TypeVar("R") +AsyncMethodT = Callable[P, Awaitable[R]] + +_HOOK_EVENT_ATTR = "_cao_hook_event" + + +class CaoPlugin: + """Base class for CAO plugins. + + Subclass this and declare hooks with the @hook decorator. + Register the subclass via the `cao.plugins` entry point group. + """ + + async def setup(self) -> None: + """Called once after instantiation on server startup. + + Override to open connections, load config, or initialize state. + """ + + async def teardown(self) -> None: + """Called once on server shutdown. + + Override to close connections or flush buffers. + """ + + +def hook(event_type: str) -> Callable[[AsyncMethodT[P, R]], AsyncMethodT[P, R]]: + """Decorator that registers a plugin method as a hook for a CAO event. + + Args: + event_type: The CAO event type to listen for (e.g. "post_send_message"). + + Example: + @hook("post_send_message") + async def notify(self, event: PostSendMessageEvent) -> None: + ... + """ + + def decorator(fn: AsyncMethodT[P, R]) -> AsyncMethodT[P, R]: + setattr(fn, _HOOK_EVENT_ATTR, event_type) + return fn + + return decorator diff --git a/src/cli_agent_orchestrator/plugins/events.py b/src/cli_agent_orchestrator/plugins/events.py new file mode 100644 index 00000000..6fc415d2 --- /dev/null +++ b/src/cli_agent_orchestrator/plugins/events.py @@ -0,0 +1,75 @@ +"""Typed plugin event dataclasses for CAO lifecycle and messaging hooks.""" + +from dataclasses import dataclass, field +from datetime import datetime, timezone + + +def _utc_now() -> datetime: + """Return the current UTC time as a timezone-aware datetime.""" + + return datetime.now(timezone.utc) + + +@dataclass +class CaoEvent: + """Base class for all CAO plugin events.""" + + # Empty by default so the base dataclass is zero-arg constructible for Phase 1 tests. + event_type: str = "" + timestamp: datetime = field(default_factory=_utc_now) + session_id: str | None = None + + +@dataclass +class PostSendMessageEvent(CaoEvent): + """Emitted after a message is dispatched to an agent's inbox. + + Fired for all three orchestration methods: + - send_message: direct message to an existing terminal + - handoff: message sent as part of a synchronous handoff + - assign: message sent as part of an asynchronous assign + + Orchestration methods like assign span multiple steps and may therefore + emit more than one PostSendMessageEvent across their lifecycle. + """ + + event_type: str = "post_send_message" + sender: str = "" + receiver: str = "" + message: str = "" + orchestration_type: str = "" + + +@dataclass +class PostCreateSessionEvent(CaoEvent): + """Emitted after a CAO session is created.""" + + event_type: str = "post_create_session" + session_name: str = "" + + +@dataclass +class PostKillSessionEvent(CaoEvent): + """Emitted after a CAO session is killed.""" + + event_type: str = "post_kill_session" + session_name: str = "" + + +@dataclass +class PostCreateTerminalEvent(CaoEvent): + """Emitted after a CAO terminal is created.""" + + event_type: str = "post_create_terminal" + terminal_id: str = "" + agent_name: str | None = None + provider: str = "" + + +@dataclass +class PostKillTerminalEvent(CaoEvent): + """Emitted after a CAO terminal is killed.""" + + event_type: str = "post_kill_terminal" + terminal_id: str = "" + agent_name: str | None = None diff --git a/src/cli_agent_orchestrator/plugins/registry.py b/src/cli_agent_orchestrator/plugins/registry.py new file mode 100644 index 00000000..95cf0e37 --- /dev/null +++ b/src/cli_agent_orchestrator/plugins/registry.py @@ -0,0 +1,87 @@ +"""Plugin discovery, registration, dispatch, and lifecycle management.""" + +import importlib.metadata +import inspect +import logging +from typing import Any + +from cli_agent_orchestrator.plugins.base import _HOOK_EVENT_ATTR, CaoPlugin +from cli_agent_orchestrator.plugins.events import CaoEvent + +logger = logging.getLogger(__name__) + +ENTRY_POINT_GROUP = "cao.plugins" + + +class PluginRegistry: + """Registry for discovered CAO plugins and their hook handlers.""" + + def __init__(self) -> None: + """Initialize an empty plugin registry.""" + + self._plugins: list[CaoPlugin] = [] + self._dispatch: dict[str, list[Any]] = {} + + async def load(self) -> None: + """Discover, instantiate, and set up all registered CAO plugins.""" + + entry_points = importlib.metadata.entry_points(group=ENTRY_POINT_GROUP) + for entry_point in entry_points: + try: + plugin_class = entry_point.load() + if not (isinstance(plugin_class, type) and issubclass(plugin_class, CaoPlugin)): + logger.warning( + "Plugin entry point '%s' is not a CaoPlugin subclass, skipping", + entry_point.name, + ) + continue + + plugin = plugin_class() + await plugin.setup() + self._register(plugin) + logger.info("Loaded CAO plugin: %s", entry_point.name) + except Exception: + logger.warning( + "Failed to load plugin '%s'", + entry_point.name, + exc_info=True, + ) + + if not self._plugins: + logger.info("No CAO plugins registered (cao.plugins entry point group is empty)") + + def _register(self, plugin: CaoPlugin) -> None: + """Register a plugin instance and index any decorated hook methods.""" + + self._plugins.append(plugin) + for _, method in inspect.getmembers(plugin, predicate=inspect.ismethod): + event_type = getattr(method, _HOOK_EVENT_ATTR, None) + if event_type is not None: + self._dispatch.setdefault(event_type, []).append(method) + + async def dispatch(self, event_type: str, event: CaoEvent) -> None: + """Dispatch an event to all matching plugin hook handlers.""" + + for handler in self._dispatch.get(event_type, []): + try: + await handler(event) + except Exception: + logger.warning( + "Hook '%s' raised an error for event '%s'", + handler.__qualname__, + event_type, + exc_info=True, + ) + + async def teardown(self) -> None: + """Call teardown() on every loaded plugin, continuing after failures.""" + + for plugin in self._plugins: + try: + await plugin.teardown() + except Exception: + logger.warning( + "Plugin teardown failed for %s", + type(plugin).__name__, + exc_info=True, + ) diff --git a/src/cli_agent_orchestrator/services/inbox_service.py b/src/cli_agent_orchestrator/services/inbox_service.py index 6761518e..3c2a55cf 100644 --- a/src/cli_agent_orchestrator/services/inbox_service.py +++ b/src/cli_agent_orchestrator/services/inbox_service.py @@ -30,8 +30,9 @@ from cli_agent_orchestrator.clients.database import get_pending_messages, update_message_status from cli_agent_orchestrator.constants import TERMINAL_LOG_DIR -from cli_agent_orchestrator.models.inbox import MessageStatus +from cli_agent_orchestrator.models.inbox import MessageStatus, OrchestrationType from cli_agent_orchestrator.models.terminal import TerminalStatus +from cli_agent_orchestrator.plugins import PluginRegistry from cli_agent_orchestrator.providers.manager import provider_manager from cli_agent_orchestrator.services import terminal_service @@ -71,7 +72,9 @@ def _has_idle_pattern(terminal_id: str) -> bool: return False -def check_and_send_pending_messages(terminal_id: str) -> bool: +def check_and_send_pending_messages( + terminal_id: str, registry: PluginRegistry | None = None +) -> bool: """Check for pending messages and send if terminal is ready. Args: @@ -105,9 +108,21 @@ def check_and_send_pending_messages(terminal_id: str) -> bool: logger.debug(f"Terminal {terminal_id} not ready (status={status})") return False - # Send message + # Send message. Inbox-queued delivery is only reached via the send_message + # MCP tool, so the orchestration_type is always "send_message" here — the + # synchronous handoff/assign paths bypass the inbox and pass their own + # orchestration_type directly to send_input(). try: - terminal_service.send_input(terminal_id, message.message) + if registry is None: + terminal_service.send_input(terminal_id, message.message) + else: + terminal_service.send_input( + terminal_id, + message.message, + registry=registry, + sender_id=message.sender_id, + orchestration_type=OrchestrationType.SEND_MESSAGE, + ) update_message_status(message.id, MessageStatus.DELIVERED) logger.info(f"Delivered message {message.id} to terminal {terminal_id}") return True @@ -120,6 +135,12 @@ def check_and_send_pending_messages(terminal_id: str) -> bool: class LogFileHandler(FileSystemEventHandler): """Handler for terminal log file changes.""" + def __init__(self, registry: PluginRegistry | None = None) -> None: + """Initialize the log file handler with an optional plugin registry.""" + + super().__init__() + self._registry = registry + def on_modified(self, event): """Handle file modification events.""" if isinstance(event, FileModifiedEvent) and event.src_path.endswith(".log"): @@ -145,7 +166,7 @@ def _handle_log_change(self, terminal_id: str): return # Attempt delivery - check_and_send_pending_messages(terminal_id) + check_and_send_pending_messages(terminal_id, registry=self._registry) except Exception as e: logger.error(f"Error handling log change for {terminal_id}: {e}") diff --git a/src/cli_agent_orchestrator/services/plugin_dispatch.py b/src/cli_agent_orchestrator/services/plugin_dispatch.py new file mode 100644 index 00000000..9cded2ff --- /dev/null +++ b/src/cli_agent_orchestrator/services/plugin_dispatch.py @@ -0,0 +1,42 @@ +"""Helpers for emitting plugin events from synchronous service functions.""" + +import asyncio +import logging + +from cli_agent_orchestrator.plugins import CaoEvent, PluginRegistry + +logger = logging.getLogger(__name__) + + +async def _dispatch_with_logging( + registry: PluginRegistry, event_type: str, event: CaoEvent +) -> None: + """Run registry dispatch with local error isolation at the adapter boundary.""" + + try: + await registry.dispatch(event_type, event) + except Exception: + logger.warning("Plugin event dispatch failed for %s", event_type, exc_info=True) + + +def dispatch_plugin_event( + registry: PluginRegistry | None, event_type: str, event: CaoEvent +) -> None: + """Dispatch a plugin event without forcing a broad async refactor. + + If called inside a running event loop (the common FastAPI path), the + dispatch coroutine is scheduled as a background task. If no loop is + running (for synchronous code paths and unit tests), the dispatch runs to + completion via ``asyncio.run``. + """ + + if registry is None: + return + + coroutine = _dispatch_with_logging(registry, event_type, event) + try: + loop = asyncio.get_running_loop() + except RuntimeError: + asyncio.run(coroutine) + else: + loop.create_task(coroutine) diff --git a/src/cli_agent_orchestrator/services/session_service.py b/src/cli_agent_orchestrator/services/session_service.py index 47c1dea4..3d24a9e1 100644 --- a/src/cli_agent_orchestrator/services/session_service.py +++ b/src/cli_agent_orchestrator/services/session_service.py @@ -28,11 +28,49 @@ ) from cli_agent_orchestrator.clients.tmux import tmux_client from cli_agent_orchestrator.constants import SESSION_PREFIX +from cli_agent_orchestrator.models.terminal import Terminal +from cli_agent_orchestrator.plugins import ( + PluginRegistry, + PostCreateSessionEvent, + PostKillSessionEvent, +) from cli_agent_orchestrator.providers.manager import provider_manager +from cli_agent_orchestrator.services.plugin_dispatch import dispatch_plugin_event +from cli_agent_orchestrator.services.terminal_service import create_terminal logger = logging.getLogger(__name__) +def create_session( + provider: str, + agent_profile: str, + session_name: str | None = None, + working_directory: str | None = None, + allowed_tools: list[str] | None = None, + registry: PluginRegistry | None = None, +) -> Terminal: + """Create a new session by creating its initial terminal.""" + + terminal = create_terminal( + provider=provider, + agent_profile=agent_profile, + session_name=session_name, + new_session=True, + working_directory=working_directory, + allowed_tools=allowed_tools, + registry=registry, + ) + dispatch_plugin_event( + registry, + "post_create_session", + PostCreateSessionEvent( + session_id=terminal.session_name, + session_name=terminal.session_name, + ), + ) + return terminal + + def list_sessions() -> List[Dict]: """List all sessions from tmux.""" try: @@ -63,7 +101,7 @@ def get_session(session_name: str) -> Dict: raise -def delete_session(session_name: str) -> Dict: +def delete_session(session_name: str, registry: PluginRegistry | None = None) -> Dict: """Delete session and cleanup. Returns: @@ -91,6 +129,11 @@ def delete_session(session_name: str) -> Dict: result["deleted"].append(session_name) logger.info(f"Deleted session: {session_name}") + dispatch_plugin_event( + registry, + "post_kill_session", + PostKillSessionEvent(session_id=session_name, session_name=session_name), + ) return result except Exception as e: diff --git a/src/cli_agent_orchestrator/services/terminal_service.py b/src/cli_agent_orchestrator/services/terminal_service.py index 207086c9..fbb3c247 100644 --- a/src/cli_agent_orchestrator/services/terminal_service.py +++ b/src/cli_agent_orchestrator/services/terminal_service.py @@ -31,9 +31,17 @@ ) from cli_agent_orchestrator.clients.tmux import tmux_client from cli_agent_orchestrator.constants import SESSION_PREFIX, TERMINAL_LOG_DIR +from cli_agent_orchestrator.models.inbox import OrchestrationType from cli_agent_orchestrator.models.provider import ProviderType from cli_agent_orchestrator.models.terminal import Terminal, TerminalStatus +from cli_agent_orchestrator.plugins import ( + PluginRegistry, + PostCreateTerminalEvent, + PostKillTerminalEvent, + PostSendMessageEvent, +) from cli_agent_orchestrator.providers.manager import provider_manager +from cli_agent_orchestrator.services.plugin_dispatch import dispatch_plugin_event from cli_agent_orchestrator.utils.agent_profiles import load_agent_profile from cli_agent_orchestrator.utils.skills import build_skill_catalog from cli_agent_orchestrator.utils.terminal import ( @@ -73,7 +81,8 @@ def create_terminal( session_name: Optional[str] = None, new_session: bool = False, working_directory: Optional[str] = None, - allowed_tools: Optional[list] = None, + allowed_tools: Optional[list[str]] = None, + registry: PluginRegistry | None = None, ) -> Terminal: """Create a new terminal with an initialized CLI agent. @@ -186,6 +195,16 @@ def create_terminal( logger.info( f"Created terminal: {terminal_id} in session: {session_name} (new_session={new_session})" ) + dispatch_plugin_event( + registry, + "post_create_terminal", + PostCreateTerminalEvent( + session_id=terminal.session_name, + terminal_id=terminal.id, + agent_name=terminal.agent_profile, + provider=provider, + ), + ) return terminal except Exception as e: @@ -260,7 +279,13 @@ def get_working_directory(terminal_id: str) -> Optional[str]: raise -def send_input(terminal_id: str, message: str) -> bool: +def send_input( + terminal_id: str, + message: str, + registry: PluginRegistry | None = None, + sender_id: str | None = None, + orchestration_type: OrchestrationType | None = None, +) -> bool: """Send input to terminal via tmux paste buffer. Uses bracketed paste mode (-p) to bypass TUI hotkey handling. The number @@ -290,6 +315,18 @@ def send_input(terminal_id: str, message: str) -> bool: update_last_active(terminal_id) logger.info(f"Sent input to terminal: {terminal_id}") + if registry is not None and sender_id is not None and orchestration_type is not None: + dispatch_plugin_event( + registry, + "post_send_message", + PostSendMessageEvent( + session_id=metadata["tmux_session"], + sender=sender_id, + receiver=terminal_id, + message=message, + orchestration_type=orchestration_type, + ), + ) return True except Exception as e: @@ -377,7 +414,7 @@ def get_output(terminal_id: str, mode: OutputMode = OutputMode.FULL) -> str: raise -def delete_terminal(terminal_id: str) -> bool: +def delete_terminal(terminal_id: str, registry: PluginRegistry | None = None) -> bool: """Delete terminal and kill its tmux window.""" try: # Get metadata before deletion @@ -400,6 +437,16 @@ def delete_terminal(terminal_id: str) -> bool: provider_manager.cleanup_provider(terminal_id) deleted = db_delete_terminal(terminal_id) logger.info(f"Deleted terminal: {terminal_id}") + if deleted and metadata: + dispatch_plugin_event( + registry, + "post_kill_terminal", + PostKillTerminalEvent( + session_id=metadata["tmux_session"], + terminal_id=terminal_id, + agent_name=metadata.get("agent_profile"), + ), + ) return deleted except Exception as e: diff --git a/test/api/conftest.py b/test/api/conftest.py index 79002217..4d1c822f 100644 --- a/test/api/conftest.py +++ b/test/api/conftest.py @@ -4,6 +4,7 @@ from fastapi.testclient import TestClient from cli_agent_orchestrator.api.main import app +from cli_agent_orchestrator.plugins import PluginRegistry class TestClientWithHost(TestClient): @@ -27,4 +28,5 @@ def request(self, method, url, **kwargs): @pytest.fixture def client(): """Test client with proper Host header for security middleware.""" + app.state.plugin_registry = PluginRegistry() return TestClientWithHost(app) diff --git a/test/api/test_api_endpoints.py b/test/api/test_api_endpoints.py index 09d43ef5..ea1d35bf 100644 --- a/test/api/test_api_endpoints.py +++ b/test/api/test_api_endpoints.py @@ -7,7 +7,7 @@ import asyncio from datetime import datetime -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import ANY, AsyncMock, MagicMock, patch import pytest @@ -236,8 +236,8 @@ def test_create_session_success(self, client): provider="kiro_cli", agent_profile="developer", ) - with patch("cli_agent_orchestrator.api.main.terminal_service") as mock_svc: - mock_svc.create_terminal.return_value = mock_terminal + with patch("cli_agent_orchestrator.api.main.session_service") as mock_svc: + mock_svc.create_session.return_value = mock_terminal response = client.post( "/sessions", @@ -252,6 +252,14 @@ def test_create_session_success(self, client): assert data["id"] == "abcd1234" assert data["provider"] == "kiro_cli" assert data["agent_profile"] == "developer" + mock_svc.create_session.assert_called_once_with( + provider="kiro_cli", + agent_profile="developer", + session_name=None, + working_directory=None, + allowed_tools=None, + registry=ANY, + ) def test_create_session_with_session_name(self, client): """POST /sessions with explicit session_name.""" @@ -262,8 +270,8 @@ def test_create_session_with_session_name(self, client): provider="q_cli", agent_profile="developer", ) - with patch("cli_agent_orchestrator.api.main.terminal_service") as mock_svc: - mock_svc.create_terminal.return_value = mock_terminal + with patch("cli_agent_orchestrator.api.main.session_service") as mock_svc: + mock_svc.create_session.return_value = mock_terminal response = client.post( "/sessions", @@ -275,14 +283,14 @@ def test_create_session_with_session_name(self, client): ) assert response.status_code == 201 - call_kwargs = mock_svc.create_terminal.call_args.kwargs + call_kwargs = mock_svc.create_session.call_args.kwargs assert call_kwargs["session_name"] == "my-custom-session" - assert call_kwargs["new_session"] is True + assert call_kwargs["registry"] is not None def test_create_session_value_error(self, client): """POST /sessions returns 400 on ValueError.""" - with patch("cli_agent_orchestrator.api.main.terminal_service") as mock_svc: - mock_svc.create_terminal.side_effect = ValueError("Invalid provider") + with patch("cli_agent_orchestrator.api.main.session_service") as mock_svc: + mock_svc.create_session.side_effect = ValueError("Invalid provider") response = client.post( "/sessions", @@ -297,8 +305,8 @@ def test_create_session_value_error(self, client): def test_create_session_server_error(self, client): """POST /sessions returns 500 on unexpected error.""" - with patch("cli_agent_orchestrator.api.main.terminal_service") as mock_svc: - mock_svc.create_terminal.side_effect = Exception("TMux crashed") + with patch("cli_agent_orchestrator.api.main.session_service") as mock_svc: + mock_svc.create_session.side_effect = Exception("TMux crashed") response = client.post( "/sessions", @@ -408,7 +416,7 @@ def test_delete_session_success(self, client): data = response.json() assert data["success"] is True assert data["deleted"] == ["test-session"] - mock_svc.delete_session.assert_called_once_with("test-session") + mock_svc.delete_session.assert_called_once_with("test-session", registry=ANY) def test_delete_session_not_found(self, client): """DELETE /sessions/{name} returns 404 for nonexistent session.""" @@ -610,6 +618,29 @@ def test_send_input_success(self, client): assert data["success"] is True mock_svc.send_input.assert_called_once_with("abcd1234", "hello world") + def test_send_input_with_orchestration_context(self, client): + """POST /terminals/{id}/input forwards registry and orchestration metadata when provided.""" + with patch("cli_agent_orchestrator.api.main.terminal_service") as mock_svc: + mock_svc.send_input.return_value = True + + response = client.post( + "/terminals/abcd1234/input", + params={ + "message": "hello world", + "sender_id": "supervisor-1", + "orchestration_type": "assign", + }, + ) + + assert response.status_code == 200 + mock_svc.send_input.assert_called_once_with( + "abcd1234", + "hello world", + registry=ANY, + sender_id="supervisor-1", + orchestration_type="assign", + ) + def test_send_input_terminal_not_found(self, client): """POST /terminals/{id}/input returns 404 for nonexistent terminal.""" with patch("cli_agent_orchestrator.api.main.terminal_service") as mock_svc: @@ -698,7 +729,7 @@ def test_delete_terminal_success(self, client): assert response.status_code == 200 data = response.json() assert data["success"] is True - mock_svc.delete_terminal.assert_called_once_with("abcd1234") + mock_svc.delete_terminal.assert_called_once_with("abcd1234", registry=ANY) def test_delete_terminal_not_found(self, client): """DELETE /terminals/{id} returns 404 for nonexistent terminal.""" diff --git a/test/api/test_plugin_lifespan.py b/test/api/test_plugin_lifespan.py new file mode 100644 index 00000000..a8c435ec --- /dev/null +++ b/test/api/test_plugin_lifespan.py @@ -0,0 +1,124 @@ +"""Integration tests for plugin registry FastAPI lifespan wiring.""" + +import logging +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import Request + +from cli_agent_orchestrator.api.main import app, get_plugin_registry, lifespan +from cli_agent_orchestrator.plugins import CaoPlugin, PluginRegistry, hook +from cli_agent_orchestrator.plugins.events import PostSendMessageEvent + + +async def fake_flow_daemon() -> None: + """Minimal async flow daemon stub for lifespan tests.""" + + +class TestPluginRegistryLifespan: + """Tests for plugin registry startup, app state wiring, and teardown.""" + + @pytest.mark.asyncio + async def test_lifespan_stores_registry_and_tears_it_down(self) -> None: + """The lifespan should create, store, expose, and tear down the registry.""" + + mock_observer = MagicMock() + ordering: list[str] = [] + mock_load = AsyncMock() + mock_teardown = AsyncMock() + mock_load.side_effect = lambda: ordering.append("registry_load") + mock_observer.schedule.side_effect = lambda *args, **kwargs: ordering.append( + "observer_schedule" + ) + + request_scope = {"type": "http", "app": app, "headers": []} + + with ( + patch("cli_agent_orchestrator.api.main.setup_logging"), + patch("cli_agent_orchestrator.api.main.init_db"), + patch("cli_agent_orchestrator.api.main.cleanup_old_data"), + patch( + "cli_agent_orchestrator.api.main.PollingObserver", + return_value=mock_observer, + ), + patch("cli_agent_orchestrator.api.main.flow_daemon", fake_flow_daemon), + patch.object(PluginRegistry, "load", mock_load), + patch.object(PluginRegistry, "teardown", mock_teardown), + ): + async with lifespan(app): + registry = app.state.plugin_registry + + assert isinstance(registry, PluginRegistry) + assert get_plugin_registry(Request(request_scope)) is registry + assert get_plugin_registry(Request(dict(request_scope))) is registry + mock_load.assert_awaited_once() + mock_observer.schedule.assert_called_once() + mock_observer.start.assert_called_once() + assert ordering == ["registry_load", "observer_schedule"] + + mock_teardown.assert_awaited_once() + mock_observer.stop.assert_called_once() + mock_observer.join.assert_called_once() + + @pytest.mark.asyncio + async def test_lifespan_logs_no_plugins_registered_when_entry_points_are_empty( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """The lifespan should surface the empty-plugin INFO log from the registry.""" + + mock_observer = MagicMock() + + with ( + patch("cli_agent_orchestrator.api.main.setup_logging"), + patch("cli_agent_orchestrator.api.main.init_db"), + patch("cli_agent_orchestrator.api.main.cleanup_old_data"), + patch( + "cli_agent_orchestrator.api.main.PollingObserver", + return_value=mock_observer, + ), + patch("cli_agent_orchestrator.api.main.flow_daemon", fake_flow_daemon), + patch("importlib.metadata.entry_points", return_value=[]), + ): + with caplog.at_level(logging.INFO, logger="cli_agent_orchestrator.plugins.registry"): + async with lifespan(app): + assert isinstance(app.state.plugin_registry, PluginRegistry) + + assert "No CAO plugins registered" in caplog.text + + @pytest.mark.asyncio + async def test_lifespan_tolerates_plugin_setup_failure(self) -> None: + """The lifespan should still start when one plugin fails during setup.""" + + mock_observer = MagicMock() + + class FailingPlugin(CaoPlugin): + async def setup(self) -> None: + raise RuntimeError("setup failed") + + class HealthyPlugin(CaoPlugin): + @hook("post_send_message") + async def on_message(self, event: PostSendMessageEvent) -> None: + del event + + with ( + patch("cli_agent_orchestrator.api.main.setup_logging"), + patch("cli_agent_orchestrator.api.main.init_db"), + patch("cli_agent_orchestrator.api.main.cleanup_old_data"), + patch( + "cli_agent_orchestrator.api.main.PollingObserver", + return_value=mock_observer, + ), + patch("cli_agent_orchestrator.api.main.flow_daemon", fake_flow_daemon), + patch( + "importlib.metadata.entry_points", + return_value=[ + type("EP", (), {"name": "failing", "load": lambda self: FailingPlugin})(), + type("EP", (), {"name": "healthy", "load": lambda self: HealthyPlugin})(), + ], + ), + ): + async with lifespan(app): + registry = app.state.plugin_registry + + assert isinstance(registry, PluginRegistry) + assert len(registry._plugins) == 1 diff --git a/test/api/test_terminals.py b/test/api/test_terminals.py index 3fa4d388..d00c3c88 100644 --- a/test/api/test_terminals.py +++ b/test/api/test_terminals.py @@ -1,6 +1,6 @@ """Tests for terminal-related API endpoints including working directory and exit.""" -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from fastapi.testclient import TestClient @@ -70,8 +70,8 @@ class TestSessionCreationWithWorkingDirectory: def test_create_session_passes_working_directory(self, client, tmp_path): """Test that working_directory parameter is passed to service.""" - with patch("cli_agent_orchestrator.api.main.terminal_service") as mock_svc: - mock_svc.create_terminal.return_value = Terminal( + with patch("cli_agent_orchestrator.api.main.session_service") as mock_svc: + mock_svc.create_session.return_value = Terminal( id="abcd1234", name="test-window", session_name="test-session", @@ -90,13 +90,14 @@ def test_create_session_passes_working_directory(self, client, tmp_path): assert response.status_code == 201 # Verify working_directory was passed - call_kwargs = mock_svc.create_terminal.call_args.kwargs + call_kwargs = mock_svc.create_session.call_args.kwargs assert call_kwargs.get("working_directory") == str(tmp_path) + assert call_kwargs.get("registry") is not None def test_create_session_with_working_directory(self, client): """Test POST /sessions with working_directory parameter.""" - with patch("cli_agent_orchestrator.api.main.terminal_service") as mock_svc: - mock_svc.create_terminal.return_value = Terminal( + with patch("cli_agent_orchestrator.api.main.session_service") as mock_svc: + mock_svc.create_session.return_value = Terminal( id="abcd1234", name="test-window", session_name="test-session", @@ -114,7 +115,7 @@ def test_create_session_with_working_directory(self, client): ) assert response.status_code == 201 - call_kwargs = mock_svc.create_terminal.call_args.kwargs + call_kwargs = mock_svc.create_session.call_args.kwargs assert call_kwargs.get("working_directory") == "/custom/path" @@ -291,7 +292,7 @@ def test_delete_terminal_success(self, client): assert response.status_code == 200 assert response.json() == {"success": True} - mock_svc.delete_terminal.assert_called_once_with("abcd1234") + mock_svc.delete_terminal.assert_called_once_with("abcd1234", registry=ANY) def test_delete_terminal_not_found(self, client): """DELETE /terminals/{terminal_id} returns 404 for missing terminal.""" @@ -340,6 +341,14 @@ def test_create_inbox_message_success(self, client): assert data["success"] is True assert data["message_id"] == 1 assert data["sender_id"] == "sender1" + mock_create.assert_called_once_with( + "sender1", + "abcd1234", + "hello", + ) + mock_inbox.check_and_send_pending_messages.assert_called_once_with( + "abcd1234", registry=ANY + ) def test_create_inbox_message_delivery_failure_still_succeeds(self, client): """Immediate delivery failure should not fail the API response.""" @@ -467,9 +476,9 @@ def test_create_session_does_not_resolve_provider(self, client): """create_session should NOT call resolve_provider — CLI flag is the override.""" with ( patch("cli_agent_orchestrator.api.main.resolve_provider") as mock_resolve, - patch("cli_agent_orchestrator.api.main.terminal_service") as mock_svc, + patch("cli_agent_orchestrator.api.main.session_service") as mock_svc, ): - mock_svc.create_terminal.return_value = Terminal( + mock_svc.create_session.return_value = Terminal( id="abcd1234", name="test-window", session_name="test-session", @@ -488,8 +497,8 @@ def test_create_session_does_not_resolve_provider(self, client): assert response.status_code == 201 # resolve_provider should NOT have been called mock_resolve.assert_not_called() - # terminal_service should get the raw provider param - call_kwargs = mock_svc.create_terminal.call_args.kwargs + # session_service should get the raw provider param + call_kwargs = mock_svc.create_session.call_args.kwargs assert call_kwargs["provider"] == "kiro_cli" def test_create_terminal_returns_500_on_resolve_error(self, client): diff --git a/test/mcp_server/test_assign.py b/test/mcp_server/test_assign.py index 9601a814..d2bc4619 100644 --- a/test/mcp_server/test_assign.py +++ b/test/mcp_server/test_assign.py @@ -26,6 +26,7 @@ def test_assign_appends_sender_id_when_injection_enabled(self, mock_create, mock assert result["success"] is True sent_message = mock_send.call_args[0][1] + assert mock_send.call_args[0][2] == "assign" assert sent_message.startswith("Analyze the logs") assert "[Assigned by terminal supervisor-abc123" in sent_message assert "send results back to terminal supervisor-abc123 using send_message]" in sent_message @@ -45,6 +46,7 @@ def test_assign_no_suffix_when_injection_disabled(self, mock_create, mock_send): assert result["success"] is True sent_message = mock_send.call_args[0][1] + assert mock_send.call_args[0][2] == "assign" assert sent_message == "Analyze the logs" @patch("cli_agent_orchestrator.mcp_server.server.ENABLE_SENDER_ID_INJECTION", True) @@ -61,6 +63,7 @@ def test_assign_sender_id_fallback_unknown(self, mock_create, mock_send): result = _assign_impl("developer", "Build feature X") sent_message = mock_send.call_args[0][1] + assert mock_send.call_args[0][2] == "assign" assert "[Assigned by terminal unknown" in sent_message @patch("cli_agent_orchestrator.mcp_server.server.ENABLE_SENDER_ID_INJECTION", True) @@ -78,6 +81,7 @@ def test_assign_suffix_is_appended_not_prepended(self, mock_create, mock_send): _assign_impl("developer", original) sent_message = mock_send.call_args[0][1] + assert mock_send.call_args[0][2] == "assign" assert sent_message.startswith(original) assert sent_message.index("[Assigned by terminal") > len(original) diff --git a/test/mcp_server/test_handoff.py b/test/mcp_server/test_handoff.py index 1804ccbb..4f8f331c 100644 --- a/test/mcp_server/test_handoff.py +++ b/test/mcp_server/test_handoff.py @@ -30,13 +30,12 @@ def test_codex_provider_prepends_handoff_context(self, mock_create, mock_wait, m mock_requests.get.return_value = mock_response mock_requests.post.return_value = mock_response - result = asyncio.get_event_loop().run_until_complete( - _handoff_impl("developer", "Implement hello world") - ) + result = asyncio.run(_handoff_impl("developer", "Implement hello world")) # Verify _send_direct_input was called with the handoff prefix mock_send.assert_called_once() sent_message = mock_send.call_args[0][1] + assert mock_send.call_args[0][2] == "handoff" assert sent_message.startswith("[CAO Handoff]") assert "supervisor-abc123" in sent_message assert "Implement hello world" in sent_message @@ -58,13 +57,12 @@ def test_claude_code_provider_no_handoff_context(self, mock_create, mock_wait, m mock_requests.get.return_value = mock_response mock_requests.post.return_value = mock_response - result = asyncio.get_event_loop().run_until_complete( - _handoff_impl("developer", "Implement hello world") - ) + result = asyncio.run(_handoff_impl("developer", "Implement hello world")) # Verify message was sent unchanged mock_send.assert_called_once() sent_message = mock_send.call_args[0][1] + assert mock_send.call_args[0][2] == "handoff" assert sent_message == "Implement hello world" @patch("cli_agent_orchestrator.mcp_server.server._send_direct_input") @@ -83,12 +81,11 @@ def test_kiro_cli_provider_no_handoff_context(self, mock_create, mock_wait, mock mock_requests.get.return_value = mock_response mock_requests.post.return_value = mock_response - result = asyncio.get_event_loop().run_until_complete( - _handoff_impl("developer", "Implement hello world") - ) + result = asyncio.run(_handoff_impl("developer", "Implement hello world")) mock_send.assert_called_once() sent_message = mock_send.call_args[0][1] + assert mock_send.call_args[0][2] == "handoff" assert sent_message == "Implement hello world" @patch("cli_agent_orchestrator.mcp_server.server._send_direct_input") @@ -110,11 +107,10 @@ def test_codex_handoff_context_includes_supervisor_id_from_env( mock_requests.get.return_value = mock_response mock_requests.post.return_value = mock_response - asyncio.get_event_loop().run_until_complete( - _handoff_impl("developer", "Build feature X") - ) + asyncio.run(_handoff_impl("developer", "Build feature X")) sent_message = mock_send.call_args[0][1] + assert mock_send.call_args[0][2] == "handoff" assert "sup-xyz789" in sent_message assert "Build feature X" in sent_message @@ -135,9 +131,10 @@ def test_codex_handoff_context_fallback_when_no_env(self, mock_create, mock_wait mock_requests.get.return_value = mock_response mock_requests.post.return_value = mock_response - asyncio.get_event_loop().run_until_complete(_handoff_impl("developer", "Do task")) + asyncio.run(_handoff_impl("developer", "Do task")) sent_message = mock_send.call_args[0][1] + assert mock_send.call_args[0][2] == "handoff" assert "unknown" in sent_message assert "[CAO Handoff]" in sent_message assert "Do task" in sent_message @@ -160,7 +157,8 @@ def test_codex_handoff_original_message_preserved(self, mock_create, mock_wait, mock_requests.get.return_value = mock_response mock_requests.post.return_value = mock_response - asyncio.get_event_loop().run_until_complete(_handoff_impl("developer", original)) + asyncio.run(_handoff_impl("developer", original)) sent_message = mock_send.call_args[0][1] + assert mock_send.call_args[0][2] == "handoff" assert sent_message.endswith(original) diff --git a/test/plugins/__init__.py b/test/plugins/__init__.py new file mode 100644 index 00000000..c46c332f --- /dev/null +++ b/test/plugins/__init__.py @@ -0,0 +1 @@ +"""Tests for CAO plugin primitives.""" diff --git a/test/plugins/test_base.py b/test/plugins/test_base.py new file mode 100644 index 00000000..088c8e5e --- /dev/null +++ b/test/plugins/test_base.py @@ -0,0 +1,67 @@ +"""Tests for CAO plugin base primitives.""" + +import pytest + +from cli_agent_orchestrator.plugins.base import _HOOK_EVENT_ATTR, CaoPlugin, hook +from cli_agent_orchestrator.plugins.events import PostCreateSessionEvent, PostSendMessageEvent + + +class ExamplePlugin(CaoPlugin): + """Simple plugin used to verify hook registration.""" + + @hook("post_send_message") + async def on_message(self, event: PostSendMessageEvent) -> None: + """Handle a message event.""" + + @hook("post_create_session") + async def on_session_created(self, event: PostCreateSessionEvent) -> None: + """Handle a session creation event.""" + + +class TestCaoPlugin: + """Tests for the CaoPlugin base class.""" + + @pytest.mark.asyncio + async def test_setup_is_awaitable_no_op(self) -> None: + """Default setup() is awaitable and returns None.""" + + plugin = CaoPlugin() + + result = await plugin.setup() + + assert result is None + + @pytest.mark.asyncio + async def test_teardown_is_awaitable_no_op(self) -> None: + """Default teardown() is awaitable and returns None.""" + + plugin = CaoPlugin() + + result = await plugin.teardown() + + assert result is None + + +class TestHookDecorator: + """Tests for the @hook decorator.""" + + def test_hook_sets_event_attribute(self) -> None: + """Decorator attaches the configured event type to the method.""" + + assert getattr(ExamplePlugin.on_message, _HOOK_EVENT_ATTR) == "post_send_message" + + def test_hook_preserves_original_callable_reference(self) -> None: + """Decorator returns the same callable instead of wrapping it.""" + + async def handler(event: PostSendMessageEvent) -> None: + """Standalone handler used for identity checks.""" + + decorated = hook("post_create_terminal")(handler) + + assert decorated is handler + + def test_multiple_methods_can_register_distinct_events(self) -> None: + """Each decorated method retains its own hook event attribute.""" + + assert getattr(ExamplePlugin.on_message, _HOOK_EVENT_ATTR) == "post_send_message" + assert getattr(ExamplePlugin.on_session_created, _HOOK_EVENT_ATTR) == "post_create_session" diff --git a/test/plugins/test_events.py b/test/plugins/test_events.py new file mode 100644 index 00000000..b4de99f8 --- /dev/null +++ b/test/plugins/test_events.py @@ -0,0 +1,122 @@ +"""Tests for CAO plugin event dataclasses.""" + +from datetime import timedelta + +from cli_agent_orchestrator.plugins.events import ( + CaoEvent, + PostCreateSessionEvent, + PostCreateTerminalEvent, + PostKillSessionEvent, + PostKillTerminalEvent, + PostSendMessageEvent, +) + + +class TestEventDefaults: + """Tests for default plugin event values.""" + + def test_post_send_message_event_defaults(self) -> None: + """PostSendMessageEvent defaults to the post_send_message type.""" + + event = PostSendMessageEvent() + + assert event.event_type == "post_send_message" + assert event.session_id is None + assert isinstance(event, CaoEvent) + + def test_post_create_session_event_defaults(self) -> None: + """PostCreateSessionEvent defaults to the post_create_session type.""" + + event = PostCreateSessionEvent() + + assert event.event_type == "post_create_session" + assert event.session_id is None + + def test_post_kill_session_event_defaults(self) -> None: + """PostKillSessionEvent defaults to the post_kill_session type.""" + + event = PostKillSessionEvent() + + assert event.event_type == "post_kill_session" + assert event.session_id is None + + def test_post_create_terminal_event_defaults(self) -> None: + """PostCreateTerminalEvent defaults to the post_create_terminal type.""" + + event = PostCreateTerminalEvent() + + assert event.event_type == "post_create_terminal" + assert event.session_id is None + + def test_post_kill_terminal_event_defaults(self) -> None: + """PostKillTerminalEvent defaults to the post_kill_terminal type.""" + + event = PostKillTerminalEvent() + + assert event.event_type == "post_kill_terminal" + assert event.session_id is None + + def test_base_event_has_utc_timestamp(self) -> None: + """CaoEvent auto-populates a timezone-aware UTC timestamp.""" + + event = CaoEvent() + + assert event.timestamp.tzinfo is not None + assert event.timestamp.utcoffset() == timedelta(0) + assert event.event_type == "" + assert event.session_id is None + + +class TestEventFields: + """Tests for event-specific payload fields.""" + + def test_post_send_message_event_accepts_orchestration_fields(self) -> None: + """PostSendMessageEvent accepts all messaging payload fields.""" + + event = PostSendMessageEvent( + session_id="session-123", + sender="supervisor", + receiver="worker-1", + message="Process this task", + orchestration_type="assign", + ) + + assert event.session_id == "session-123" + assert event.sender == "supervisor" + assert event.receiver == "worker-1" + assert event.message == "Process this task" + assert event.orchestration_type == "assign" + + def test_session_events_carry_session_identifier_fields(self) -> None: + """Session lifecycle events carry their session name payload.""" + + created_event = PostCreateSessionEvent(session_id="session-1", session_name="Build") + killed_event = PostKillSessionEvent(session_id="session-1", session_name="Build") + + assert created_event.session_id == "session-1" + assert created_event.session_name == "Build" + assert killed_event.session_id == "session-1" + assert killed_event.session_name == "Build" + + def test_terminal_events_carry_terminal_identifier_fields(self) -> None: + """Terminal lifecycle events carry terminal-specific identifiers.""" + + created_event = PostCreateTerminalEvent( + session_id="session-2", + terminal_id="term-1", + agent_name="worker", + provider="codex", + ) + killed_event = PostKillTerminalEvent( + session_id="session-2", + terminal_id="term-1", + agent_name="worker", + ) + + assert created_event.session_id == "session-2" + assert created_event.terminal_id == "term-1" + assert created_event.agent_name == "worker" + assert created_event.provider == "codex" + assert killed_event.session_id == "session-2" + assert killed_event.terminal_id == "term-1" + assert killed_event.agent_name == "worker" diff --git a/test/plugins/test_package.py b/test/plugins/test_package.py new file mode 100644 index 00000000..c957d543 --- /dev/null +++ b/test/plugins/test_package.py @@ -0,0 +1,59 @@ +"""Smoke tests for the public CAO plugin package API.""" + +from cli_agent_orchestrator.plugins import ( + CaoEvent, + CaoPlugin, + PluginRegistry, + PostCreateSessionEvent, + PostCreateTerminalEvent, + PostKillSessionEvent, + PostKillTerminalEvent, + PostSendMessageEvent, + __all__, + hook, +) +from cli_agent_orchestrator.plugins.base import CaoPlugin as BaseCaoPlugin +from cli_agent_orchestrator.plugins.base import hook as base_hook +from cli_agent_orchestrator.plugins.events import CaoEvent as BaseCaoEvent +from cli_agent_orchestrator.plugins.events import ( + PostCreateSessionEvent as BasePostCreateSessionEvent, +) +from cli_agent_orchestrator.plugins.events import ( + PostCreateTerminalEvent as BasePostCreateTerminalEvent, +) +from cli_agent_orchestrator.plugins.events import PostKillSessionEvent as BasePostKillSessionEvent +from cli_agent_orchestrator.plugins.events import PostKillTerminalEvent as BasePostKillTerminalEvent +from cli_agent_orchestrator.plugins.events import PostSendMessageEvent as BasePostSendMessageEvent +from cli_agent_orchestrator.plugins.registry import PluginRegistry as BasePluginRegistry + + +class TestPluginPackageAPI: + """Tests for the plugin package's public exports.""" + + def test_public_imports_resolve_to_expected_symbols(self) -> None: + """Importing from the package should resolve to the concrete implementation objects.""" + + assert CaoPlugin is BaseCaoPlugin + assert hook is base_hook + assert CaoEvent is BaseCaoEvent + assert PostSendMessageEvent is BasePostSendMessageEvent + assert PostCreateSessionEvent is BasePostCreateSessionEvent + assert PostKillSessionEvent is BasePostKillSessionEvent + assert PostCreateTerminalEvent is BasePostCreateTerminalEvent + assert PostKillTerminalEvent is BasePostKillTerminalEvent + assert PluginRegistry is BasePluginRegistry + + def test___all___contains_exactly_the_phase_two_public_api(self) -> None: + """The package __all__ should expose exactly the documented public symbols.""" + + assert __all__ == [ + "CaoPlugin", + "hook", + "CaoEvent", + "PostSendMessageEvent", + "PostCreateSessionEvent", + "PostKillSessionEvent", + "PostCreateTerminalEvent", + "PostKillTerminalEvent", + "PluginRegistry", + ] diff --git a/test/plugins/test_registry.py b/test/plugins/test_registry.py new file mode 100644 index 00000000..a5bbc141 --- /dev/null +++ b/test/plugins/test_registry.py @@ -0,0 +1,292 @@ +"""Tests for plugin registry discovery, dispatch, and lifecycle behavior.""" + +import logging +from dataclasses import dataclass +from unittest.mock import patch + +import pytest + +from cli_agent_orchestrator.plugins import CaoPlugin, PluginRegistry, hook +from cli_agent_orchestrator.plugins.events import PostSendMessageEvent + + +@dataclass +class FakeEntryPoint: + """Simple synthetic entry point for registry tests.""" + + name: str + loaded: object + + def load(self) -> object: + """Return the configured object for this fake entry point.""" + + return self.loaded + + +def make_entry_point(name: str, loaded: object) -> FakeEntryPoint: + """Construct a fake entry point for a test plugin class or object.""" + + return FakeEntryPoint(name=name, loaded=loaded) + + +class TestPluginRegistryLoad: + """Tests for plugin discovery and registration.""" + + @pytest.mark.asyncio + async def test_load_with_no_entry_points_emits_info_and_keeps_dispatch_empty( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """No registered plugins should leave the registry empty and log INFO.""" + + registry = PluginRegistry() + + with patch("importlib.metadata.entry_points", return_value=[]): + with caplog.at_level(logging.INFO, logger="cli_agent_orchestrator.plugins.registry"): + await registry.load() + + assert registry._plugins == [] + assert registry._dispatch == {} + assert "No CAO plugins registered" in caplog.text + + @pytest.mark.asyncio + async def test_load_single_plugin_with_one_hook_dispatches_matching_event( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """A single registered hook should receive matching dispatched events.""" + + received: list[str] = [] + + class SingleHookPlugin(CaoPlugin): + @hook("post_send_message") + async def on_message(self, event: PostSendMessageEvent) -> None: + received.append(event.message) + + registry = PluginRegistry() + + with patch( + "importlib.metadata.entry_points", + return_value=[make_entry_point("single-hook", SingleHookPlugin)], + ): + with caplog.at_level(logging.INFO, logger="cli_agent_orchestrator.plugins.registry"): + await registry.load() + + await registry.dispatch("post_send_message", PostSendMessageEvent(message="hello")) + + assert received == ["hello"] + assert len(registry._plugins) == 1 + assert len(registry._dispatch["post_send_message"]) == 1 + assert "Loaded CAO plugin: single-hook" in caplog.text + + @pytest.mark.asyncio + async def test_load_single_plugin_with_two_hooks_for_same_event_invokes_both(self) -> None: + """Two hooks on the same plugin should both be registered and called.""" + + received: list[str] = [] + + class DoubleHookPlugin(CaoPlugin): + @hook("post_send_message") + async def first(self, event: PostSendMessageEvent) -> None: + received.append(f"first:{event.message}") + + @hook("post_send_message") + async def second(self, event: PostSendMessageEvent) -> None: + received.append(f"second:{event.message}") + + registry = PluginRegistry() + + with patch( + "importlib.metadata.entry_points", + return_value=[make_entry_point("double-hook", DoubleHookPlugin)], + ): + await registry.load() + + await registry.dispatch("post_send_message", PostSendMessageEvent(message="hello")) + + assert set(received) == {"first:hello", "second:hello"} + assert len(received) == 2 + assert len(registry._dispatch["post_send_message"]) == 2 + + @pytest.mark.asyncio + async def test_load_multiple_plugins_for_same_event_invokes_all(self) -> None: + """Hooks from multiple plugins should all receive the event.""" + + received: list[str] = [] + + class FirstPlugin(CaoPlugin): + @hook("post_send_message") + async def on_message(self, event: PostSendMessageEvent) -> None: + received.append(f"first:{event.message}") + + class SecondPlugin(CaoPlugin): + @hook("post_send_message") + async def on_message(self, event: PostSendMessageEvent) -> None: + received.append(f"second:{event.message}") + + registry = PluginRegistry() + + with patch( + "importlib.metadata.entry_points", + return_value=[ + make_entry_point("first", FirstPlugin), + make_entry_point("second", SecondPlugin), + ], + ): + await registry.load() + + await registry.dispatch("post_send_message", PostSendMessageEvent(message="hello")) + + assert set(received) == {"first:hello", "second:hello"} + assert len(registry._plugins) == 2 + assert len(registry._dispatch["post_send_message"]) == 2 + + @pytest.mark.asyncio + async def test_load_skips_plugin_when_setup_raises_and_loads_remaining( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """A setup failure should log a warning and not block later plugins.""" + + received: list[str] = [] + + class FailingSetupPlugin(CaoPlugin): + async def setup(self) -> None: + raise RuntimeError("setup failed") + + @hook("post_send_message") + async def on_message(self, event: PostSendMessageEvent) -> None: + received.append(f"failing:{event.message}") + + class HealthyPlugin(CaoPlugin): + @hook("post_send_message") + async def on_message(self, event: PostSendMessageEvent) -> None: + received.append(f"healthy:{event.message}") + + registry = PluginRegistry() + + with patch( + "importlib.metadata.entry_points", + return_value=[ + make_entry_point("failing-setup", FailingSetupPlugin), + make_entry_point("healthy", HealthyPlugin), + ], + ): + with caplog.at_level(logging.WARNING, logger="cli_agent_orchestrator.plugins.registry"): + await registry.load() + + await registry.dispatch("post_send_message", PostSendMessageEvent(message="hello")) + + assert received == ["healthy:hello"] + assert len(registry._plugins) == 1 + assert "Failed to load plugin 'failing-setup'" in caplog.text + assert caplog.records[-1].exc_info is not None + + @pytest.mark.asyncio + async def test_load_skips_non_plugin_entry_point_with_warning( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """A non-CaoPlugin entry point should be skipped and logged.""" + + class NotAPlugin: + pass + + registry = PluginRegistry() + + with patch( + "importlib.metadata.entry_points", + return_value=[make_entry_point("not-a-plugin", NotAPlugin)], + ): + with caplog.at_level(logging.WARNING, logger="cli_agent_orchestrator.plugins.registry"): + await registry.load() + + assert registry._plugins == [] + assert registry._dispatch == {} + assert "not a CaoPlugin subclass, skipping" in caplog.text + + +class TestPluginRegistryDispatch: + """Tests for dispatch-time behavior and error isolation.""" + + @pytest.mark.asyncio + async def test_dispatch_logs_warning_and_continues_when_hook_raises( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """A failing hook should not prevent other matching hooks from running.""" + + received: list[str] = [] + + class FailingHookPlugin(CaoPlugin): + @hook("post_send_message") + async def broken(self, event: PostSendMessageEvent) -> None: + received.append("broken") + raise RuntimeError("dispatch failed") + + @hook("post_send_message") + async def healthy(self, event: PostSendMessageEvent) -> None: + received.append("healthy") + + registry = PluginRegistry() + + with patch( + "importlib.metadata.entry_points", + return_value=[make_entry_point("failing-hook", FailingHookPlugin)], + ): + await registry.load() + + with caplog.at_level(logging.WARNING, logger="cli_agent_orchestrator.plugins.registry"): + await registry.dispatch("post_send_message", PostSendMessageEvent(message="hello")) + + assert set(received) == {"broken", "healthy"} + assert "raised an error for event 'post_send_message'" in caplog.text + assert caplog.records[-1].exc_info is not None + + @pytest.mark.asyncio + async def test_dispatch_with_no_registered_handlers_is_no_op( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """Dispatching an unhandled event should do nothing and not error.""" + + registry = PluginRegistry() + + with caplog.at_level(logging.WARNING, logger="cli_agent_orchestrator.plugins.registry"): + await registry.dispatch("post_send_message", PostSendMessageEvent(message="hello")) + + assert registry._dispatch == {} + assert caplog.records == [] + + +class TestPluginRegistryTeardown: + """Tests for plugin teardown behavior.""" + + @pytest.mark.asyncio + async def test_teardown_logs_warning_and_continues_after_failure( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """A teardown failure should not prevent later plugins from tearing down.""" + + torn_down: list[str] = [] + + class FailingTeardownPlugin(CaoPlugin): + async def teardown(self) -> None: + torn_down.append("failing") + raise RuntimeError("teardown failed") + + class HealthyTeardownPlugin(CaoPlugin): + async def teardown(self) -> None: + torn_down.append("healthy") + + registry = PluginRegistry() + + with patch( + "importlib.metadata.entry_points", + return_value=[ + make_entry_point("failing", FailingTeardownPlugin), + make_entry_point("healthy", HealthyTeardownPlugin), + ], + ): + await registry.load() + + with caplog.at_level(logging.WARNING, logger="cli_agent_orchestrator.plugins.registry"): + await registry.teardown() + + assert set(torn_down) == {"failing", "healthy"} + assert "Plugin teardown failed for FailingTeardownPlugin" in caplog.text + assert caplog.records[-1].exc_info is not None diff --git a/test/services/test_inbox_service.py b/test/services/test_inbox_service.py index b242f83f..7944eec1 100644 --- a/test/services/test_inbox_service.py +++ b/test/services/test_inbox_service.py @@ -192,7 +192,7 @@ def test_on_modified_triggers_delivery(self, mock_get_messages, mock_has_idle, m handler.on_modified(event) - mock_check_send.assert_called_once_with("test-terminal") + mock_check_send.assert_called_once_with("test-terminal", registry=None) @patch("cli_agent_orchestrator.services.inbox_service.get_pending_messages") def test_handle_log_change_no_pending_messages(self, mock_get_messages): diff --git a/test/services/test_plugin_dispatch.py b/test/services/test_plugin_dispatch.py new file mode 100644 index 00000000..4f5f3cac --- /dev/null +++ b/test/services/test_plugin_dispatch.py @@ -0,0 +1,64 @@ +"""Tests for plugin dispatch adapter behavior.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from cli_agent_orchestrator.plugins import PostSendMessageEvent +from cli_agent_orchestrator.services.plugin_dispatch import dispatch_plugin_event + + +def test_dispatch_plugin_event_noops_when_registry_missing(): + """Missing registry should be a silent no-op.""" + + event = PostSendMessageEvent( + session_id="cao-demo", + sender="supervisor-1", + receiver="worker-1", + message="Hello", + orchestration_type="send_message", + ) + + dispatch_plugin_event(None, "post_send_message", event) + + +def test_dispatch_plugin_event_logs_and_swallows_registry_errors(caplog): + """Adapter-level failures should be logged and must not propagate.""" + + registry = MagicMock() + registry.dispatch = AsyncMock(side_effect=RuntimeError("dispatch failed")) + event = PostSendMessageEvent( + session_id="cao-demo", + sender="supervisor-1", + receiver="worker-1", + message="Hello", + orchestration_type="send_message", + ) + + with caplog.at_level("WARNING"): + dispatch_plugin_event(registry, "post_send_message", event) + + registry.dispatch.assert_awaited_once_with("post_send_message", event) + assert caplog.records[-1].message == "Plugin event dispatch failed for post_send_message" + assert caplog.records[-1].exc_info is not None + + +@pytest.mark.asyncio +async def test_dispatch_plugin_event_schedules_dispatch_in_running_loop(): + """A running event loop should use create_task and still complete dispatch.""" + + registry = MagicMock() + registry.dispatch = AsyncMock() + event = PostSendMessageEvent( + session_id="cao-demo", + sender="supervisor-1", + receiver="worker-1", + message="Hello", + orchestration_type="assign", + ) + + dispatch_plugin_event(registry, "post_send_message", event) + await asyncio.sleep(0) + + registry.dispatch.assert_awaited_once_with("post_send_message", event) diff --git a/test/services/test_plugin_event_emission.py b/test/services/test_plugin_event_emission.py new file mode 100644 index 00000000..76358f29 --- /dev/null +++ b/test/services/test_plugin_event_emission.py @@ -0,0 +1,393 @@ +"""Tests for plugin event emission from service-layer operations.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from cli_agent_orchestrator.models.agent_profile import AgentProfile +from cli_agent_orchestrator.models.inbox import MessageStatus +from cli_agent_orchestrator.models.terminal import Terminal, TerminalStatus +from cli_agent_orchestrator.plugins import ( + PostCreateSessionEvent, + PostCreateTerminalEvent, + PostKillSessionEvent, + PostKillTerminalEvent, + PostSendMessageEvent, +) +from cli_agent_orchestrator.services.inbox_service import check_and_send_pending_messages +from cli_agent_orchestrator.services.session_service import create_session, delete_session +from cli_agent_orchestrator.services.terminal_service import ( + create_terminal, + delete_terminal, + send_input, +) + + +def _registry_mock() -> MagicMock: + """Build a registry double whose async dispatch can be asserted directly.""" + + registry = MagicMock() + registry.dispatch = AsyncMock() + return registry + + +class TestSessionPluginEvents: + """Verify session lifecycle events are emitted correctly.""" + + @patch("cli_agent_orchestrator.services.session_service.create_terminal") + def test_create_session_dispatches_post_create_session_event(self, mock_create_terminal): + """Successful session creation should emit exactly one post_create_session event.""" + registry = _registry_mock() + mock_create_terminal.return_value = Terminal( + id="abcd1234", + name="developer-abcd", + session_name="cao-demo", + provider="kiro_cli", + agent_profile="developer", + ) + + result = create_session( + provider="kiro_cli", + agent_profile="developer", + session_name="cao-demo", + registry=registry, + ) + + assert result.session_name == "cao-demo" + registry.dispatch.assert_awaited_once() + event_type, event = registry.dispatch.await_args.args + assert event_type == "post_create_session" + assert isinstance(event, PostCreateSessionEvent) + assert event.session_id == "cao-demo" + assert event.session_name == "cao-demo" + + @patch("cli_agent_orchestrator.services.session_service.create_terminal") + def test_create_session_does_not_dispatch_on_failure(self, mock_create_terminal): + """Session creation failures must not emit plugin events.""" + registry = _registry_mock() + mock_create_terminal.side_effect = RuntimeError("tmux failed") + + with pytest.raises(RuntimeError, match="tmux failed"): + create_session(provider="kiro_cli", agent_profile="developer", registry=registry) + + registry.dispatch.assert_not_awaited() + + @patch("cli_agent_orchestrator.services.session_service.delete_terminals_by_session") + @patch("cli_agent_orchestrator.services.session_service.list_terminals_by_session") + @patch("cli_agent_orchestrator.services.session_service.tmux_client") + def test_delete_session_dispatches_post_kill_session_event_after_cleanup( + self, mock_tmux, mock_list_terminals, mock_delete_terminals + ): + """Session kill should emit after the tmux kill and DB cleanup succeed.""" + registry = _registry_mock() + call_order: list[str] = [] + + async def record_dispatch(*_args): + call_order.append("dispatch") + + mock_tmux.session_exists.return_value = True + mock_tmux.kill_session.side_effect = lambda *_: call_order.append("kill_session") + mock_list_terminals.return_value = [] + mock_delete_terminals.side_effect = lambda *_: call_order.append("delete_terminals") + registry.dispatch.side_effect = record_dispatch + + result = delete_session("cao-demo", registry=registry) + + assert result == {"deleted": ["cao-demo"], "errors": []} + assert call_order == ["kill_session", "delete_terminals", "dispatch"] + event_type, event = registry.dispatch.await_args.args + assert event_type == "post_kill_session" + assert isinstance(event, PostKillSessionEvent) + assert event.session_id == "cao-demo" + assert event.session_name == "cao-demo" + + @patch("cli_agent_orchestrator.services.session_service.tmux_client") + def test_delete_session_does_not_dispatch_on_failure(self, mock_tmux): + """Missing sessions should raise without emitting events.""" + registry = _registry_mock() + mock_tmux.session_exists.return_value = False + + with pytest.raises(ValueError, match="Session 'cao-missing' not found"): + delete_session("cao-missing", registry=registry) + + registry.dispatch.assert_not_awaited() + + +class TestTerminalPluginEvents: + """Verify terminal lifecycle events are emitted correctly.""" + + @patch("cli_agent_orchestrator.services.terminal_service.TERMINAL_LOG_DIR") + @patch("cli_agent_orchestrator.services.terminal_service.build_skill_catalog", return_value="") + @patch("cli_agent_orchestrator.services.terminal_service.load_agent_profile") + @patch("cli_agent_orchestrator.services.terminal_service.generate_terminal_id") + @patch("cli_agent_orchestrator.services.terminal_service.generate_window_name") + @patch("cli_agent_orchestrator.services.terminal_service.tmux_client") + @patch("cli_agent_orchestrator.services.terminal_service.db_create_terminal") + @patch("cli_agent_orchestrator.services.terminal_service.provider_manager") + def test_create_terminal_dispatches_post_create_terminal_event_after_setup( + self, + mock_provider_manager, + mock_db_create_terminal, + mock_tmux, + mock_generate_window_name, + mock_generate_terminal_id, + mock_load_agent_profile, + mock_build_skill_catalog, + mock_log_dir, + ): + """Terminal creation should emit only after persistence and startup complete.""" + registry = _registry_mock() + call_order: list[str] = [] + + async def record_dispatch(*_args): + call_order.append("dispatch") + + mock_generate_terminal_id.return_value = "abcd1234" + mock_generate_window_name.return_value = "developer-abcd" + mock_tmux.session_exists.return_value = False + mock_db_create_terminal.side_effect = lambda *_: call_order.append("db_create") + mock_load_agent_profile.return_value = AgentProfile(name="developer", description="Dev") + + provider = MagicMock() + provider.initialize.side_effect = lambda: call_order.append("provider_initialize") + mock_provider_manager.create_provider.return_value = provider + + log_path = MagicMock() + mock_log_dir.__truediv__.return_value = log_path + mock_tmux.pipe_pane.side_effect = lambda *_: call_order.append("pipe_pane") + registry.dispatch.side_effect = record_dispatch + + terminal = create_terminal( + provider="kiro_cli", + agent_profile="developer", + session_name="demo", + new_session=True, + allowed_tools=["*"], + registry=registry, + ) + + assert terminal.id == "abcd1234" + assert call_order == ["db_create", "provider_initialize", "pipe_pane", "dispatch"] + event_type, event = registry.dispatch.await_args.args + assert event_type == "post_create_terminal" + assert isinstance(event, PostCreateTerminalEvent) + assert event.session_id == "cao-demo" + assert event.terminal_id == "abcd1234" + assert event.agent_name == "developer" + assert event.provider == "kiro_cli" + + @patch("cli_agent_orchestrator.services.terminal_service.TERMINAL_LOG_DIR") + @patch("cli_agent_orchestrator.services.terminal_service.build_skill_catalog", return_value="") + @patch("cli_agent_orchestrator.services.terminal_service.load_agent_profile") + @patch("cli_agent_orchestrator.services.terminal_service.generate_terminal_id") + @patch("cli_agent_orchestrator.services.terminal_service.generate_window_name") + @patch("cli_agent_orchestrator.services.terminal_service.tmux_client") + @patch("cli_agent_orchestrator.services.terminal_service.db_create_terminal") + @patch("cli_agent_orchestrator.services.terminal_service.provider_manager") + def test_create_terminal_does_not_dispatch_on_failure( + self, + mock_provider_manager, + mock_db_create_terminal, + mock_tmux, + mock_generate_window_name, + mock_generate_terminal_id, + mock_load_agent_profile, + mock_build_skill_catalog, + mock_log_dir, + ): + """Terminal creation failures must not emit post_create_terminal.""" + registry = _registry_mock() + mock_generate_terminal_id.return_value = "abcd1234" + mock_generate_window_name.return_value = "developer-abcd" + mock_tmux.session_exists.return_value = False + mock_load_agent_profile.return_value = AgentProfile(name="developer", description="Dev") + + provider = MagicMock() + provider.initialize.side_effect = RuntimeError("provider init failed") + mock_provider_manager.create_provider.return_value = provider + mock_log_dir.__truediv__.return_value = MagicMock() + + with pytest.raises(RuntimeError, match="provider init failed"): + create_terminal( + provider="kiro_cli", + agent_profile="developer", + session_name="demo", + new_session=True, + allowed_tools=["*"], + registry=registry, + ) + + registry.dispatch.assert_not_awaited() + + @patch("cli_agent_orchestrator.services.terminal_service.db_delete_terminal", return_value=True) + @patch("cli_agent_orchestrator.services.terminal_service.provider_manager") + @patch("cli_agent_orchestrator.services.terminal_service.tmux_client") + @patch("cli_agent_orchestrator.services.terminal_service.get_terminal_metadata") + def test_delete_terminal_dispatches_post_kill_terminal_event_after_delete( + self, mock_get_metadata, mock_tmux, mock_provider_manager, mock_db_delete_terminal + ): + """Terminal kill should emit only after deletion succeeds.""" + registry = _registry_mock() + call_order: list[str] = [] + + async def record_dispatch(*_args): + call_order.append("dispatch") + + mock_get_metadata.return_value = { + "tmux_session": "cao-demo", + "tmux_window": "developer-abcd", + "agent_profile": "developer", + } + mock_provider_manager.cleanup_provider.side_effect = lambda *_: call_order.append("cleanup") + mock_db_delete_terminal.side_effect = lambda *_: call_order.append("db_delete") or True + registry.dispatch.side_effect = record_dispatch + + deleted = delete_terminal("abcd1234", registry=registry) + + assert deleted is True + assert call_order[-2:] == ["db_delete", "dispatch"] + event_type, event = registry.dispatch.await_args.args + assert event_type == "post_kill_terminal" + assert isinstance(event, PostKillTerminalEvent) + assert event.session_id == "cao-demo" + assert event.terminal_id == "abcd1234" + assert event.agent_name == "developer" + + @patch("cli_agent_orchestrator.services.terminal_service.db_delete_terminal") + @patch("cli_agent_orchestrator.services.terminal_service.provider_manager") + @patch("cli_agent_orchestrator.services.terminal_service.tmux_client") + @patch("cli_agent_orchestrator.services.terminal_service.get_terminal_metadata") + def test_delete_terminal_does_not_dispatch_on_failure( + self, mock_get_metadata, mock_tmux, mock_provider_manager, mock_db_delete_terminal + ): + """Deletion failures must not emit post_kill_terminal.""" + registry = _registry_mock() + mock_get_metadata.return_value = { + "tmux_session": "cao-demo", + "tmux_window": "developer-abcd", + "agent_profile": "developer", + } + mock_db_delete_terminal.side_effect = RuntimeError("db delete failed") + + with pytest.raises(RuntimeError, match="db delete failed"): + delete_terminal("abcd1234", registry=registry) + + registry.dispatch.assert_not_awaited() + + +class TestMessagePluginEvents: + """Verify message delivery emits the correct event payloads.""" + + @pytest.mark.parametrize("orchestration_type", ["send_message", "assign", "handoff"]) + @patch("cli_agent_orchestrator.services.terminal_service.update_last_active") + @patch("cli_agent_orchestrator.services.terminal_service.tmux_client") + @patch("cli_agent_orchestrator.services.terminal_service.provider_manager") + @patch("cli_agent_orchestrator.services.terminal_service.get_terminal_metadata") + def test_send_input_dispatches_post_send_message_event_for_each_orchestration_mode( + self, + mock_get_metadata, + mock_provider_manager, + mock_tmux, + mock_update_last_active, + orchestration_type, + ): + """Every successful delivery should emit one post_send_message event.""" + registry = _registry_mock() + call_order: list[str] = [] + + async def record_dispatch(*_args): + call_order.append("dispatch") + + mock_get_metadata.return_value = { + "tmux_session": "cao-demo", + "tmux_window": "developer-abcd", + } + provider = MagicMock() + provider.paste_enter_count = 2 + provider.mark_input_received.side_effect = lambda: call_order.append("mark_input_received") + mock_provider_manager.get_provider.return_value = provider + mock_tmux.send_keys.side_effect = lambda *_args, **_kwargs: call_order.append("send_keys") + mock_update_last_active.side_effect = lambda *_: call_order.append("update_last_active") + registry.dispatch.side_effect = record_dispatch + + delivered = send_input( + "abcd1234", + "Hello from supervisor", + registry=registry, + sender_id="supervisor-1", + orchestration_type=orchestration_type, + ) + + assert delivered is True + assert call_order[-1] == "dispatch" + event_type, event = registry.dispatch.await_args.args + assert event_type == "post_send_message" + assert isinstance(event, PostSendMessageEvent) + assert event.session_id == "cao-demo" + assert event.sender == "supervisor-1" + assert event.receiver == "abcd1234" + assert event.message == "Hello from supervisor" + assert event.orchestration_type == orchestration_type + + @patch("cli_agent_orchestrator.services.terminal_service.tmux_client") + @patch("cli_agent_orchestrator.services.terminal_service.provider_manager") + @patch("cli_agent_orchestrator.services.terminal_service.get_terminal_metadata") + def test_send_input_does_not_dispatch_on_failure( + self, mock_get_metadata, mock_provider_manager, mock_tmux + ): + """Message delivery failures must not emit post_send_message.""" + registry = _registry_mock() + mock_get_metadata.return_value = { + "tmux_session": "cao-demo", + "tmux_window": "developer-abcd", + } + provider = MagicMock() + provider.paste_enter_count = 1 + mock_provider_manager.get_provider.return_value = provider + mock_tmux.send_keys.side_effect = RuntimeError("send failed") + + with pytest.raises(RuntimeError, match="send failed"): + send_input( + "abcd1234", + "Hello from supervisor", + registry=registry, + sender_id="supervisor-1", + orchestration_type="assign", + ) + + registry.dispatch.assert_not_awaited() + + @patch("cli_agent_orchestrator.services.inbox_service.update_message_status") + @patch("cli_agent_orchestrator.services.inbox_service.terminal_service") + @patch("cli_agent_orchestrator.services.inbox_service.provider_manager") + @patch("cli_agent_orchestrator.services.inbox_service.get_pending_messages") + def test_inbox_delivery_threads_send_message_context_to_terminal_service( + self, + mock_get_pending_messages, + mock_provider_manager, + mock_terminal_service, + mock_update_message_status, + ): + """Queued inbox delivery should forward sender context and hardcode send_message.""" + registry = _registry_mock() + message = MagicMock() + message.id = 17 + message.sender_id = "supervisor-1" + message.message = "Please review this" + mock_get_pending_messages.return_value = [message] + + provider = MagicMock() + provider.get_status.return_value = TerminalStatus.IDLE + mock_provider_manager.get_provider.return_value = provider + + delivered = check_and_send_pending_messages("abcd1234", registry=registry) + + assert delivered is True + mock_terminal_service.send_input.assert_called_once_with( + "abcd1234", + "Please review this", + registry=registry, + sender_id="supervisor-1", + orchestration_type="send_message", + ) + mock_update_message_status.assert_called_once_with(17, MessageStatus.DELIVERED)