Skip to content

Commit 18f4793

Browse files
committed
Addressing comments, adding unit tests upto format_frame
1 parent 680b7e4 commit 18f4793

2 files changed

Lines changed: 313 additions & 19 deletions

File tree

src/MoBI_View/web/broadcaster.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import time
1212
from typing import Any, Dict, List, Optional, Set
1313

14-
from websockets.asyncio.server import ServerConnection
14+
from websockets.asyncio import server
1515

1616
from MoBI_View.core import config
1717
from MoBI_View.presenters import main_app_presenter
@@ -46,16 +46,15 @@ def __init__(
4646
Config.TIMER_INTERVAL converted to seconds.
4747
"""
4848
self.presenter = presenter
49-
self.clients: Set[ServerConnection] = set()
49+
self.clients: Set[server.ServerConnection] = set()
5050
self._clients_lock = threading.Lock()
5151
self._running = False
5252
self._thread: Optional[threading.Thread] = None
5353
self._loop: Optional[asyncio.AbstractEventLoop] = None
5454

55-
if broadcast_interval is None:
56-
self.broadcast_interval = config.Config.TIMER_INTERVAL / 1000.0
57-
else:
58-
self.broadcast_interval = broadcast_interval
55+
self.broadcast_interval = (
56+
broadcast_interval or config.Config.TIMER_INTERVAL / 1000
57+
)
5958

6059
def start(self) -> None:
6160
"""Starts the broadcast loop in a background thread.
@@ -83,20 +82,22 @@ def stop(self) -> None:
8382
return
8483

8584
self._running = False
86-
if self._thread is not None:
87-
with self._clients_lock:
88-
client_count = len(self.clients)
89-
timeout = (
90-
self.broadcast_interval
91-
+ (client_count * self.CLIENT_SEND_TIMEOUT)
92-
+ self.CLIENT_SEND_TIMEOUT
93-
)
94-
self._thread.join(timeout=timeout)
95-
self._thread = None
85+
if self._thread is None:
86+
return
87+
88+
with self._clients_lock:
89+
client_count = len(self.clients)
90+
timeout = (
91+
self.broadcast_interval
92+
+ (client_count * self.CLIENT_SEND_TIMEOUT)
93+
+ self.CLIENT_SEND_TIMEOUT
94+
)
95+
self._thread.join(timeout=timeout)
96+
self._thread = None
9697
self._loop = None
9798
logger.info("Broadcaster stopped")
9899

99-
def add_client(self, client: ServerConnection) -> None:
100+
def add_client(self, client: server.ServerConnection) -> None:
100101
"""Adds a WebSocket client to the broadcast set.
101102
102103
Args:
@@ -106,7 +107,7 @@ def add_client(self, client: ServerConnection) -> None:
106107
self.clients.add(client)
107108
logger.info("Client added, total clients: %d", len(self.clients))
108109

109-
def remove_client(self, client: ServerConnection) -> None:
110+
def remove_client(self, client: server.ServerConnection) -> None:
110111
"""Removes a WebSocket client from the broadcast set.
111112
112113
Args:
@@ -173,7 +174,7 @@ def _broadcast_to_clients(self, message: str) -> None:
173174
with self._clients_lock:
174175
clients_snapshot = set(self.clients)
175176

176-
disconnected: List[ServerConnection] = []
177+
disconnected: List[server.ServerConnection] = []
177178

178179
for client in clients_snapshot:
179180
try:

tests/unit/test_broadcaster.py

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
"""Unit tests for the Broadcaster class."""
2+
3+
import json
4+
import time
5+
from unittest.mock import MagicMock
6+
7+
import pytest
8+
9+
from MoBI_View.core import config
10+
from MoBI_View.presenters import main_app_presenter
11+
from MoBI_View.web import broadcaster
12+
13+
14+
@pytest.fixture
15+
def mock_presenter() -> MagicMock:
16+
"""Creates a mock MainAppPresenter."""
17+
mock = MagicMock(spec=main_app_presenter.MainAppPresenter)
18+
mock.poll_data.return_value = []
19+
return mock
20+
21+
22+
@pytest.fixture
23+
def broadcaster_instance(mock_presenter: MagicMock) -> broadcaster.Broadcaster:
24+
"""Creates an unstarted Broadcaster instance."""
25+
return broadcaster.Broadcaster(presenter=mock_presenter, broadcast_interval=0.01)
26+
27+
28+
def test_init_sets_presenter_and_defaults(
29+
broadcaster_instance: broadcaster.Broadcaster,
30+
mock_presenter: MagicMock,
31+
) -> None:
32+
"""Tests __init__ sets presenter and initializes default state."""
33+
assert broadcaster_instance.presenter is mock_presenter
34+
assert broadcaster_instance.clients == set()
35+
assert broadcaster_instance._running is False
36+
assert broadcaster_instance._thread is None
37+
assert broadcaster_instance._loop is None
38+
39+
40+
def test_init_uses_default_broadcast_interval_from_config(
41+
mock_presenter: MagicMock,
42+
) -> None:
43+
"""Tests __init__ uses Config.TIMER_INTERVAL when no interval provided."""
44+
expected_interval = config.Config.TIMER_INTERVAL / 1000
45+
46+
bc = broadcaster.Broadcaster(presenter=mock_presenter)
47+
48+
assert bc.broadcast_interval == expected_interval
49+
50+
51+
def test_init_uses_custom_broadcast_interval_when_provided(
52+
mock_presenter: MagicMock,
53+
) -> None:
54+
"""Tests __init__ uses custom interval when provided."""
55+
custom_interval = 0.1
56+
57+
bc = broadcaster.Broadcaster(
58+
presenter=mock_presenter,
59+
broadcast_interval=custom_interval,
60+
)
61+
62+
assert bc.broadcast_interval == custom_interval
63+
64+
65+
def test_start_sets_running_and_creates_thread(
66+
broadcaster_instance: broadcaster.Broadcaster,
67+
) -> None:
68+
"""Tests start() sets _running=True and creates a daemon thread."""
69+
broadcaster_instance.start()
70+
71+
assert broadcaster_instance._running is True
72+
assert broadcaster_instance._thread is not None
73+
assert broadcaster_instance._thread.is_alive()
74+
assert broadcaster_instance._thread.daemon is True
75+
76+
broadcaster_instance.stop()
77+
78+
79+
def test_start_when_already_running_logs_warning_and_does_nothing(
80+
broadcaster_instance: broadcaster.Broadcaster,
81+
caplog: pytest.LogCaptureFixture,
82+
) -> None:
83+
"""Tests start(), when already running, logs warning and keeps same thread."""
84+
broadcaster_instance.start()
85+
original_thread = broadcaster_instance._thread
86+
87+
broadcaster_instance.start()
88+
89+
assert broadcaster_instance._thread is original_thread
90+
assert "already running" in caplog.text
91+
92+
broadcaster_instance.stop()
93+
94+
95+
def test_stop_when_not_running_logs_warning_and_returns(
96+
broadcaster_instance: broadcaster.Broadcaster,
97+
caplog: pytest.LogCaptureFixture,
98+
) -> None:
99+
"""Tests stop(), when _running=False, logs warning and returns early."""
100+
broadcaster_instance.stop()
101+
102+
assert "not running" in caplog.text
103+
104+
105+
def test_stop_when_thread_is_none_returns_early(
106+
broadcaster_instance: broadcaster.Broadcaster,
107+
) -> None:
108+
"""Tests stop() returns early when _thread is None."""
109+
broadcaster_instance._running = True
110+
broadcaster_instance._thread = None
111+
broadcaster_instance._loop = MagicMock()
112+
113+
broadcaster_instance.stop()
114+
115+
assert broadcaster_instance._running is False
116+
assert broadcaster_instance._loop is not None
117+
118+
119+
def test_stop_joins_thread_and_cleans_up(
120+
broadcaster_instance: broadcaster.Broadcaster,
121+
) -> None:
122+
"""Tests stop() joins thread and sets _thread and _loop to None."""
123+
broadcaster_instance.start()
124+
time.sleep(0.02)
125+
126+
broadcaster_instance.stop()
127+
128+
assert broadcaster_instance._running is False
129+
assert broadcaster_instance._thread is None
130+
assert broadcaster_instance._loop is None
131+
132+
133+
def test_stop_calculates_timeout_from_clients(
134+
broadcaster_instance: broadcaster.Broadcaster,
135+
) -> None:
136+
"""Tests stop() calculates join timeout based on client count."""
137+
broadcaster_instance.start()
138+
time.sleep(0.02)
139+
mock_client1 = MagicMock()
140+
mock_client2 = MagicMock()
141+
broadcaster_instance.add_client(mock_client1)
142+
broadcaster_instance.add_client(mock_client2)
143+
144+
broadcaster_instance.stop()
145+
146+
assert broadcaster_instance._thread is None
147+
148+
149+
def test_add_client_adds_to_clients_set(
150+
broadcaster_instance: broadcaster.Broadcaster,
151+
) -> None:
152+
"""Tests add_client() adds client to the clients set."""
153+
mock_client = MagicMock()
154+
155+
broadcaster_instance.add_client(mock_client)
156+
157+
assert mock_client in broadcaster_instance.clients
158+
assert len(broadcaster_instance.clients) == 1
159+
160+
161+
def test_add_client_logs_client_count(
162+
broadcaster_instance: broadcaster.Broadcaster,
163+
caplog: pytest.LogCaptureFixture,
164+
) -> None:
165+
"""Tests add_client() logs the total client count."""
166+
mock_client = MagicMock()
167+
168+
with caplog.at_level("INFO"):
169+
broadcaster_instance.add_client(mock_client)
170+
171+
assert "total clients: 1" in caplog.text
172+
173+
174+
def test_add_client_multiple_times_stores_once(
175+
broadcaster_instance: broadcaster.Broadcaster,
176+
) -> None:
177+
"""Tests adding the same client twice only stores it once."""
178+
mock_client = MagicMock()
179+
180+
broadcaster_instance.add_client(mock_client)
181+
broadcaster_instance.add_client(mock_client)
182+
183+
assert len(broadcaster_instance.clients) == 1
184+
185+
186+
def test_add_client_multiple_different_clients(
187+
broadcaster_instance: broadcaster.Broadcaster,
188+
) -> None:
189+
"""Tests adding multiple different clients."""
190+
client1 = MagicMock()
191+
client2 = MagicMock()
192+
193+
broadcaster_instance.add_client(client1)
194+
broadcaster_instance.add_client(client2)
195+
196+
assert len(broadcaster_instance.clients) == 2
197+
assert client1 in broadcaster_instance.clients
198+
assert client2 in broadcaster_instance.clients
199+
200+
201+
def test_remove_client_removes_from_clients_set(
202+
broadcaster_instance: broadcaster.Broadcaster,
203+
) -> None:
204+
"""Tests remove_client() removes client from the clients set."""
205+
mock_client = MagicMock()
206+
broadcaster_instance.add_client(mock_client)
207+
208+
broadcaster_instance.remove_client(mock_client)
209+
210+
assert mock_client not in broadcaster_instance.clients
211+
assert len(broadcaster_instance.clients) == 0
212+
213+
214+
def test_remove_client_logs_client_count(
215+
broadcaster_instance: broadcaster.Broadcaster,
216+
caplog: pytest.LogCaptureFixture,
217+
) -> None:
218+
"""Tests remove_client() logs the total client count."""
219+
mock_client = MagicMock()
220+
broadcaster_instance.add_client(mock_client)
221+
222+
with caplog.at_level("INFO"):
223+
broadcaster_instance.remove_client(mock_client)
224+
225+
assert "total clients: 0" in caplog.text
226+
227+
228+
def test_remove_client_nonexistent_does_not_raise(
229+
broadcaster_instance: broadcaster.Broadcaster,
230+
) -> None:
231+
"""Tests remove_client() with nonexistent client does not raise (discard)."""
232+
mock_client = MagicMock()
233+
234+
broadcaster_instance.remove_client(mock_client)
235+
236+
assert len(broadcaster_instance.clients) == 0
237+
238+
239+
def test_format_frame_empty_streams_returns_valid_json(
240+
broadcaster_instance: broadcaster.Broadcaster,
241+
) -> None:
242+
"""Tests format_frame() with empty list returns valid JSON with streams key."""
243+
result = broadcaster_instance.format_frame([])
244+
parsed = json.loads(result)
245+
246+
assert "streams" in parsed
247+
assert parsed["streams"] == []
248+
249+
250+
def test_format_frame_single_stream(
251+
broadcaster_instance: broadcaster.Broadcaster,
252+
) -> None:
253+
"""Tests format_frame() with single stream data."""
254+
streams_data = [
255+
{
256+
"stream_name": "EEG",
257+
"data": [1.0, 2.0, 3.0],
258+
"channel_labels": ["Fp1", "Fp2", "Fz"],
259+
}
260+
]
261+
262+
result = broadcaster_instance.format_frame(streams_data)
263+
parsed = json.loads(result)
264+
265+
assert len(parsed["streams"]) == 1
266+
assert parsed["streams"][0]["stream_name"] == "EEG"
267+
assert parsed["streams"][0]["data"] == [1.0, 2.0, 3.0]
268+
assert parsed["streams"][0]["channel_labels"] == ["Fp1", "Fp2", "Fz"]
269+
270+
271+
def test_format_frame_multiple_streams(
272+
broadcaster_instance: broadcaster.Broadcaster,
273+
) -> None:
274+
"""Tests format_frame() with multiple streams."""
275+
streams_data = [
276+
{
277+
"stream_name": "EEG",
278+
"data": [1.0, 2.0, 3.0],
279+
"channel_labels": ["Fp1", "Fp2", "Fz"],
280+
},
281+
{
282+
"stream_name": "Accelerometer",
283+
"data": [0.1, 0.2, 9.8],
284+
"channel_labels": ["X", "Y", "Z"],
285+
},
286+
]
287+
288+
result = broadcaster_instance.format_frame(streams_data)
289+
290+
parsed = json.loads(result)
291+
assert len(parsed["streams"]) == 2
292+
assert parsed["streams"][0]["stream_name"] == "EEG"
293+
assert parsed["streams"][1]["stream_name"] == "Accelerometer"

0 commit comments

Comments
 (0)