Skip to content

Commit 49ae2be

Browse files
committed
SDK-2983: Split entry methods from feature detection
1 parent c9e1907 commit 49ae2be

File tree

3 files changed

+93
-78
lines changed

3 files changed

+93
-78
lines changed

Diff for: src/galaxy/api/plugin.py

+59-73
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,18 @@
11
import asyncio
2+
import dataclasses
23
import json
34
import logging
45
import logging.handlers
5-
import dataclasses
6-
from enum import Enum
6+
import sys
77
from collections import OrderedDict
8+
from enum import Enum
89
from itertools import count
9-
import sys
10+
from typing import Any, Dict, List, Optional, Set, Union
1011

11-
from typing import Any, List, Dict, Optional, Union
12-
13-
from galaxy.api.types import Achievement, Game, LocalGame, FriendInfo, GameTime
14-
15-
from galaxy.api.jsonrpc import Server, NotificationClient, ApplicationError
1612
from galaxy.api.consts import Feature
17-
from galaxy.api.errors import UnknownError, ImportInProgress
18-
from galaxy.api.types import Authentication, NextStep
13+
from galaxy.api.errors import ImportInProgress, UnknownError
14+
from galaxy.api.jsonrpc import ApplicationError, NotificationClient, Server
15+
from galaxy.api.types import Achievement, Authentication, FriendInfo, Game, GameTime, LocalGame, NextStep
1916

2017

2118
class JSONEncoder(json.JSONEncoder):
@@ -24,6 +21,7 @@ def default(self, o): # pylint: disable=method-hidden
2421
# filter None values
2522
def dict_factory(elements):
2623
return {k: v for k, v in elements if v is not None}
24+
2725
return dataclasses.asdict(o, dict_factory=dict_factory)
2826
if isinstance(o, Enum):
2927
return o.value
@@ -32,12 +30,13 @@ def dict_factory(elements):
3230

3331
class Plugin:
3432
"""Use and override methods of this class to create a new platform integration."""
33+
3534
def __init__(self, platform, version, reader, writer, handshake_token):
3635
logging.info("Creating plugin for platform %s, version %s", platform.value, version)
3736
self._platform = platform
3837
self._version = version
3938

40-
self._feature_methods = OrderedDict()
39+
self._features: Set[Feature] = set()
4140
self._active = True
4241
self._pass_control_task = None
4342

@@ -50,6 +49,7 @@ def __init__(self, platform, version, reader, writer, handshake_token):
5049

5150
def eof_handler():
5251
self._shutdown()
52+
5353
self._server.register_eof(eof_handler)
5454

5555
self._achievements_import_in_progress = False
@@ -85,77 +85,65 @@ def eof_handler():
8585
self._register_method(
8686
"import_owned_games",
8787
self.get_owned_games,
88-
result_name="owned_games",
89-
feature=Feature.ImportOwnedGames
88+
result_name="owned_games"
9089
)
90+
self._detect_feature(Feature.ImportOwnedGames, ["get_owned_games"])
91+
9192
self._register_method(
9293
"import_unlocked_achievements",
9394
self.get_unlocked_achievements,
94-
result_name="unlocked_achievements",
95-
feature=Feature.ImportAchievements
96-
)
97-
self._register_method(
98-
"start_achievements_import",
99-
self.start_achievements_import,
100-
)
101-
self._register_method(
102-
"import_local_games",
103-
self.get_local_games,
104-
result_name="local_games",
105-
feature=Feature.ImportInstalledGames
106-
)
107-
self._register_notification("launch_game", self.launch_game, feature=Feature.LaunchGame)
108-
self._register_notification("install_game", self.install_game, feature=Feature.InstallGame)
109-
self._register_notification(
110-
"uninstall_game",
111-
self.uninstall_game,
112-
feature=Feature.UninstallGame
113-
)
114-
self._register_notification(
115-
"shutdown_platform_client",
116-
self.shutdown_platform_client,
117-
feature=Feature.ShutdownPlatformClient
118-
)
119-
self._register_method(
120-
"import_friends",
121-
self.get_friends,
122-
result_name="friend_info_list",
123-
feature=Feature.ImportFriends
124-
)
125-
self._register_method(
126-
"import_game_times",
127-
self.get_game_times,
128-
result_name="game_times",
129-
feature=Feature.ImportGameTime
130-
)
131-
self._register_method(
132-
"start_game_times_import",
133-
self.start_game_times_import,
95+
result_name="unlocked_achievements"
13496
)
97+
self._detect_feature(Feature.ImportAchievements, ["get_unlocked_achievements"])
13598

