Skip to content

Commit 0c00c00

Browse files
authored
Typing fixes for the config controller (#2570)
1 parent af320bf commit 0c00c00

File tree

2 files changed

+94
-55
lines changed

2 files changed

+94
-55
lines changed

music_assistant/controllers/config.py

Lines changed: 94 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import base64
66
import logging
77
import os
8-
from typing import TYPE_CHECKING, Any
8+
from typing import TYPE_CHECKING, Any, cast
99
from uuid import uuid4
1010

1111
import aiofiles
@@ -69,6 +69,7 @@
6969
from music_assistant.helpers.json import JSON_DECODE_EXCEPTIONS, async_json_dumps, async_json_loads
7070
from music_assistant.helpers.util import load_provider_module
7171
from music_assistant.models import ProviderModuleType
72+
from music_assistant.models.music_provider import MusicProvider
7273

7374
if TYPE_CHECKING:
7475
import asyncio
@@ -117,7 +118,7 @@ async def setup(self) -> None:
117118
@property
118119
def onboard_done(self) -> bool:
119120
"""Return True if onboarding is done."""
120-
return self.get(CONF_ONBOARD_DONE, False)
121+
return bool(self.get(CONF_ONBOARD_DONE, False))
121122

122123
async def close(self) -> None:
123124
"""Handle logic on server stop."""
@@ -196,12 +197,12 @@ async def get_provider_configs(
196197
include_values: bool = False,
197198
) -> list[ProviderConfig]:
198199
"""Return all known provider configurations, optionally filtered by ProviderType."""
199-
raw_values: dict[str, dict] = self.get(CONF_PROVIDERS, {})
200+
raw_values = self.get(CONF_PROVIDERS, {})
200201
prov_entries = {x.domain for x in self.mass.get_provider_manifests()}
201202
return [
202203
await self.get_provider_config(prov_conf["instance_id"])
203204
if include_values
204-
else ProviderConfig.parse([], prov_conf)
205+
else cast("ProviderConfig", ProviderConfig.parse([], prov_conf))
205206
for prov_conf in raw_values.values()
206207
if (provider_type is None or prov_conf["type"] == provider_type)
207208
and (provider_domain is None or prov_conf["domain"] == provider_domain)
@@ -224,7 +225,7 @@ async def get_provider_config(self, instance_id: str) -> ProviderConfig:
224225
else:
225226
msg = f"Unknown provider domain: {raw_conf['domain']}"
226227
raise KeyError(msg)
227-
return ProviderConfig.parse(config_entries, raw_conf)
228+
return cast("ProviderConfig", ProviderConfig.parse(config_entries, raw_conf))
228229
msg = f"No config found for provider id {instance_id}"
229230
raise KeyError(msg)
230231

@@ -284,23 +285,29 @@ async def get_provider_config_entries( # noqa: PLR0915
284285
supported_features = provider.supported_features
285286
else:
286287
provider = None
287-
supported_features: set[ProviderFeature] = getattr(
288-
prov_mod, "SUPPORTED_FEATURES", set()
289-
)
288+
supported_features = getattr(prov_mod, "SUPPORTED_FEATURES", set())
290289
extra_entries: list[ConfigEntry] = []
291290
if manifest.type == ProviderType.MUSIC:
292291
# library sync settings
293292
if ProviderFeature.LIBRARY_ARTISTS in supported_features:
294293
extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_ARTISTS)
295294
if ProviderFeature.LIBRARY_ALBUMS in supported_features:
296295
extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_ALBUMS)
297-
if provider and provider.is_streaming_provider:
296+
if (
297+
provider
298+
and isinstance(provider, MusicProvider)
299+
and provider.is_streaming_provider
300+
):
298301
extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_ALBUM_TRACKS)
299302
if ProviderFeature.LIBRARY_TRACKS in supported_features:
300303
extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_TRACKS)
301304
if ProviderFeature.LIBRARY_PLAYLISTS in supported_features:
302305
extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_PLAYLISTS)
303-
if provider and provider.is_streaming_provider:
306+
if (
307+
provider
308+
and isinstance(provider, MusicProvider)
309+
and provider.is_streaming_provider
310+
):
304311
extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_PLAYLIST_TRACKS)
305312
if ProviderFeature.LIBRARY_AUDIOBOOKS in supported_features:
306313
extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_AUDIOBOOKS)
@@ -413,7 +420,7 @@ async def get_player_configs(
413420
return [
414421
await self.get_player_config(raw_conf["player_id"])
415422
if include_values
416-
else PlayerConfig.parse([], raw_conf)
423+
else cast("PlayerConfig", PlayerConfig.parse([], raw_conf))
417424
for raw_conf in list(self.get(CONF_PLAYERS, {}).values())
418425
# filter out unavailable providers (only if we requested the full info)
419426
if (
@@ -447,7 +454,7 @@ async def get_player_config(
447454
raw_conf["available"] = False
448455
raw_conf["name"] = raw_conf.get("name")
449456
raw_conf["default_name"] = raw_conf.get("default_name") or raw_conf["player_id"]
450-
return PlayerConfig.parse(conf_entries, raw_conf)
457+
return cast("PlayerConfig", PlayerConfig.parse(conf_entries, raw_conf))
451458
msg = f"No config found for player id {player_id}"
452459
raise KeyError(msg)
453460

@@ -480,7 +487,7 @@ async def get_player_config_value(
480487
player_id: str,
481488
key: str,
482489
unpack_splitted_values: bool = False,
483-
) -> ConfigValueType:
490+
) -> ConfigValueType | tuple[str, ...] | list[tuple[str, ...]]:
484491
"""Return single configentry value for a player."""
485492
conf = await self.get_player_config(player_id)
486493
if unpack_splitted_values:
@@ -499,9 +506,12 @@ def get_raw_player_config_value(
499506
500507
Note that this only returns the stored value without any validation or default.
501508
"""
502-
return self.get(
503-
f"{CONF_PLAYERS}/{player_id}/values/{key}",
504-
self.get(f"{CONF_PLAYERS}/{player_id}/{key}", default),
509+
return cast(
510+
"ConfigValueType",
511+
self.get(
512+
f"{CONF_PLAYERS}/{player_id}/values/{key}",
513+
self.get(f"{CONF_PLAYERS}/{player_id}/{key}", default),
514+
),
505515
)
506516

507517
def get_base_player_config(self, player_id: str, provider: str) -> PlayerConfig:
@@ -516,7 +526,7 @@ def get_base_player_config(self, player_id: str, provider: str) -> PlayerConfig:
516526
"player_id": player_id,
517527
"provider": provider,
518528
}
519-
return PlayerConfig.parse([], raw_conf)
529+
return cast("PlayerConfig", PlayerConfig.parse([], raw_conf))
520530

521531
@api_command("config/players/save")
522532
async def save_player_config(
@@ -527,7 +537,7 @@ async def save_player_config(
527537
changed_keys = config.update(values)
528538
if not changed_keys:
529539
# no changes
530-
return None
540+
return config
531541
# validate/handle the update in the player manager
532542
await self.mass.players.on_player_config_change(config, changed_keys)
533543
# actually store changes (if the above did not raise)
@@ -602,9 +612,15 @@ def get_player_dsp_config(self, player_id: str) -> DSPConfig:
602612
dsp_config.filters.append(
603613
ToneControlFilter(
604614
enabled=True,
605-
bass_level=deprecated_eq_bass,
606-
mid_level=deprecated_eq_mid,
607-
treble_level=deprecated_eq_treble,
615+
bass_level=float(deprecated_eq_bass)
616+
if isinstance(deprecated_eq_bass, (int, float, str))
617+
else 0.0,
618+
mid_level=float(deprecated_eq_mid)
619+
if isinstance(deprecated_eq_mid, (int, float, str))
620+
else 0.0,
621+
treble_level=float(deprecated_eq_treble)
622+
if isinstance(deprecated_eq_treble, (int, float, str))
623+
else 0.0,
608624
)
609625
)
610626

@@ -748,17 +764,20 @@ async def create_builtin_provider_config(self, provider_domain: str) -> None:
748764
instance_id = f"{manifest.domain}--{shortuuid.random(8)}"
749765
else:
750766
instance_id = manifest.domain
751-
default_config: ProviderConfig = ProviderConfig.parse(
752-
config_entries,
753-
{
754-
"type": manifest.type.value,
755-
"domain": manifest.domain,
756-
"instance_id": instance_id,
757-
"name": manifest.name,
758-
# note: this will only work for providers that do
759-
# not have any required config entries or provide defaults
760-
"values": {},
761-
},
767+
default_config = cast(
768+
"ProviderConfig",
769+
ProviderConfig.parse(
770+
config_entries,
771+
{
772+
"type": manifest.type.value,
773+
"domain": manifest.domain,
774+
"instance_id": instance_id,
775+
"name": manifest.name,
776+
# note: this will only work for providers that do
777+
# not have any required config entries or provide defaults
778+
"values": {},
779+
},
780+
),
762781
)
763782
default_config.validate()
764783
conf_key = f"{CONF_PROVIDERS}/{default_config.instance_id}"
@@ -770,9 +789,12 @@ async def get_core_configs(self, include_values: bool = False) -> list[CoreConfi
770789
return [
771790
await self.get_core_config(core_controller)
772791
if include_values
773-
else CoreConfig.parse(
774-
[],
775-
self.get(f"{CONF_CORE}/{core_controller}", {"domain": core_controller}),
792+
else cast(
793+
"CoreConfig",
794+
CoreConfig.parse(
795+
[],
796+
self.get(f"{CONF_CORE}/{core_controller}", {"domain": core_controller}),
797+
),
776798
)
777799
for core_controller in CONFIGURABLE_CORE_CONTROLLERS
778800
]
@@ -782,7 +804,7 @@ async def get_core_config(self, domain: str) -> CoreConfig:
782804
"""Return configuration for a single core controller."""
783805
raw_conf = self.get(f"{CONF_CORE}/{domain}", {"domain": domain})
784806
config_entries = await self.get_core_config_entries(domain)
785-
return CoreConfig.parse(config_entries, raw_conf)
807+
return cast("CoreConfig", CoreConfig.parse(config_entries, raw_conf))
786808

787809
@api_command("config/core/get_value")
788810
async def get_core_config_value(self, domain: str, key: str) -> ConfigValueType:
@@ -848,9 +870,12 @@ def get_raw_core_config_value(
848870
849871
Note that this only returns the stored value without any validation or default.
850872
"""
851-
return self.get(
852-
f"{CONF_CORE}/{core_module}/values/{key}",
853-
self.get(f"{CONF_CORE}/{core_module}/{key}", default),
873+
return cast(
874+
"ConfigValueType",
875+
self.get(
876+
f"{CONF_CORE}/{core_module}/values/{key}",
877+
self.get(f"{CONF_CORE}/{core_module}/{key}", default),
878+
),
854879
)
855880

856881
def get_raw_provider_config_value(
@@ -861,9 +886,12 @@ def get_raw_provider_config_value(
861886
862887
Note that this only returns the stored value without any validation or default.
863888
"""
864-
return self.get(
865-
f"{CONF_PROVIDERS}/{provider_instance}/values/{key}",
866-
self.get(f"{CONF_PROVIDERS}/{provider_instance}/{key}", default),
889+
return cast(
890+
"ConfigValueType",
891+
self.get(
892+
f"{CONF_PROVIDERS}/{provider_instance}/values/{key}",
893+
self.get(f"{CONF_PROVIDERS}/{provider_instance}/{key}", default),
894+
),
867895
)
868896

869897
def set_raw_provider_config_value(
@@ -883,6 +911,9 @@ def set_raw_provider_config_value(
883911
msg = f"Invalid provider_instance: {provider_instance}"
884912
raise KeyError(msg)
885913
if encrypted:
914+
if not isinstance(value, str):
915+
msg = f"Cannot encrypt non-string value for key {key}"
916+
raise ValueError(msg)
886917
value = self.encrypt_string(value)
887918
if key in BASE_KEYS:
888919
self.set(f"{CONF_PROVIDERS}/{provider_instance}/{key}", value)
@@ -934,6 +965,7 @@ def encrypt_string(self, str_value: str) -> str:
934965
"""Encrypt a (password)string with Fernet."""
935966
if str_value.startswith(ENCRYPT_SUFFIX):
936967
return str_value
968+
assert self._fernet is not None
937969
return ENCRYPT_SUFFIX + self._fernet.encrypt(str_value.encode()).decode()
938970

939971
def decrypt_string(self, encrypted_str: str) -> str:
@@ -942,6 +974,7 @@ def decrypt_string(self, encrypted_str: str) -> str:
942974
return encrypted_str
943975
if not encrypted_str.startswith(ENCRYPT_SUFFIX):
944976
return encrypted_str
977+
assert self._fernet is not None
945978
try:
946979
return self._fernet.decrypt(encrypted_str.replace(ENCRYPT_SUFFIX, "").encode()).decode()
947980
except InvalidToken as err:
@@ -972,7 +1005,6 @@ async def _migrate(self) -> None: # noqa: PLR0915
9721005
instance_id: str
9731006
provider_config: dict[str, Any]
9741007
player_config: dict[str, Any]
975-
values: dict[str, ConfigValueType]
9761008

9771009
# Older versions of MA can create corrupt entries with no domain if retrying
9781010
# logic runs after a provider has been removed. Remove those corrupt entries.
@@ -1020,7 +1052,12 @@ async def _migrate(self) -> None: # noqa: PLR0915
10201052
# migrate player_group entries
10211053
ugp_found = False
10221054
for player_config in self._data.get(CONF_PLAYERS, {}).values():
1023-
if not player_config.get("provider").startswith("player_group"):
1055+
provider = player_config.get("provider")
1056+
if (
1057+
not provider
1058+
or not isinstance(provider, str)
1059+
or not provider.startswith("player_group")
1060+
):
10241061
continue
10251062
if not (values := player_config.get("values")):
10261063
continue
@@ -1144,7 +1181,7 @@ async def _add_provider_config(
11441181
self,
11451182
provider_domain: str,
11461183
values: dict[str, ConfigValueType],
1147-
) -> list[ConfigEntry] | ProviderConfig:
1184+
) -> ProviderConfig:
11481185
"""
11491186
Add new Provider (instance).
11501187
@@ -1181,15 +1218,18 @@ async def _add_provider_config(
11811218
config_entries = await self.get_provider_config_entries(
11821219
provider_domain=provider_domain, instance_id=instance_id, values=values
11831220
)
1184-
config: ProviderConfig = ProviderConfig.parse(
1185-
config_entries,
1186-
{
1187-
"type": manifest.type.value,
1188-
"domain": manifest.domain,
1189-
"instance_id": instance_id,
1190-
"default_name": manifest.name,
1191-
"values": values,
1192-
},
1221+
config = cast(
1222+
"ProviderConfig",
1223+
ProviderConfig.parse(
1224+
config_entries,
1225+
{
1226+
"type": manifest.type.value,
1227+
"domain": manifest.domain,
1228+
"instance_id": instance_id,
1229+
"default_name": manifest.name,
1230+
"values": values,
1231+
},
1232+
),
11931233
)
11941234
# validate the new config
11951235
config.validate()

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ enable_error_code = [
132132
]
133133
exclude = [
134134
'^music_assistant/controllers/cache.py$',
135-
'^music_assistant/controllers/config.py$',
136135
'^music_assistant/controllers/media/albums.py*$',
137136
'^music_assistant/controllers/media/artists.py*$',
138137
'^music_assistant/controllers/media/audiobooks.py*$',

0 commit comments

Comments
 (0)