diff --git a/src/MoBI_View/web/server.py b/src/MoBI_View/web/server.py new file mode 100644 index 0000000..014eafd --- /dev/null +++ b/src/MoBI_View/web/server.py @@ -0,0 +1,98 @@ +"""WebSocket server for MoBI-View real-time data streaming. + +This module provides the ws_handler coroutine that manages WebSocket +client connections and handles incoming discover messages. +""" + +import json +import logging + +from websockets.asyncio import server + +from MoBI_View.core import discovery +from MoBI_View.presenters import main_app_presenter +from MoBI_View.web import broadcaster + +logger = logging.getLogger("MoBI-View.web.server") + + +async def ws_handler( + websocket: server.ServerConnection, + active_broadcaster: broadcaster.Broadcaster, + presenter: main_app_presenter.MainAppPresenter, +) -> None: + """Handle a WebSocket client connection. + + Registers the client with the broadcaster, listens for incoming + messages, and removes the client on disconnect. + + Args: + websocket: The WebSocket connection. + active_broadcaster: The Broadcaster instance managing clients. + presenter: The MainAppPresenter providing data and inlets. + """ + active_broadcaster.add_client(websocket) + try: + async for raw_message in websocket: + await _handle_message(raw_message, websocket, presenter) + finally: + active_broadcaster.remove_client(websocket) + + +async def _handle_message( + raw_message: str | bytes, + websocket: server.ServerConnection, + presenter: main_app_presenter.MainAppPresenter, +) -> None: + """Parse and dispatch a single client message. + + Currently supports only the `discover` command. As the number of + supported commands grows, this function should be refactored to use a + dispatch table (e.g. `handlers = {"discover": _handle_discover}`) + instead of explicit `if/else` branching. + + Args: + raw_message: The raw message received from the WebSocket. + websocket: The WebSocket connection for sending responses. + presenter: The MainAppPresenter for stream management. + """ + try: + data = json.loads(raw_message) + except (json.JSONDecodeError, TypeError, UnicodeDecodeError, RecursionError): + logger.warning("Received invalid JSON") + return + + if not isinstance(data, dict): + logger.warning("Expected JSON object, got %s", type(data).__name__) + return + + command = data.get("command") + if command == "discover": + await _handle_discover(websocket, presenter) + else: + logger.warning("Unknown command: %s", command) + + +async def _handle_discover( + websocket: server.ServerConnection, + presenter: main_app_presenter.MainAppPresenter, +) -> None: + """Run stream discovery and send results to the requesting client. + + Args: + websocket: The WebSocket connection to send the result to. + presenter: The MainAppPresenter managing data inlets. + """ + new_inlets = discovery.discover_and_create_inlets( + existing_inlets=presenter.data_inlets, + ) + presenter.data_inlets.extend(new_inlets) + stream_names = [inlet.stream_name for inlet in new_inlets] + response = json.dumps( + { + "type": "discover_result", + "streams": stream_names, + } + ) + await websocket.send(response) + logger.info("Discover: found %d new stream(s)", len(new_inlets)) diff --git a/tests/unit/test_server.py b/tests/unit/test_server.py new file mode 100644 index 0000000..9cfc5c3 --- /dev/null +++ b/tests/unit/test_server.py @@ -0,0 +1,206 @@ +"""Unit tests for the WebSocket server module.""" + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from MoBI_View.presenters import main_app_presenter +from MoBI_View.web import broadcaster, server + + +@pytest.fixture +def mock_presenter() -> MagicMock: + """Creates a mock MainAppPresenter.""" + mock = MagicMock(spec=main_app_presenter.MainAppPresenter) + mock.poll_data.return_value = [] + mock.data_inlets = [] + return mock + + +@pytest.fixture +def mock_broadcaster(mock_presenter: MagicMock) -> MagicMock: + """Creates a mock Broadcaster.""" + mock = MagicMock(spec=broadcaster.Broadcaster) + mock.presenter = mock_presenter + return mock + + +@pytest.fixture +def mock_websocket() -> AsyncMock: + """Creates a mock ServerConnection.""" + return AsyncMock() + + +def test_ws_handler_registers_and_unregisters_client( + mock_websocket: AsyncMock, + mock_broadcaster: MagicMock, + mock_presenter: MagicMock, +) -> None: + """Tests ws_handler adds client on connect and removes on disconnect.""" + mock_websocket.__aiter__.return_value = iter([]) + + asyncio.run(server.ws_handler(mock_websocket, mock_broadcaster, mock_presenter)) + + mock_broadcaster.add_client.assert_called_once_with(mock_websocket) + mock_broadcaster.remove_client.assert_called_once_with(mock_websocket) + + +def test_ws_handler_removes_client_on_exception( + mock_websocket: AsyncMock, + mock_broadcaster: MagicMock, + mock_presenter: MagicMock, +) -> None: + """Tests ws_handler removes client even when iteration raises.""" + mock_websocket.__aiter__.side_effect = RuntimeError("connection lost") + + with pytest.raises(RuntimeError, match="connection lost"): + asyncio.run(server.ws_handler(mock_websocket, mock_broadcaster, mock_presenter)) + + mock_broadcaster.remove_client.assert_called_once_with(mock_websocket) + + +def test_ws_handler_dispatches_discover_command( + mock_websocket: AsyncMock, + mock_broadcaster: MagicMock, + mock_presenter: MagicMock, +) -> None: + """Tests ws_handler forwards a discover command to _handle_discover.""" + message = json.dumps({"command": "discover"}) + mock_websocket.__aiter__.return_value = iter([message]) + + with patch.object(server, "_handle_discover", new_callable=AsyncMock) as mock_hd: + asyncio.run(server.ws_handler(mock_websocket, mock_broadcaster, mock_presenter)) + + mock_hd.assert_awaited_once_with(mock_websocket, mock_presenter) + + +def test_handle_message_ignores_invalid_json( + mock_websocket: AsyncMock, + mock_presenter: MagicMock, + caplog: pytest.LogCaptureFixture, +) -> None: + """Tests _handle_message logs warning for non-JSON payload.""" + asyncio.run(server._handle_message("not json", mock_websocket, mock_presenter)) + + assert "invalid JSON" in caplog.text + + +def test_handle_message_rejects_non_string_input( + mock_websocket: AsyncMock, + mock_presenter: MagicMock, + caplog: pytest.LogCaptureFixture, +) -> None: + """Tests _handle_message logs warning when input is not str or bytes.""" + asyncio.run( + server._handle_message(None, mock_websocket, mock_presenter) # type: ignore[arg-type] + ) + + assert "invalid JSON" in caplog.text + + +def test_handle_message_rejects_invalid_bytes( + mock_websocket: AsyncMock, + mock_presenter: MagicMock, + caplog: pytest.LogCaptureFixture, +) -> None: + """Tests _handle_message logs warning for bytes with invalid encoding.""" + asyncio.run(server._handle_message(b"\xff\xfe", mock_websocket, mock_presenter)) + + assert "invalid JSON" in caplog.text + + +def test_handle_message_rejects_deeply_nested_json( + mock_websocket: AsyncMock, + mock_presenter: MagicMock, + caplog: pytest.LogCaptureFixture, +) -> None: + """Tests _handle_message logs warning for deeply nested JSON.""" + deeply_nested = "[" * 10000 + "]" * 10000 + + asyncio.run(server._handle_message(deeply_nested, mock_websocket, mock_presenter)) + + assert "invalid JSON" in caplog.text + + +def test_handle_message_rejects_non_dict_json( + mock_websocket: AsyncMock, + mock_presenter: MagicMock, + caplog: pytest.LogCaptureFixture, +) -> None: + """Tests _handle_message logs warning when JSON is not an object.""" + asyncio.run( + server._handle_message('"just a string"', mock_websocket, mock_presenter) + ) + + assert "Expected JSON object" in caplog.text + + +def test_handle_message_logs_unknown_command( + mock_websocket: AsyncMock, + mock_presenter: MagicMock, + caplog: pytest.LogCaptureFixture, +) -> None: + """Tests _handle_message logs warning for unrecognised command.""" + msg = json.dumps({"command": "foobar"}) + + asyncio.run(server._handle_message(msg, mock_websocket, mock_presenter)) + + assert "Unknown command: foobar" in caplog.text + + +def test_handle_message_routes_discover_command( + mock_websocket: AsyncMock, + mock_presenter: MagicMock, +) -> None: + """Tests _handle_message calls _handle_discover for discover command.""" + msg = json.dumps({"command": "discover"}) + + with patch.object(server, "_handle_discover", new_callable=AsyncMock) as mock_hd: + asyncio.run(server._handle_message(msg, mock_websocket, mock_presenter)) + + mock_hd.assert_awaited_once_with(mock_websocket, mock_presenter) + + +def test_handle_discover_calls_discovery_and_sends_result( + mock_websocket: AsyncMock, + mock_presenter: MagicMock, + caplog: pytest.LogCaptureFixture, +) -> None: + """Tests _handle_discover discovers new streams and sends them to client.""" + fake_inlet = MagicMock() + fake_inlet.stream_name = "EEG" + + with ( + patch.object( + server.discovery, + "discover_and_create_inlets", + return_value=[fake_inlet], + ), + caplog.at_level("INFO"), + ): + asyncio.run(server._handle_discover(mock_websocket, mock_presenter)) + + mock_websocket.send.assert_awaited_once() + sent = json.loads(mock_websocket.send.call_args[0][0]) + assert sent["type"] == "discover_result" + assert sent["streams"] == ["EEG"] + assert fake_inlet in mock_presenter.data_inlets + assert "found 1 new stream(s)" in caplog.text + + +def test_handle_discover_with_no_new_streams( + mock_websocket: AsyncMock, + mock_presenter: MagicMock, +) -> None: + """Tests _handle_discover sends empty list when no new streams found.""" + with patch.object( + server.discovery, + "discover_and_create_inlets", + return_value=[], + ): + asyncio.run(server._handle_discover(mock_websocket, mock_presenter)) + + sent = json.loads(mock_websocket.send.call_args[0][0]) + assert sent["streams"] == []