Skip to content

Commit 2b6f6a3

Browse files
authored
Typing fixes for the Webserver controller (#2586)
* Fix merge conflicts * more mypy fixes * Drafting post upstream changes * fix run_handler * PR review comment * Merge conflict fixes * Fix WebserverController __init__ signature to match CoreController
1 parent 487a327 commit 2b6f6a3

File tree

2 files changed

+30
-24
lines changed

2 files changed

+30
-24
lines changed

music_assistant/controllers/webserver.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import logging
1313
import os
1414
import urllib.parse
15+
from collections.abc import Awaitable, Callable
1516
from concurrent import futures
1617
from contextlib import suppress
1718
from functools import partial
@@ -45,11 +46,11 @@
4546
from music_assistant.models.core_controller import CoreController
4647

4748
if TYPE_CHECKING:
48-
from collections.abc import Awaitable
49-
5049
from music_assistant_models.config_entries import ConfigValueType, CoreConfig
5150
from music_assistant_models.event import MassEvent
5251

52+
from music_assistant import MusicAssistant
53+
5354
DEFAULT_SERVER_PORT = 8095
5455
INGRESS_SERVER_PORT = 8094
5556
CONF_BASE_URL = "base_url"
@@ -62,9 +63,9 @@ class WebserverController(CoreController):
6263

6364
domain: str = "webserver"
6465

65-
def __init__(self, *args, **kwargs) -> None:
66+
def __init__(self, mass: MusicAssistant) -> None:
6667
"""Initialize instance."""
67-
super().__init__(*args, **kwargs)
68+
super().__init__(mass)
6869
self._server = Webserver(self.logger, enable_dynamic_routes=True)
6970
self.register_dynamic_route = self._server.register_dynamic_route
7071
self.unregister_dynamic_route = self._server.unregister_dynamic_route
@@ -134,7 +135,7 @@ async def get_config_entries(
134135
async def setup(self, config: CoreConfig) -> None:
135136
"""Async initialize of module."""
136137
# work out all routes
137-
routes: list[tuple[str, str, Awaitable]] = []
138+
routes: list[tuple[str, str, Callable[[web.Request], Awaitable[web.StreamResponse]]]] = []
138139
# frontend routes
139140
frontend_dir = locate_frontend()
140141
for filename in next(os.walk(frontend_dir))[2]:
@@ -182,9 +183,12 @@ async def setup(self, config: CoreConfig) -> None:
182183
else:
183184
ingress_tcp_site_params = None
184185
base_url = str(config.get_value(CONF_BASE_URL))
185-
self.publish_port = int(config.get_value(CONF_BIND_PORT))
186+
port_value = config.get_value(CONF_BIND_PORT)
187+
assert isinstance(port_value, int)
188+
self.publish_port = port_value
186189
self.publish_ip = default_publish_ip
187190
bind_ip = config.get_value(CONF_BIND_IP)
191+
assert isinstance(bind_ip, str)
188192
# print a big fat message in the log where the webserver is running
189193
# because this is a common source of issues for people with more complex setups
190194
if not self.mass.config.onboard_done:
@@ -221,7 +225,7 @@ async def close(self) -> None:
221225
await client.disconnect()
222226
await self._server.close()
223227

224-
async def serve_preview_stream(self, request: web.Request):
228+
async def serve_preview_stream(self, request: web.Request) -> web.StreamResponse:
225229
"""Serve short preview sample."""
226230
provider_instance_id_or_domain = request.query["provider"]
227231
item_id = urllib.parse.unquote(request.query["item_id"])
@@ -254,7 +258,7 @@ async def _handle_jsonrpc_api_command(self, request: web.Request) -> web.Respons
254258
try:
255259
command_msg = CommandMessage.from_json(cmd_data)
256260
except ValueError:
257-
error = f"Invalid JSON: {cmd_data}"
261+
error = f"Invalid JSON: {cmd_data.decode()}"
258262
self.logger.error("Unhandled JSONRPC API error: %s", error)
259263
return web.Response(status=400, text=error)
260264
except MissingField as e:
@@ -274,10 +278,11 @@ async def _handle_jsonrpc_api_command(self, request: web.Request) -> web.Respons
274278
error = f"Invalid Command: {command_msg.command}"
275279
self.logger.error("Unhandled JSONRPC API error: %s", error)
276280
return web.Response(status=400, text=error)
277-
278281
try:
279282
args = parse_arguments(handler.signature, handler.type_hints, command_msg.args)
280-
result = handler.target(**args)
283+
result: Any = handler.target(**args)
284+
if asyncio.iscoroutine(result):
285+
result = await result
281286
if hasattr(result, "__anext__"):
282287
# handle async generator (for really large listings)
283288
result = [item async for item in result]
@@ -330,7 +335,7 @@ async def _handle_schemas_reference(self, request: web.Request) -> web.Response:
330335
html = generate_schemas_reference(self.mass.command_handlers)
331336
return web.Response(text=html, content_type="text/html")
332337

333-
async def _handle_swagger_ui(self, request: web.Request) -> web.Response:
338+
async def _handle_swagger_ui(self, request: web.Request) -> web.FileResponse:
334339
"""Handle request for Swagger UI."""
335340
swagger_html_path = os.path.join(
336341
os.path.dirname(__file__), "..", "helpers", "resources", "swagger_ui.html"
@@ -346,9 +351,9 @@ def __init__(self, webserver: WebserverController, request: web.Request) -> None
346351
self.mass = webserver.mass
347352
self.request = request
348353
self.wsock = web.WebSocketResponse(heartbeat=55)
349-
self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG)
350-
self._handle_task: asyncio.Task | None = None
351-
self._writer_task: asyncio.Task | None = None
354+
self._to_write: asyncio.Queue[str | None] = asyncio.Queue(maxsize=MAX_PENDING_MSG)
355+
self._handle_task: asyncio.Task[Any] | None = None
356+
self._writer_task: asyncio.Task[None] | None = None
352357
self._logger = webserver.logger
353358
# try to dynamically detect the base_url of a client if proxied or behind Ingress
354359
self.base_url: str | None = None
@@ -461,18 +466,18 @@ async def _handle_command(self, msg: CommandMessage) -> None:
461466
async def _run_handler(self, handler: APICommandHandler, msg: CommandMessage) -> None:
462467
try:
463468
args = parse_arguments(handler.signature, handler.type_hints, msg.args)
464-
result = handler.target(**args)
469+
result: Any = handler.target(**args)
465470
if hasattr(result, "__anext__"):
466471
# handle async generator (for really large listings)
467-
iterator = result
468-
result: list[Any] = []
469-
async for item in iterator:
470-
result.append(item)
471-
if len(result) >= 500:
472+
items: list[Any] = []
473+
async for item in result:
474+
items.append(item)
475+
if len(items) >= 500:
472476
await self._send_message(
473-
SuccessResultMessage(msg.message_id, result, partial=True)
477+
SuccessResultMessage(msg.message_id, items, partial=True)
474478
)
475-
result = []
479+
items = []
480+
result = items
476481
elif asyncio.iscoroutine(result):
477482
result = await result
478483
await self._send_message(SuccessResultMessage(msg.message_id, result))
@@ -493,8 +498,10 @@ async def _writer(self) -> None:
493498
while not self.wsock.closed:
494499
if (process := await self._to_write.get()) is None:
495500
break
501+
self._logger.log(VERBOSE_LOG_LEVEL, "Writing: %s", process)
502+
await self.wsock.send_str(process)
496503

497-
if not isinstance(process, str):
504+
if callable(process):
498505
message: str = process()
499506
else:
500507
message = process

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ exclude = [
143143
'^music_assistant/controllers/music.py$',
144144
'^music_assistant/controllers/player_queues.py$',
145145
'^music_assistant/controllers/streams.py$',
146-
'^music_assistant/controllers/webserver.py',
147146
'^music_assistant/helpers/app_vars.py',
148147
'^music_assistant/providers/apple_music/.*$',
149148
'^music_assistant/providers/bluesound/.*$',

0 commit comments

Comments
 (0)