136-
@property
137-
def features(self):
138-
features = []
139-
if self.__class__ != Plugin:
140-
for feature, handlers in self._feature_methods.items():
141-
if self._implements(handlers):
142-
features.append(feature)
99+
self._register_method("start_achievements_import", self.start_achievements_import)
100+
self._detect_feature(Feature.ImportAchievements, ["import_games_achievements"])
101+
102+
self._register_method("import_local_games", self.get_local_games, result_name="local_games")
103+
self._detect_feature(Feature.ImportInstalledGames, ["get_local_games"])
104+
105+
self._register_notification("launch_game", self.launch_game)
106+
self._detect_feature(Feature.LaunchGame, ["launch_game"])
107+
108+
self._register_notification("install_game", self.install_game)
109+
self._detect_feature(Feature.InstallGame, ["install_game"])
110+
111+
self._register_notification("uninstall_game", self.uninstall_game)
112+
self._detect_feature(Feature.UninstallGame, ["uninstall_game"])
113+
114+
self._register_notification("shutdown_platform_client", self.shutdown_platform_client)
115+
self._detect_feature(Feature.ShutdownPlatformClient, ["shutdown_platform_client"])
116+
117+
self._register_method("import_friends", self.get_friends, result_name="friend_info_list")
118+
self._detect_feature(Feature.ImportFriends, ["get_friends"])
143119

144-
return features
120+
self._register_method("import_game_times", self.get_game_times, result_name="game_times")
121+
self._detect_feature(Feature.ImportGameTime, ["get_game_times"])
122+
123+
self._register_method("start_game_times_import", self.start_game_times_import)
124+
self._detect_feature(Feature.ImportGameTime, ["import_game_times"])
125+
126+
@property
127+
def features(self) -> List[Feature]:
128+
return list(self._features)
145129

146130
@property
147131
def persistent_cache(self) -> Dict:
148132
"""The cache is only available after the :meth:`~.handshake_complete()` is called.
149133
"""
150134
return self._persistent_cache
151135

152-
def _implements(self, handlers):
153-
for handler in handlers:
154-
if handler.__name__ not in self.__class__.__dict__:
136+
def _implements(self, methods: List[str]) -> bool:
137+
for method in methods:
138+
if method not in self.__class__.__dict__:
155139
return False
156140
return True
157141

158-
def _register_method(self, name, handler, result_name=None, internal=False, sensitive_params=False, feature=None):
142+
def _detect_feature(self, feature: Feature, methods: List[str]):
143+
if self._implements(methods):
144+
self._features.add(feature)
145+
146+
def _register_method(self, name, handler, result_name=None, internal=False, sensitive_params=False):
159147
if internal:
160148
def method(*args, **kwargs):
161149
result = handler(*args, **kwargs)
@@ -164,6 +152,7 @@ def method(*args, **kwargs):
164152
result_name: result
165153
}
166154
return result
155+
167156
self._server.register_method(name, method, True, sensitive_params)
168157
else:
169158
async def method(*args, **kwargs):
@@ -173,17 +162,12 @@ async def method(*args, **kwargs):
173162
result_name: result
174163
}
175164
return result
176-
self._server.register_method(name, method, False, sensitive_params)
177165

178-
if feature is not None:
179-
self._feature_methods.setdefault(feature, []).append(handler)
166+
self._server.register_method(name, method, False, sensitive_params)
180167

181-
def _register_notification(self, name, handler, internal=False, sensitive_params=False, feature=None):
168+
def _register_notification(self, name, handler, internal=False, sensitive_params=False):
182169
self._server.register_notification(name, handler, internal, sensitive_params)
183170

