Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/tumf/mcp-shell-server in…
Browse files Browse the repository at this point in the history
…to develop
  • Loading branch information
tumf committed Jan 5, 2025
2 parents 561c208 + fda0adb commit f86e1e0
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 19 deletions.
16 changes: 8 additions & 8 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@ test:
uv run pytest

format:
black .
isort .
ruff check --fix .
uv run isort .
uv run black .
uv run ruff check --fix .


lint:
black --check .
isort --check .
ruff check .
uv run isort --check .
uv run black --check .
uv run ruff check .

typecheck:
mypy src/mcp_shell_server tests
uv run mypy src/mcp_shell_server tests

coverage:
pytest --cov=src/mcp_shell_server tests
uv run pytest --cov=src/mcp_shell_server tests

# Run all checks required before pushing
check: lint typecheck
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ target-version = ['py311']
profile = "black"
line_length = 88

[tool.mypy]
error_summary = false
hide_error_codes = true
disallow_untyped_defs = false
check_untyped_defs = false

[tool.hatch.version]
path = "src/mcp_shell_server/version.py"

Expand Down
21 changes: 16 additions & 5 deletions src/mcp_shell_server/process_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,31 @@ async def start_process(
return process

async def cleanup_processes(
self, processes: List[asyncio.subprocess.Process]
self, processes: Optional[List[asyncio.subprocess.Process]] = None
) -> None:
"""Clean up processes by killing them if they're still running.
Args:
processes: List of processes to clean up
processes: Optional list of processes to clean up. If None, clean up all tracked processes
"""
if processes is None:
processes = list(self._processes)

cleanup_tasks = []
for process in processes:
if process.returncode is None:
try:
# Force kill immediately as required by tests
process.kill()
cleanup_tasks.append(asyncio.create_task(process.wait()))
# First attempt graceful termination
process.terminate()
try:
await asyncio.wait_for(process.wait(), timeout=0.5)
except asyncio.TimeoutError:
# Force kill if termination didn't work
process.kill()
cleanup_tasks.append(asyncio.create_task(process.wait()))
except ProcessLookupError:
# Process already terminated
pass
except Exception as e:
logging.warning(f"Error killing process: {e}")

Expand Down
56 changes: 54 additions & 2 deletions src/mcp_shell_server/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import signal
import traceback
from collections.abc import Sequence
from typing import Any
Expand Down Expand Up @@ -139,13 +140,64 @@ async def call_tool(name: str, arguments: Any) -> Sequence[TextContent]:
async def main() -> None:
"""Main entry point for the MCP shell server"""
logger.info(f"Starting MCP shell server v{__version__}")

# Setup signal handling
loop = asyncio.get_running_loop()
stop_event = asyncio.Event()

def handle_signal():
if not stop_event.is_set(): # Prevent duplicate handling
logger.info("Received shutdown signal, starting cleanup...")
stop_event.set()

# Register signal handlers
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, handle_signal)

try:
from mcp.server.stdio import stdio_server

async with stdio_server() as (read_stream, write_stream):
await app.run(
read_stream, write_stream, app.create_initialization_options()
# Run the server until stop_event is set
server_task = asyncio.create_task(
app.run(read_stream, write_stream, app.create_initialization_options())
)

# Create task for stop event
stop_task = asyncio.create_task(stop_event.wait())

# Wait for either server completion or stop signal
done, pending = await asyncio.wait(
[server_task, stop_task], return_when=asyncio.FIRST_COMPLETED
)

# Check for exceptions in completed tasks
for task in done:
try:
await task
except Exception:
raise # Re-raise the exception

# Cancel any pending tasks
for task in pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

except Exception as e:
logger.error(f"Server error: {str(e)}")
raise
finally:
# Cleanup signal handlers
for sig in (signal.SIGTERM, signal.SIGINT):
loop.remove_signal_handler(sig)

# Ensure all processes are terminated
if hasattr(tool_handler, "executor") and hasattr(
tool_handler.executor, "process_manager"
):
await tool_handler.executor.process_manager.cleanup_processes()

logger.info("Server shutdown complete")
9 changes: 5 additions & 4 deletions tests/test_process_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,18 @@ async def test_cleanup_processes(process_manager):
# Create mock processes with different states
running_proc = create_mock_process()
running_proc.returncode = None
# Mock wait to simulate timeout
running_proc.wait.side_effect = [asyncio.TimeoutError(), None]

completed_proc = create_mock_process()
completed_proc.returncode = 0

# Execute cleanup
await process_manager.cleanup_processes([running_proc, completed_proc])

# Verify running process was killed and waited for
running_proc.kill.assert_called_once()
running_proc.wait.assert_awaited_once()

# Verify running process was terminated first, then killed
running_proc.terminate.assert_called_once()
assert running_proc.wait.await_count == 2 # wait called for both terminate and kill
# Verify completed process was not killed or waited for
completed_proc.kill.assert_not_called()
completed_proc.wait.assert_not_called()
Expand Down
66 changes: 66 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import os
import signal
import tempfile

import pytest
Expand Down Expand Up @@ -434,3 +435,68 @@ async def test_environment_variables(monkeypatch, temp_test_dir):
{"command": ["env"], "directory": temp_test_dir},
)
assert len(result) == 1


@pytest.mark.asyncio
async def test_signal_handling(monkeypatch, mocker):
"""Test signal handling and cleanup during server shutdown"""
from mcp_shell_server.server import main

# Setup mocks
mock_read_stream = mocker.AsyncMock()
mock_write_stream = mocker.AsyncMock()
mock_cleanup_processes = mocker.AsyncMock()

# Mock process manager
class MockExecutor:
def __init__(self):
self.process_manager = mocker.MagicMock()
self.process_manager.cleanup_processes = mock_cleanup_processes

class MockToolHandler:
def __init__(self):
self.executor = MockExecutor()

# Setup server mocks
context_manager = mocker.AsyncMock()
context_manager.__aenter__ = mocker.AsyncMock(
return_value=(mock_read_stream, mock_write_stream)
)
context_manager.__aexit__ = mocker.AsyncMock()
mock_stdio_server = mocker.Mock(return_value=context_manager)
mocker.patch("mcp.server.stdio.stdio_server", mock_stdio_server)

# Mock server run to simulate long-running task
async def mock_run(*args):
# Wait indefinitely or until cancelled
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
pass

mocker.patch("mcp_shell_server.server.app.run", side_effect=mock_run)

# Mock tool handler
tool_handler = MockToolHandler()
mocker.patch("mcp_shell_server.server.tool_handler", tool_handler)

# Run main in a task so we can simulate signal
task = asyncio.create_task(main())

# Give the server a moment to start
await asyncio.sleep(0.1)

# Simulate SIGINT
loop = asyncio.get_running_loop()
loop.call_soon(lambda: [h() for h in loop._signal_handlers.get(signal.SIGINT, [])])

# Wait for main to complete
try:
await asyncio.wait_for(task, timeout=1.0)
except asyncio.TimeoutError:
task.cancel()
await asyncio.sleep(0.1)

# Verify cleanup was called
mock_cleanup_processes.assert_called_once()
context_manager.__aexit__.assert_called_once()

0 comments on commit f86e1e0

Please sign in to comment.