1212import logging
1313import os
1414import urllib .parse
15+ from collections .abc import Awaitable , Callable
1516from concurrent import futures
1617from contextlib import suppress
1718from functools import partial
4546from music_assistant .models .core_controller import CoreController
4647
4748if TYPE_CHECKING :
48- from collections .abc import Awaitable
49-
5049 from music_assistant_models .config_entries import ConfigValueType , CoreConfig
5150 from music_assistant_models .event import MassEvent
5251
52+ from music_assistant import MusicAssistant
53+
5354DEFAULT_SERVER_PORT = 8095
5455INGRESS_SERVER_PORT = 8094
5556CONF_BASE_URL = "base_url"
@@ -62,9 +63,9 @@ class WebserverController(CoreController):
6263
6364 domain : str = "webserver"
6465
65- def __init__ (self , * args , ** kwargs ) -> None :
66+ def __init__ (self , mass : MusicAssistant ) -> None :
6667 """Initialize instance."""
67- super ().__init__ (* args , ** kwargs )
68+ super ().__init__ (mass )
6869 self ._server = Webserver (self .logger , enable_dynamic_routes = True )
6970 self .register_dynamic_route = self ._server .register_dynamic_route
7071 self .unregister_dynamic_route = self ._server .unregister_dynamic_route
@@ -134,7 +135,7 @@ async def get_config_entries(
134135 async def setup (self , config : CoreConfig ) -> None :
135136 """Async initialize of module."""
136137 # work out all routes
137- routes : list [tuple [str , str , Awaitable ]] = []
138+ routes : list [tuple [str , str , Callable [[ web . Request ], Awaitable [ web . StreamResponse ]] ]] = []
138139 # frontend routes
139140 frontend_dir = locate_frontend ()
140141 for filename in next (os .walk (frontend_dir ))[2 ]:
@@ -182,9 +183,12 @@ async def setup(self, config: CoreConfig) -> None:
182183 else :
183184 ingress_tcp_site_params = None
184185 base_url = str (config .get_value (CONF_BASE_URL ))
185- self .publish_port = int (config .get_value (CONF_BIND_PORT ))
186+ port_value = config .get_value (CONF_BIND_PORT )
187+ assert isinstance (port_value , int )
188+ self .publish_port = port_value
186189 self .publish_ip = default_publish_ip
187190 bind_ip = config .get_value (CONF_BIND_IP )
191+ assert isinstance (bind_ip , str )
188192 # print a big fat message in the log where the webserver is running
189193 # because this is a common source of issues for people with more complex setups
190194 if not self .mass .config .onboard_done :
@@ -221,7 +225,7 @@ async def close(self) -> None:
221225 await client .disconnect ()
222226 await self ._server .close ()
223227
224- async def serve_preview_stream (self , request : web .Request ):
228+ async def serve_preview_stream (self , request : web .Request ) -> web . StreamResponse :
225229 """Serve short preview sample."""
226230 provider_instance_id_or_domain = request .query ["provider" ]
227231 item_id = urllib .parse .unquote (request .query ["item_id" ])
@@ -254,7 +258,7 @@ async def _handle_jsonrpc_api_command(self, request: web.Request) -> web.Respons
254258 try :
255259 command_msg = CommandMessage .from_json (cmd_data )
256260 except ValueError :
257- error = f"Invalid JSON: { cmd_data } "
261+ error = f"Invalid JSON: { cmd_data . decode () } "
258262 self .logger .error ("Unhandled JSONRPC API error: %s" , error )
259263 return web .Response (status = 400 , text = error )
260264 except MissingField as e :
@@ -274,10 +278,11 @@ async def _handle_jsonrpc_api_command(self, request: web.Request) -> web.Respons
274278 error = f"Invalid Command: { command_msg .command } "
275279 self .logger .error ("Unhandled JSONRPC API error: %s" , error )
276280 return web .Response (status = 400 , text = error )
277-
278281 try :
279282 args = parse_arguments (handler .signature , handler .type_hints , command_msg .args )
280- result = handler .target (** args )
283+ result : Any = handler .target (** args )
284+ if asyncio .iscoroutine (result ):
285+ result = await result
281286 if hasattr (result , "__anext__" ):
282287 # handle async generator (for really large listings)
283288 result = [item async for item in result ]
@@ -330,7 +335,7 @@ async def _handle_schemas_reference(self, request: web.Request) -> web.Response:
330335 html = generate_schemas_reference (self .mass .command_handlers )
331336 return web .Response (text = html , content_type = "text/html" )
332337
333- async def _handle_swagger_ui (self , request : web .Request ) -> web .Response :
338+ async def _handle_swagger_ui (self , request : web .Request ) -> web .FileResponse :
334339 """Handle request for Swagger UI."""
335340 swagger_html_path = os .path .join (
336341 os .path .dirname (__file__ ), ".." , "helpers" , "resources" , "swagger_ui.html"
@@ -346,9 +351,9 @@ def __init__(self, webserver: WebserverController, request: web.Request) -> None
346351 self .mass = webserver .mass
347352 self .request = request
348353 self .wsock = web .WebSocketResponse (heartbeat = 55 )
349- self ._to_write : asyncio .Queue = asyncio .Queue (maxsize = MAX_PENDING_MSG )
350- self ._handle_task : asyncio .Task | None = None
351- self ._writer_task : asyncio .Task | None = None
354+ self ._to_write : asyncio .Queue [ str | None ] = asyncio .Queue (maxsize = MAX_PENDING_MSG )
355+ self ._handle_task : asyncio .Task [ Any ] | None = None
356+ self ._writer_task : asyncio .Task [ None ] | None = None
352357 self ._logger = webserver .logger
353358 # try to dynamically detect the base_url of a client if proxied or behind Ingress
354359 self .base_url : str | None = None
@@ -461,18 +466,18 @@ async def _handle_command(self, msg: CommandMessage) -> None:
461466 async def _run_handler (self , handler : APICommandHandler , msg : CommandMessage ) -> None :
462467 try :
463468 args = parse_arguments (handler .signature , handler .type_hints , msg .args )
464- result = handler .target (** args )
469+ result : Any = handler .target (** args )
465470 if hasattr (result , "__anext__" ):
466471 # handle async generator (for really large listings)
467- iterator = result
468- result : list [Any ] = []
469- async for item in iterator :
470- result .append (item )
471- if len (result ) >= 500 :
472+ items : list [Any ] = []
473+ async for item in result :
474+ items .append (item )
475+ if len (items ) >= 500 :
472476 await self ._send_message (
473- SuccessResultMessage (msg .message_id , result , partial = True )
477+ SuccessResultMessage (msg .message_id , items , partial = True )
474478 )
475- result = []
479+ items = []
480+ result = items
476481 elif asyncio .iscoroutine (result ):
477482 result = await result
478483 await self ._send_message (SuccessResultMessage (msg .message_id , result ))
@@ -493,8 +498,10 @@ async def _writer(self) -> None:
493498 while not self .wsock .closed :
494499 if (process := await self ._to_write .get ()) is None :
495500 break
501+ self ._logger .log (VERBOSE_LOG_LEVEL , "Writing: %s" , process )
502+ await self .wsock .send_str (process )
496503
497- if not isinstance (process , str ):
504+ if callable (process ):
498505 message : str = process ()
499506 else :
500507 message = process
0 commit comments