diff --git a/src/mopidy_mpd/actor.py b/src/mopidy_mpd/actor.py index 66c9799..fb660d5 100644 --- a/src/mopidy_mpd/actor.py +++ b/src/mopidy_mpd/actor.py @@ -1,4 +1,6 @@ +import asyncio import logging +import threading from typing import Any import pykka @@ -37,7 +39,11 @@ def __init__(self, config: types.Config, core: CoreProxy) -> None: self.zeroconf_service = None self.uri_map = uri_mapper.MpdUriMapper(core) + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.server = self._setup_server(config, core) + self.server_thread = threading.Thread(target=self._run_server) def _setup_server(self, config: types.Config, core: CoreProxy) -> network.Server: try: @@ -54,11 +60,16 @@ def _setup_server(self, config: types.Config, core: CoreProxy) -> network.Server except OSError as exc: raise exceptions.FrontendError(f"MPD server startup failed: {exc}") from exc - logger.info(f"MPD server running at {network.format_address(server.address)}") + logger.info("MPD server running at %s", network.format_address(server.address)) return server + def _run_server(self) -> None: + self.loop.run_until_complete(self.server.run()) + def on_start(self) -> None: + self.server_thread.start() + if self.zeroconf_name and not network.is_unix_socket(self.server.server_socket): self.zeroconf_service = zeroconf.Zeroconf( name=self.zeroconf_name, stype="_mpd._tcp", port=self.port @@ -73,7 +84,13 @@ def on_stop(self) -> None: for session_actor in session_actors: session_actor.stop() - self.server.stop() + if not self.server_thread.is_alive(): + logger.warning("MPD server already stopped") + return + + self.loop.call_soon_threadsafe(self.server.stop) + logger.debug("Waiting for MPD server thread to terminate") + self.server_thread.join() def on_event(self, event: str, **kwargs: Any) -> None: if event not in _CORE_EVENTS_TO_IDLE_SUBSYSTEMS: diff --git a/src/mopidy_mpd/dispatcher.py b/src/mopidy_mpd/dispatcher.py index 698e8ac..25b20f1 100644 --- a/src/mopidy_mpd/dispatcher.py +++ b/src/mopidy_mpd/dispatcher.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import logging from collections.abc import Callable from typing import ( @@ -100,7 +101,10 @@ def handle_idle(self, subsystem: str) -> None: response = [*[f"changed: {s}" for s in subsystems], "OK"] self.subsystem_events = set() self.subsystem_subscriptions = set() - self.session.send_lines(response) + asyncio.run_coroutine_threadsafe( + self.session.send_lines(response), + self.session.loop, + ) def _call_next_filter( self, request: Request, response: Response, filter_chain: list[Filter] diff --git a/src/mopidy_mpd/network.py b/src/mopidy_mpd/network.py index ba55364..b461284 100644 --- a/src/mopidy_mpd/network.py +++ b/src/mopidy_mpd/network.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import contextlib import errno import logging @@ -8,10 +9,9 @@ import socket import sys import threading -from typing import TYPE_CHECKING, Any, Never +from typing import TYPE_CHECKING, Any, Never, Optional import pykka -from gi.repository import GLib # pyright: ignore[reportMissingModuleSource] logger = logging.getLogger(__name__) @@ -119,7 +119,7 @@ def format_hostname(hostname: str) -> str: class Server: - """Setup listener and register it with GLib's event loop.""" + """Setup listener and register it with the asyncio loop.""" def __init__( # noqa: PLR0913 self, @@ -142,8 +142,7 @@ def __init__( # noqa: PLR0913 self.timeout = timeout self.server_socket = self.create_server_socket(host, port) self.address = get_socket_address(host, port) - - self.watcher = self.register_server_socket(self.server_socket.fileno()) + self._should_stop = asyncio.Event() def create_server_socket(self, host: str, port: int) -> socket.socket: sock = get_systemd_socket() @@ -156,7 +155,7 @@ def create_server_socket(self, host: str, port: int) -> socket.socket: sock.bind(socket_path) else: # ensure the port is supplied - if not isinstance(port, int): + if not (port and isinstance(port, int)): raise TypeError(f"Expected an integer, not {port!r}") sock = create_tcp_socket() sock.bind((host, port)) @@ -166,7 +165,6 @@ def create_server_socket(self, host: str, port: int) -> socket.socket: return sock def stop(self) -> None: - GLib.source_remove(self.watcher) if is_unix_socket(self.server_socket): unix_socket_path = self.server_socket.getsockname() else: @@ -179,36 +177,22 @@ def stop(self) -> None: if unix_socket_path is not None: os.unlink(unix_socket_path) # noqa: PTH108 - def register_server_socket(self, fileno: int) -> int: - return GLib.io_add_watch(fileno, GLib.IO_IN, self.handle_connection) - - def handle_connection(self, _fd: int, _flags: int) -> bool: - try: - sock, addr = self.accept_connection() - except ShouldRetrySocketCallError: - return True + self._should_stop.set() + def handle_connection( + self, + client: socket.socket, + addr: SocketAddress, + loop: asyncio.AbstractEventLoop, + ) -> bool: + if is_unix_socket(client): + addr = (client.getsockname(), None) if self.maximum_connections_exceeded(): - self.reject_connection(sock, addr) + self.reject_connection(client, addr, reason="Maximum connections exceeded") else: - self.init_connection(sock, addr) + self.init_connection(client, addr, loop) return True - def accept_connection(self) -> tuple[socket.socket, SocketAddress]: - try: - sock, addr = self.server_socket.accept() - if is_unix_socket(sock): - addr = (sock.getsockname(), None) - except OSError as exc: - if exc.errno in (errno.EAGAIN, errno.EINTR): - raise ShouldRetrySocketCallError from None - raise - else: - return ( - sock, - addr[:2], # addr is a two-tuple for IPv4 and four-tuple for IPv6 - ) - def maximum_connections_exceeded(self) -> bool: return ( self.max_connections is not None @@ -218,14 +202,17 @@ def maximum_connections_exceeded(self) -> bool: def number_of_connections(self) -> int: return len(pykka.ActorRegistry.get_by_class(self.protocol)) - def reject_connection(self, sock: socket.socket, addr: SocketAddress) -> None: - # TODO: provide more context in logging? - logger.warning("Rejected connection from %s", format_address(addr)) + def reject_connection( + self, sock: socket.socket, addr: SocketAddress, reason: str = "" + ) -> None: + logger.warning("Rejected connection from %s: %s", format_address(addr), reason) with contextlib.suppress(OSError): sock.close() - def init_connection(self, sock: socket.socket, addr: SocketAddress) -> None: - Connection( + def init_connection( + self, sock: socket.socket, addr: SocketAddress, loop: asyncio.AbstractEventLoop + ) -> None: + conn = Connection( config=self.config, core=self.core, uri_map=self.uri_map, @@ -233,18 +220,49 @@ def init_connection(self, sock: socket.socket, addr: SocketAddress) -> None: sock=sock, addr=addr, timeout=self.timeout, + loop=loop, ) + asyncio.create_task(conn.serve()) -class Connection: - # NOTE: the callback code is _not_ run in the actor's thread, but in the - # same one as the event loop. If code in the callbacks blocks, the rest of - # GLib code will likely be blocked as well... - # - # Also note that source_remove() return values are ignored on purpose, a - # false return value would only tell us that what we thought was registered - # is already gone, there is really nothing more we can do. + async def wait_stop(self, timeout: Optional[float] = None) -> None: + await asyncio.wait_for(self._should_stop.wait(), timeout=timeout) + + def should_stop(self) -> bool: + return self._should_stop.is_set() + + async def run(self) -> None: + loop = asyncio.get_event_loop() + self._should_stop.clear() + wait_stop = loop.create_task(self.wait_stop()) + + while not self.should_stop(): + try: + tasks = [ + loop.create_task(loop.sock_accept(self.server_socket)), + wait_stop, + ] + + await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + if tasks[1].done(): + tasks[1].result() + break + + try: + client, addr = tasks[0].result() + self.handle_connection(client, addr, loop) + except OSError as exc: + if exc.errno in (errno.EBADF, errno.ENOTSOCK): + continue + raise exc + except OSError as exc: + if exc.errno in (errno.EAGAIN, errno.EINTR): + continue + raise exc + + +class Connection: host: str port: int | None @@ -258,12 +276,14 @@ def __init__( # noqa: PLR0913 sock: socket.socket, addr: SocketAddress, timeout: int, + loop: asyncio.AbstractEventLoop, ) -> None: sock.setblocking(False) # noqa: FBT003 self.host, self.port = addr[:2] self._sock = sock + self._loop = loop self.protocol = protocol self.timeout = timeout @@ -272,24 +292,54 @@ def __init__( # noqa: PLR0913 self.stopping = False - self.recv_id: int | None = None - self.send_id: int | None = None - self.timeout_id: int | None = None - protocol_kwargs: MpdSessionKwargs = { "config": config, "core": core, "uri_map": uri_map, "connection": self, + "loop": loop, } self.actor_ref = self.protocol.start(**protocol_kwargs) - self.enable_recv() - self.enable_timeout() + async def recv(self) -> bool: + try: + task = asyncio.create_task(self._loop.sock_recv(self._sock, 4096)) + tasks, _ = await asyncio.wait( + [task], timeout=self.timeout, return_when=asyncio.FIRST_COMPLETED + ) + + if not (tasks and task in tasks): + self.stop(f"Client inactive for {self.timeout:d}s; closing connection") + return False + + data = task.result() + except OSError as exc: + if exc.errno in (errno.EWOULDBLOCK, errno.EINTR): + return True + self.stop(f"Unexpected client error: {exc}") + return False + + if not data: + self.actor_ref.tell({"close": True}) + return False + + try: + self.actor_ref.tell({"received": data}) + except pykka.ActorDeadError: + self.stop("Actor is dead.") + return False + + return True + + async def serve(self) -> None: + while not self.stopping: + should_continue = await self.recv() + if not should_continue: + break def stop(self, reason: str, level: int = logging.DEBUG) -> None: if self.stopping: - logger.log(level, f"Already stopping: {reason}") + logger.log(level, "Already stopping: %s", reason) return self.stopping = True @@ -299,128 +349,24 @@ def stop(self, reason: str, level: int = logging.DEBUG) -> None: with contextlib.suppress(pykka.ActorDeadError): self.actor_ref.stop(block=False) - self.disable_timeout() - self.disable_recv() - self.disable_send() - with contextlib.suppress(OSError): self._sock.close() - def queue_send(self, data: bytes) -> None: + async def queue_send(self, data: bytes) -> None: """Try to send data to client exactly as is and queue rest.""" - self.send_lock.acquire(blocking=True) - self.send_buffer = self.send(self.send_buffer + data) - self.send_lock.release() - if self.send_buffer: - self.enable_send() - - def send(self, data: bytes) -> bytes: - """Send data to client, return any unsent data.""" + with self.send_lock: + task = asyncio.create_task(self.send(self.send_buffer + data)) + await asyncio.wait({task}, timeout=self.timeout) + self.send_buffer = b"" + + async def send(self, data: bytes) -> None: + """Send data to client.""" try: - sent = self._sock.send(data) - return data[sent:] + await self._loop.sock_sendall(self._sock, data) except OSError as exc: if exc.errno in (errno.EWOULDBLOCK, errno.EINTR): - return data + return self.stop(f"Unexpected client error: {exc}") - return b"" - - def enable_timeout(self) -> None: - """Reactivate timeout mechanism.""" - if self.timeout <= 0: - return - - self.disable_timeout() - self.timeout_id = GLib.timeout_add_seconds(self.timeout, self.timeout_callback) - - def disable_timeout(self) -> None: - """Deactivate timeout mechanism.""" - if self.timeout_id is None: - return - GLib.source_remove(self.timeout_id) - self.timeout_id = None - - def enable_recv(self) -> None: - if self.recv_id is not None: - return - - try: - self.recv_id = GLib.io_add_watch( - self._sock.fileno(), - GLib.IO_IN | GLib.IO_ERR | GLib.IO_HUP, - self.recv_callback, - ) - except OSError as exc: - self.stop(f"Problem with connection: {exc}") - - def disable_recv(self) -> None: - if self.recv_id is None: - return - GLib.source_remove(self.recv_id) - self.recv_id = None - - def enable_send(self) -> None: - if self.send_id is not None: - return - - try: - self.send_id = GLib.io_add_watch( - self._sock.fileno(), - GLib.IO_OUT | GLib.IO_ERR | GLib.IO_HUP, - self.send_callback, - ) - except OSError as exc: - self.stop(f"Problem with connection: {exc}") - - def disable_send(self) -> None: - if self.send_id is None: - return - - GLib.source_remove(self.send_id) - self.send_id = None - - def recv_callback(self, fd: int, flags: int) -> bool: # noqa: ARG002 - if flags & (GLib.IO_ERR | GLib.IO_HUP): - self.stop(f"Bad client flags: {flags}") - return True - - try: - data = self._sock.recv(4096) - except OSError as exc: - if exc.errno not in (errno.EWOULDBLOCK, errno.EINTR): - self.stop(f"Unexpected client error: {exc}") - return True - - if not data: - self.disable_recv() - self.actor_ref.tell({"close": True}) - return True - - try: - self.actor_ref.tell({"received": data}) - except pykka.ActorDeadError: - self.stop("Actor is dead.") - - return True - - def send_callback(self, fd: int, flags: int) -> bool: # noqa: ARG002 - if flags & (GLib.IO_ERR | GLib.IO_HUP): - self.stop(f"Bad client flags: {flags}") - return True - - # If with can't get the lock, simply try again next time socket is - # ready for sending. - if not self.send_lock.acquire(blocking=False): - return True - - try: - self.send_buffer = self.send(self.send_buffer) - if not self.send_buffer: - self.disable_send() - finally: - self.send_lock.release() - - return True def timeout_callback(self) -> bool: self.stop(f"Client inactive for {self.timeout:d}s; closing connection") @@ -478,7 +424,6 @@ def on_receive(self, message: dict[str, Any]) -> None: if "received" not in message: return - self.connection.disable_timeout() self.recv_buffer += message["received"] for line in self.parse_lines(): @@ -486,17 +431,16 @@ def on_receive(self, message: dict[str, Any]) -> None: if decoded_line is not None: self.on_line_received(decoded_line) - if not self.prevent_timeout: - self.connection.enable_timeout() - def on_failure( self, - exception_type: type[BaseException] | None, # noqa: ARG002 - exception_value: BaseException | None, # noqa: ARG002 - traceback: TracebackType | None, # noqa: ARG002 + exception_type: type[BaseException] | None, + exception_value: BaseException | None, + traceback: TracebackType | None, ) -> None: """Clean up connection resouces when actor fails.""" + super().on_failure(exception_type, exception_value, traceback) self.connection.stop("Actor failed.") + logger.exception("Actor failed.", exc_info=exception_value) def on_stop(self) -> None: """Clean up connection resouces when actor stops.""" @@ -548,7 +492,7 @@ def join_lines(self, lines: list[str]) -> str: line_terminator = self.decode(self.terminator) return line_terminator.join(lines) + line_terminator - def send_lines(self, lines: list[str]) -> None: + async def send_lines(self, lines: list[str]) -> None: """ Send array of lines to client via connection. @@ -562,4 +506,4 @@ def send_lines(self, lines: list[str]) -> None: lines = [line.translate(CONTROL_CHARS) for line in lines] data = self.join_lines(lines) - self.connection.queue_send(self.encode(data)) + await self.connection.queue_send(self.encode(data)) diff --git a/src/mopidy_mpd/session.py b/src/mopidy_mpd/session.py index 515947c..35ab0ae 100644 --- a/src/mopidy_mpd/session.py +++ b/src/mopidy_mpd/session.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import logging from typing import TYPE_CHECKING, Never, TypedDict @@ -19,6 +20,7 @@ class MpdSessionKwargs(TypedDict): core: CoreProxy uri_map: MpdUriMapper connection: network.Connection + loop: asyncio.AbstractEventLoop class MpdSession(network.LineProtocol): @@ -37,6 +39,7 @@ def __init__( core: CoreProxy, uri_map: MpdUriMapper, connection: network.Connection, + loop: asyncio.AbstractEventLoop, ) -> None: super().__init__(connection) self.dispatcher = dispatcher.MpdDispatcher( @@ -46,10 +49,14 @@ def __init__( session=self, ) self.tagtypes = tagtype_list.TAGTYPE_LIST.copy() + self.loop = loop def on_start(self) -> None: logger.info("New MPD connection from %s", self.connection) - self.send_lines([f"OK MPD {protocol.VERSION}"]) + asyncio.run_coroutine_threadsafe( + self.send_lines([f"OK MPD {protocol.VERSION}"]), + self.loop, + ) def on_line_received(self, line: str) -> None: logger.debug("Request from %s: %s", self.connection, line) @@ -71,7 +78,10 @@ def on_line_received(self, line: str) -> None: formatting.indent(self.decode(self.terminator).join(response)), ) - self.send_lines(response) + asyncio.run_coroutine_threadsafe( + self.send_lines(response), + self.loop, + ) def on_event(self, subsystem: str) -> None: self.dispatcher.handle_idle(subsystem) diff --git a/tests/network/test_connection.py b/tests/network/test_connection.py index 9398472..a319476 100644 --- a/tests/network/test_connection.py +++ b/tests/network/test_connection.py @@ -1,78 +1,93 @@ +import asyncio import errno import logging import socket +from typing import Type import unittest -from unittest.mock import Mock, call, patch, sentinel +from unittest.mock import Mock, patch, sentinel import pykka -from gi.repository import GLib -from mopidy_mpd import network, uri_mapper -from tests import any_int, any_unicode +from mopidy_mpd import network, types, uri_mapper +from mopidy_mpd.session import MpdSession +from tests import IsA, any_int, any_unicode class ConnectionTest(unittest.TestCase): + _empty_config = types.Config({}) # type: ignore + + @property + def _mock_protocol(self) -> Type[MpdSession]: + return Mock(spec=network.LineProtocol) # type: ignore + def setUp(self): self.mock = Mock(spec=network.Connection) + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def create_connection(self): + conn = network.Connection( + config=self._empty_config, + core=Mock(), + uri_map=Mock(spec=uri_mapper.MpdUriMapper), + protocol=self._mock_protocol, + sock=Mock(spec=socket.SocketType), + addr=(sentinel.host, sentinel.port), + timeout=1, + loop=self.loop, + ) + + conn._sock = Mock(spec=socket.SocketType) + conn.actor_ref = Mock() + return conn def test_init_ensure_nonblocking_io(self): sock = Mock(spec=socket.SocketType) network.Connection.__init__( self.mock, - config={}, + config=self._empty_config, core=Mock(), uri_map=Mock(spec=uri_mapper.MpdUriMapper), - protocol=Mock(spec=network.LineProtocol), + protocol=self._mock_protocol, sock=sock, addr=(sentinel.host, sentinel.port), timeout=sentinel.timeout, + loop=self.loop, ) sock.setblocking.assert_called_once_with(False) def test_init_starts_actor(self): - protocol = Mock(spec=network.LineProtocol) + protocol = self._mock_protocol network.Connection.__init__( self.mock, - config={}, + config=self._empty_config, core=Mock(), uri_map=Mock(spec=uri_mapper.MpdUriMapper), protocol=protocol, sock=Mock(spec=socket.SocketType), addr=(sentinel.host, sentinel.port), timeout=sentinel.timeout, + loop=self.loop, ) protocol.start.assert_called_once() - def test_init_enables_recv_and_timeout(self): - network.Connection.__init__( - self.mock, - config={}, - core=Mock(), - uri_map=Mock(spec=uri_mapper.MpdUriMapper), - protocol=Mock(spec=network.LineProtocol), - sock=Mock(spec=socket.SocketType), - addr=(sentinel.host, sentinel.port), - timeout=sentinel.timeout, - ) - self.mock.enable_recv.assert_called_once_with() - self.mock.enable_timeout.assert_called_once_with() - def test_init_stores_values_in_attributes(self): addr = (sentinel.host, sentinel.port) - protocol = Mock(spec=network.LineProtocol) + protocol = self._mock_protocol sock = Mock(spec=socket.SocketType) network.Connection.__init__( self.mock, - config={}, + config=self._empty_config, core=Mock(), uri_map=Mock(spec=uri_mapper.MpdUriMapper), protocol=protocol, sock=sock, addr=addr, timeout=sentinel.timeout, + loop=self.loop, ) assert sock == self.mock._sock assert protocol == self.mock.protocol @@ -87,32 +102,22 @@ def test_init_handles_ipv6_addr(self): sentinel.flowinfo, sentinel.scopeid, ) - protocol = Mock(spec=network.LineProtocol) sock = Mock(spec=socket.SocketType) network.Connection.__init__( self.mock, - config={}, + config=self._empty_config, core=Mock(), uri_map=Mock(spec=uri_mapper.MpdUriMapper), - protocol=protocol, + protocol=self._mock_protocol, sock=sock, addr=addr, timeout=sentinel.timeout, + loop=self.loop, ) assert sentinel.host == self.mock.host assert sentinel.port == self.mock.port - def test_stop_disables_recv_send_and_timeout(self): - self.mock.stopping = False - self.mock.actor_ref = Mock() - self.mock._sock = Mock(spec=socket.SocketType) - - network.Connection.stop(self.mock, sentinel.reason) - self.mock.disable_timeout.assert_called_once_with() - self.mock.disable_recv.assert_called_once_with() - self.mock.disable_send.assert_called_once_with() - def test_stop_closes_socket(self): self.mock.stopping = False self.mock.actor_ref = Mock() @@ -191,383 +196,81 @@ def test_stop_logs_that_it_is_calling_itself(self): network.Connection.stop(self.mock, sentinel.reason) network.logger.log(any_int, any_unicode) - @patch.object(GLib, "io_add_watch", new=Mock()) - def test_enable_recv_registers_with_glib(self): - self.mock.recv_id = None - self.mock._sock = Mock(spec=socket.SocketType) - self.mock._sock.fileno.return_value = sentinel.fileno - GLib.io_add_watch.return_value = sentinel.tag - - network.Connection.enable_recv(self.mock) - GLib.io_add_watch.assert_called_once_with( - sentinel.fileno, - GLib.IO_IN | GLib.IO_ERR | GLib.IO_HUP, - self.mock.recv_callback, - ) - assert sentinel.tag == self.mock.recv_id - - @patch.object(GLib, "io_add_watch", new=Mock()) - def test_enable_recv_already_registered(self): - self.mock._sock = Mock(spec=socket.SocketType) - self.mock.recv_id = sentinel.tag - - network.Connection.enable_recv(self.mock) - assert GLib.io_add_watch.call_count == 0 - - def test_enable_recv_does_not_change_tag(self): - self.mock.recv_id = sentinel.tag - self.mock._sock = Mock(spec=socket.SocketType) - - network.Connection.enable_recv(self.mock) - assert sentinel.tag == self.mock.recv_id - - @patch.object(GLib, "source_remove", new=Mock()) - def test_disable_recv_deregisters(self): - self.mock.recv_id = sentinel.tag - - network.Connection.disable_recv(self.mock) - GLib.source_remove.assert_called_once_with(sentinel.tag) - assert self.mock.recv_id is None - - @patch.object(GLib, "source_remove", new=Mock()) - def test_disable_recv_already_deregistered(self): - self.mock.recv_id = None - - network.Connection.disable_recv(self.mock) - assert GLib.source_remove.call_count == 0 - assert self.mock.recv_id is None - - def test_enable_recv_on_closed_socket(self): - self.mock.recv_id = None - self.mock._sock = Mock(spec=socket.SocketType) - self.mock._sock.fileno.side_effect = OSError(errno.EBADF, "") - - network.Connection.enable_recv(self.mock) - self.mock.stop.assert_called_once_with(any_unicode) - assert self.mock.recv_id is None - - @patch.object(GLib, "io_add_watch", new=Mock()) - def test_enable_send_registers_with_glib(self): - self.mock.send_id = None - self.mock._sock = Mock(spec=socket.SocketType) - self.mock._sock.fileno.return_value = sentinel.fileno - GLib.io_add_watch.return_value = sentinel.tag - - network.Connection.enable_send(self.mock) - GLib.io_add_watch.assert_called_once_with( - sentinel.fileno, - GLib.IO_OUT | GLib.IO_ERR | GLib.IO_HUP, - self.mock.send_callback, - ) - assert sentinel.tag == self.mock.send_id - - @patch.object(GLib, "io_add_watch", new=Mock()) - def test_enable_send_already_registered(self): - self.mock._sock = Mock(spec=socket.SocketType) - self.mock.send_id = sentinel.tag - - network.Connection.enable_send(self.mock) - assert GLib.io_add_watch.call_count == 0 - - def test_enable_send_does_not_change_tag(self): - self.mock.send_id = sentinel.tag - self.mock._sock = Mock(spec=socket.SocketType) - - network.Connection.enable_send(self.mock) - assert sentinel.tag == self.mock.send_id - - @patch.object(GLib, "source_remove", new=Mock()) - def test_disable_send_deregisters(self): - self.mock.send_id = sentinel.tag - - network.Connection.disable_send(self.mock) - GLib.source_remove.assert_called_once_with(sentinel.tag) - assert self.mock.send_id is None - - @patch.object(GLib, "source_remove", new=Mock()) - def test_disable_send_already_deregistered(self): - self.mock.send_id = None - - network.Connection.disable_send(self.mock) - assert GLib.source_remove.call_count == 0 - assert self.mock.send_id is None - - def test_enable_send_on_closed_socket(self): - self.mock.send_id = None - self.mock._sock = Mock(spec=socket.SocketType) - self.mock._sock.fileno.side_effect = OSError(errno.EBADF, "") - - network.Connection.enable_send(self.mock) - assert self.mock.send_id is None - - @patch.object(GLib, "timeout_add_seconds", new=Mock()) - def test_enable_timeout_clears_existing_timeouts(self): - self.mock.timeout = 10 - - network.Connection.enable_timeout(self.mock) - self.mock.disable_timeout.assert_called_once_with() - - @patch.object(GLib, "timeout_add_seconds", new=Mock()) - def test_enable_timeout_add_glib_timeout(self): - self.mock.timeout = 10 - GLib.timeout_add_seconds.return_value = sentinel.tag - - network.Connection.enable_timeout(self.mock) - GLib.timeout_add_seconds.assert_called_once_with(10, self.mock.timeout_callback) - assert sentinel.tag == self.mock.timeout_id - - @patch.object(GLib, "timeout_add_seconds", new=Mock()) - def test_enable_timeout_does_not_add_timeout(self): - self.mock.timeout = 0 - network.Connection.enable_timeout(self.mock) - assert GLib.timeout_add_seconds.call_count == 0 - - self.mock.timeout = -1 - network.Connection.enable_timeout(self.mock) - assert GLib.timeout_add_seconds.call_count == 0 - - def test_enable_timeout_does_not_call_disable_for_invalid_timeout(self): - self.mock.timeout = 0 - network.Connection.enable_timeout(self.mock) - assert self.mock.disable_timeout.call_count == 0 - - self.mock.timeout = -1 - network.Connection.enable_timeout(self.mock) - assert self.mock.disable_timeout.call_count == 0 - - @patch.object(GLib, "source_remove", new=Mock()) - def test_disable_timeout_deregisters(self): - self.mock.timeout_id = sentinel.tag - - network.Connection.disable_timeout(self.mock) - GLib.source_remove.assert_called_once_with(sentinel.tag) - assert self.mock.timeout_id is None - - @patch.object(GLib, "source_remove", new=Mock()) - def test_disable_timeout_already_deregistered(self): - self.mock.timeout_id = None - - network.Connection.disable_timeout(self.mock) - assert GLib.source_remove.call_count == 0 - assert self.mock.timeout_id is None - - def test_queue_send_acquires_and_releases_lock(self): - self.mock.send_lock = Mock() - self.mock.send_buffer = b"" - - network.Connection.queue_send(self.mock, b"data") - self.mock.send_lock.acquire.assert_called_once_with(blocking=True) - self.mock.send_lock.release.assert_called_once_with() - def test_queue_send_calls_send(self): - self.mock.send_buffer = b"" - self.mock.send_lock = Mock() - self.mock.send.return_value = b"" - - network.Connection.queue_send(self.mock, b"data") - self.mock.send.assert_called_once_with(b"data") - assert self.mock.enable_send.call_count == 0 - assert self.mock.send_buffer == b"" - - def test_queue_send_calls_enable_send_for_partial_send(self): - self.mock.send_buffer = b"" - self.mock.send_lock = Mock() - self.mock.send.return_value = b"ta" - - network.Connection.queue_send(self.mock, b"data") - self.mock.send.assert_called_once_with(b"data") - self.mock.enable_send.assert_called_once_with() - assert self.mock.send_buffer == b"ta" - - def test_queue_send_calls_send_with_existing_buffer(self): - self.mock.send_buffer = b"foo" - self.mock.send_lock = Mock() - self.mock.send.return_value = b"" - - network.Connection.queue_send(self.mock, b"bar") - self.mock.send.assert_called_once_with(b"foobar") - assert self.mock.enable_send.call_count == 0 - assert self.mock.send_buffer == b"" - - def test_recv_callback_respects_io_err(self): - self.mock._sock = Mock(spec=socket.SocketType) - self.mock.actor_ref = Mock() - - assert network.Connection.recv_callback( - self.mock, sentinel.fd, (GLib.IO_IN | GLib.IO_ERR) - ) - self.mock.stop.assert_called_once_with(any_unicode) - - def test_recv_callback_respects_io_hup(self): - self.mock._sock = Mock(spec=socket.SocketType) - self.mock.actor_ref = Mock() - - assert network.Connection.recv_callback( - self.mock, sentinel.fd, (GLib.IO_IN | GLib.IO_HUP) - ) - self.mock.stop.assert_called_once_with(any_unicode) + conn = self.create_connection() + conn._loop = Mock(spec=asyncio.AbstractEventLoop) + conn.send_buffer = b"" - def test_recv_callback_respects_io_hup_and_io_err(self): - self.mock._sock = Mock(spec=socket.SocketType) - self.mock.actor_ref = Mock() - - assert network.Connection.recv_callback( - self.mock, sentinel.fd, ((GLib.IO_IN | GLib.IO_HUP) | GLib.IO_ERR) - ) - self.mock.stop.assert_called_once_with(any_unicode) + asyncio.run(conn.queue_send(b"data")) + conn._loop.sock_sendall.assert_called_once_with(IsA(Mock), b"data") + assert conn.send_buffer == b"" def test_recv_callback_sends_data_to_actor(self): - self.mock._sock = Mock(spec=socket.SocketType) - self.mock._sock.recv.return_value = b"data" - self.mock.actor_ref = Mock() + conn = self.create_connection() + conn._sock.recv.return_value = b"data" - assert network.Connection.recv_callback(self.mock, sentinel.fd, GLib.IO_IN) - self.mock.actor_ref.tell.assert_called_once_with({"received": b"data"}) + assert asyncio.run(conn.recv()) + conn.actor_ref.tell.assert_called_once_with({"received": b"data"}) def test_recv_callback_handles_dead_actors(self): - self.mock._sock = Mock(spec=socket.SocketType) - self.mock._sock.recv.return_value = b"data" - self.mock.actor_ref = Mock() - self.mock.actor_ref.tell.side_effect = pykka.ActorDeadError() + conn = self.create_connection() + conn._sock.recv.return_value = b"data" + conn.actor_ref.tell.side_effect = pykka.ActorDeadError() - assert network.Connection.recv_callback(self.mock, sentinel.fd, GLib.IO_IN) - self.mock.stop.assert_called_once_with(any_unicode) + assert not asyncio.run(conn.recv()) + conn.actor_ref.stop.assert_called_once() def test_recv_callback_gets_no_data(self): - self.mock._sock = Mock(spec=socket.SocketType) - self.mock._sock.recv.return_value = b"" - self.mock.actor_ref = Mock() + conn = self.create_connection() + conn._sock.recv.return_value = b"" - assert network.Connection.recv_callback(self.mock, sentinel.fd, GLib.IO_IN) - assert self.mock.mock_calls == [ - call._sock.recv(any_int), - call.disable_recv(), - call.actor_ref.tell({"close": True}), + assert not asyncio.run(conn.recv()) + assert conn.actor_ref.mock_calls == [ + ("tell", ({"close": True},), {}), ] def test_recv_callback_recoverable_error(self): - self.mock._sock = Mock(spec=socket.SocketType) + conn = self.create_connection() + conn._loop = Mock(spec=asyncio.AbstractEventLoop) for error in (errno.EWOULDBLOCK, errno.EINTR): - self.mock._sock.recv.side_effect = OSError(error, "") - assert network.Connection.recv_callback(self.mock, sentinel.fd, GLib.IO_IN) - assert self.mock.stop.call_count == 0 + conn._loop.sock_recv.side_effect = OSError(error, "") + assert asyncio.run(conn.recv()) + assert conn.actor_ref.stop.call_count == 0 def test_recv_callback_unrecoverable_error(self): - self.mock._sock = Mock(spec=socket.SocketType) - self.mock._sock.recv.side_effect = socket.error + conn = self.create_connection() + conn._loop = Mock(spec=asyncio.AbstractEventLoop) + conn._loop.sock_recv.side_effect = socket.error - assert network.Connection.recv_callback(self.mock, sentinel.fd, GLib.IO_IN) - self.mock.stop.assert_called_once_with(any_unicode) - - def test_send_callback_respects_io_err(self): - self.mock._sock = Mock(spec=socket.SocketType) - self.mock._sock.send.return_value = 1 - self.mock.send_lock = Mock() - self.mock.actor_ref = Mock() - self.mock.send_buffer = b"" - - assert network.Connection.send_callback( - self.mock, sentinel.fd, (GLib.IO_IN | GLib.IO_ERR) - ) - self.mock.stop.assert_called_once_with(any_unicode) - - def test_send_callback_respects_io_hup(self): - self.mock._sock = Mock(spec=socket.SocketType) - self.mock._sock.send.return_value = 1 - self.mock.send_lock = Mock() - self.mock.actor_ref = Mock() - self.mock.send_buffer = b"" - - assert network.Connection.send_callback( - self.mock, sentinel.fd, (GLib.IO_IN | GLib.IO_HUP) - ) - self.mock.stop.assert_called_once_with(any_unicode) - - def test_send_callback_respects_io_hup_and_io_err(self): - self.mock._sock = Mock(spec=socket.SocketType) - self.mock._sock.send.return_value = 1 - self.mock.send_lock = Mock() - self.mock.actor_ref = Mock() - self.mock.send_buffer = b"" - - assert network.Connection.send_callback( - self.mock, sentinel.fd, ((GLib.IO_IN | GLib.IO_HUP) | GLib.IO_ERR) - ) - self.mock.stop.assert_called_once_with(any_unicode) - - def test_send_callback_acquires_and_releases_lock(self): - self.mock.send_lock = Mock() - self.mock.send_lock.acquire.return_value = True - self.mock.send_buffer = b"" - self.mock._sock = Mock(spec=socket.SocketType) - self.mock._sock.send.return_value = 0 - - assert network.Connection.send_callback(self.mock, sentinel.fd, GLib.IO_IN) - self.mock.send_lock.acquire.assert_called_once_with(blocking=False) - self.mock.send_lock.release.assert_called_once_with() - - def test_send_callback_fails_to_acquire_lock(self): - self.mock.send_lock = Mock() - self.mock.send_lock.acquire.return_value = False - self.mock.send_buffer = b"" - self.mock._sock = Mock(spec=socket.SocketType) - self.mock._sock.send.return_value = 0 - - assert network.Connection.send_callback(self.mock, sentinel.fd, GLib.IO_IN) - self.mock.send_lock.acquire.assert_called_once_with(blocking=False) - assert self.mock._sock.send.call_count == 0 + assert not asyncio.run(conn.recv()) + conn.actor_ref.stop.assert_called_once() def test_send_callback_sends_all_data(self): - self.mock.send_lock = Mock() - self.mock.send_lock.acquire.return_value = True - self.mock.send_buffer = b"data" - self.mock.send.return_value = b"" - - assert network.Connection.send_callback(self.mock, sentinel.fd, GLib.IO_IN) - self.mock.disable_send.assert_called_once_with() - self.mock.send.assert_called_once_with(b"data") - assert self.mock.send_buffer == b"" - - def test_send_callback_sends_partial_data(self): - self.mock.send_lock = Mock() - self.mock.send_lock.acquire.return_value = True - self.mock.send_buffer = b"data" - self.mock.send.return_value = b"ta" - - assert network.Connection.send_callback(self.mock, sentinel.fd, GLib.IO_IN) - self.mock.send.assert_called_once_with(b"data") - assert self.mock.send_buffer == b"ta" + conn = self.create_connection() + conn.send_buffer = b"data" + conn._loop = Mock(spec=asyncio.AbstractEventLoop) + conn._loop.sock_sendall.return_value = None + + asyncio.run(conn.send(conn.send_buffer)) + conn._loop.sock_sendall.assert_called_once_with(IsA(Mock), b"data") def test_send_recoverable_error(self): - self.mock._sock = Mock(spec=socket.SocketType) + conn = self.create_connection() + conn._loop = Mock(spec=asyncio.AbstractEventLoop) for error in (errno.EWOULDBLOCK, errno.EINTR): - self.mock._sock.send.side_effect = OSError(error, "") + conn._loop.sock_sendall.side_effect = OSError(error, "") - network.Connection.send(self.mock, b"data") + asyncio.run(conn.send(b"data")) assert self.mock.stop.call_count == 0 def test_send_calls_socket_send(self): - self.mock._sock = Mock(spec=socket.SocketType) - self.mock._sock.send.return_value = 4 - - assert network.Connection.send(self.mock, b"data") == b"" - self.mock._sock.send.assert_called_once_with(b"data") - - def test_send_calls_socket_send_partial_send(self): - self.mock._sock = Mock(spec=socket.SocketType) - self.mock._sock.send.return_value = 2 + conn = self.create_connection() + conn._sock.send.return_value = 4 - assert network.Connection.send(self.mock, b"data") == b"ta" - self.mock._sock.send.assert_called_once_with(b"data") - - def test_send_unrecoverable_error(self): - self.mock._sock = Mock(spec=socket.SocketType) - self.mock._sock.send.side_effect = socket.error - - assert network.Connection.send(self.mock, b"data") == b"" - self.mock.stop.assert_called_once_with(any_unicode) + asyncio.run(conn.send(b"data")) + conn._sock.send.assert_called_once_with(b"data") def test_timeout_callback(self): self.mock.timeout = 10 diff --git a/tests/protocol/__init__.py b/tests/protocol/__init__.py index bd54101..07a476d 100644 --- a/tests/protocol/__init__.py +++ b/tests/protocol/__init__.py @@ -1,3 +1,4 @@ +import asyncio import unittest from typing import cast from unittest import mock @@ -53,12 +54,15 @@ def setUp(self): ).proxy(), ) + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) self.connection = MockConnection() self.session = session.MpdSession( config=self.get_config(), core=self.core, uri_map=uri_mapper.MpdUriMapper(self.core), connection=self.connection, + loop=self.loop, ) self.dispatcher = self.session.dispatcher self.context = self.dispatcher.context @@ -76,20 +80,20 @@ def assertNoResponse(self): # noqa: N802 assert self.connection.response == [] def assertInResponse(self, value): # noqa: N802 - assert value in self.connection.response, ( - f"Did not find {value!r} in {self.connection.response!r}" - ) + assert ( + value in self.connection.response + ), f"Did not find {value!r} in {self.connection.response!r}" def assertOnceInResponse(self, value): # noqa: N802 matched = len([r for r in self.connection.response if r == value]) - assert matched == 1, ( - f"Expected to find {value!r} once in {self.connection.response!r}" - ) + assert ( + matched == 1 + ), f"Expected to find {value!r} once in {self.connection.response!r}" def assertNotInResponse(self, value): # noqa: N802 - assert value not in self.connection.response, ( - f"Found {value!r} in {self.connection.response!r}" - ) + assert ( + value not in self.connection.response + ), f"Found {value!r} in {self.connection.response!r}" def assertEqualResponse(self, value): # noqa: N802 assert len(self.connection.response) == 1 diff --git a/tests/test_session.py b/tests/test_session.py index 55512e9..b8632e0 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,8 +1,12 @@ +import asyncio import logging from unittest.mock import Mock, sentinel from mopidy_mpd import dispatcher, network, session +loop = asyncio.new_event_loop() +asyncio.set_event_loop(loop) + def test_on_start_logged(caplog): caplog.set_level(logging.INFO) @@ -13,6 +17,7 @@ def test_on_start_logged(caplog): core=None, uri_map=None, connection=connection, + loop=loop, ).on_start() assert f"New MPD connection from {connection}" in caplog.text @@ -26,6 +31,7 @@ def test_on_line_received_logged(caplog): core=None, uri_map=None, connection=connection, + loop=loop, ) mpd_session.dispatcher = Mock(spec=dispatcher.MpdDispatcher) mpd_session.dispatcher.handle_request.return_value = [str(sentinel.resp)]