184-
if feature is not None:
185-
self._feature_methods.setdefault(feature, []).append(handler)
186-
187171
async def run(self):
188172
"""Plugin's main coroutine."""
189173
await self._server.run()
@@ -192,6 +176,7 @@ async def run(self):
192176

193177
def create_task(self, coro, description):
194178
"""Wrapper around asyncio.create_task - takes care of canceling tasks on shutdown"""
179+
195180
async def task_wrapper(task_id):
196181
try:
197182
return await coro
@@ -524,7 +509,7 @@ async def authenticate(self, stored_credentials=None):
524509
raise NotImplementedError()
525510

526511
async def pass_login_credentials(self, step: str, credentials: Dict[str, str], cookies: List[Dict[str, str]]) \
527-
-> Union[NextStep, Authentication]:
512+
-> Union[NextStep, Authentication]:
528513
"""This method is called if we return galaxy.api.types.NextStep from authenticate or from pass_login_credentials.
529514
This method's parameters provide the data extracted from the web page navigation that previous NextStep finished on.
530515
This method should either return galaxy.api.types.Authentication if the authentication is finished
@@ -607,6 +592,7 @@ async def import_games_achievements(self, game_ids: List[str]) -> None:
607592
608593
:param game_ids: ids of the games for which to import unlocked achievements
609594
"""
595+
610596
async def import_game_achievements(game_id):
611597
try:
612598
achievements = await self.get_unlocked_achievements(game_id)

Diff for: tests/conftest.py

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def plugin(reader, writer):
5858
stack.enter_context(patch.object(Plugin, method))
5959
yield Plugin(Platform.Generic, "0.1", reader, writer, "token")
6060

61+
6162
@pytest.fixture(autouse=True)
6263
def my_caplog(caplog):
6364
caplog.set_level(logging.DEBUG)

Diff for: tests/test_features.py

+33-5
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,49 @@
1+
from galaxy.api.consts import Feature, Platform
12
from galaxy.api.plugin import Plugin
2-
from galaxy.api.consts import Platform, Feature
3+
34

45
def test_base_class():
56
plugin = Plugin(Platform.Generic, "0.1", None, None, None)
6-
assert plugin.features == []
7+
assert set(plugin.features) == {
8+
Feature.ImportInstalledGames,
9+
Feature.ImportOwnedGames,
10+
Feature.LaunchGame,
11+
Feature.InstallGame,
12+
Feature.UninstallGame,
13+
Feature.ImportAchievements,
14+
Feature.ImportGameTime,
15+
Feature.ImportFriends,
16+
Feature.ShutdownPlatformClient
17+
}
18+
719

820
def test_no_overloads():
9-
class PluginImpl(Plugin): #pylint: disable=abstract-method
21+
class PluginImpl(Plugin): # pylint: disable=abstract-method
1022
pass
1123

1224
plugin = PluginImpl(Platform.Generic, "0.1", None, None, None)
1325
assert plugin.features == []
1426

27+
1528
def test_one_method_feature():
16-
class PluginImpl(Plugin): #pylint: disable=abstract-method
29+
class PluginImpl(Plugin): # pylint: disable=abstract-method
30+
async def get_owned_games(self):
31+
pass
32+
33+
plugin = PluginImpl(Platform.Generic, "0.1", None, None, None)
34+
assert plugin.features == [Feature.ImportOwnedGames]
35+
36+
37+
def test_multi_features():
38+
class PluginImpl(Plugin): # pylint: disable=abstract-method
1739
async def get_owned_games(self):
1840
pass
1941

42+
async def import_games_achievements(self, game_ids) -> None:
43+
pass
44+
45+
async def start_game_times_import(self, game_ids) -> None:
46+
pass
47+
2048
plugin = PluginImpl(Platform.Generic, "0.1", None, None, None)
21-
assert plugin.features == [Feature.ImportOwnedGames]
49+
assert set(plugin.features) == {Feature.ImportAchievements, Feature.ImportOwnedGames}

0 commit comments

Comments
 (0)