diff --git a/clamor/gateway/__init__.py b/clamor/gateway/__init__.py index 4f525a7..7fa7b70 100644 --- a/clamor/gateway/__init__.py +++ b/clamor/gateway/__init__.py @@ -1,4 +1,3 @@ from .connector import * from .exceptions import * from .opcodes import * -from .emitter import * diff --git a/clamor/gateway/connector.py b/clamor/gateway/connector.py index 8e68ab6..126ced0 100644 --- a/clamor/gateway/connector.py +++ b/clamor/gateway/connector.py @@ -11,7 +11,7 @@ import anyio from .encoding import ENCODERS -from .emitter import Emitter +from clamor.utils import Emitter from .opcodes import opcodes from .exceptions import * @@ -65,9 +65,14 @@ class DiscordWebsocketClient: def __init__(self, url: str, **kwargs): self.url = url self.encoder = ENCODERS[kwargs.get('encoding', 'json')] - self.zlib_compressed = kwargs.get('zlib_compressed', False) + self.zlib_compressed = kwargs.get('zlib_compressed', True) self.emitter = Emitter() + # Compression + if self.zlib_compressed: + self.buffer = bytearray() + self.inflator = zlib.decompressobj() + # Websocket connection self._con = None self._running = False @@ -86,8 +91,9 @@ def __init__(self, url: str, **kwargs): self._session_id = 0 self._token = "" - self.emitter.add_listener('HELLO', self._on_hello) - self.emitter.add_listener('HEARTBEAT_ACK', self._on_heartbeat_ack) + self.emitter.add_listener(opcodes['HELLO'], self._on_hello) + self.emitter.add_listener(opcodes['HEARTBEAT_ACK'], self._on_heartbeat_ack) + self.emitter.add_listener("READY", self._on_ready) self.format_url() @@ -99,22 +105,23 @@ def format_url(self): async def _receive(self): message = await self._con.get_message() logger.debug("Received message '{}'".format(message)) + if self.zlib_compressed: - # handle zlib compression here - pass - else: - # As there are special cases where zlib-compressed payloads also occur, even - # if zlib-stream wasn't specified in the Gateway url, also try to detect them. - is_json = message[0] == '{' - is_etf = message[0] == 131 - if not is_json and not is_etf: - message = zlib.decompress(message, 15, self.TEN_MEGABYTES).decode('utf-8') + self.buffer.extend(message) + if self.buffer.endswith(self.ZLIB_SUFFIX): + message = self.inflator.decompress(self.buffer).decode() + self.buffer.clear() + elif message[0] != '{' and message[0] != 131: + message = zlib.decompress(message, 15, self.TEN_MEGABYTES).decode('utf-8') + try: message = self.encoder.decode(message) except Exception as e: raise EncodingError(str(e)) + if message.get('s'): self._last_sequence = message['s'] + logger.debug("Decoded message to '{}'".format(message)) return message @@ -130,16 +137,15 @@ async def _send(self, opcode: Union[int, str], data): await self._con.send(json.dumps(payload)) async def _on_hello(self, data): - self._interval = int(data["heartbeat_interval"]) + self._interval = data["heartbeat_interval"] logger.debug("Found heartbeat interval: {}".format(self._interval)) await self._tg.spawn(self._heartbeat_task) async def _on_heartbeat_ack(self, data): self._has_ack = True - async def _on_dispatch(self, data, event): - if event == "ready": - self._session_id = int(data["session_id"]) + async def _on_ready(self, data): + self._session_id = data["session_id"] async def _heartbeat(self): """|coro| @@ -164,7 +170,10 @@ async def _receive_task(self): """ while self._running: message = await self._receive() - await self.emitter.emit(message['op'], message['d'], message.get('t')) + if message['op'] == 0: + await self.emitter.emit(message['t'], message['d']) + else: + await self.emitter.emit(message['op'], message['d']) async def _heartbeat_task(self): while self._running: @@ -219,7 +228,10 @@ async def connect(self): await self.on_open() async def resume(self): - await self.close() + if self._running: + await self.close() + + logger.info("Resuming") async with anysocks.open_connection(self.url) as con: self._con = con self._running = True @@ -229,7 +241,6 @@ async def resume(self): 'seq': self._last_sequence } await self._send('RESUME', payload) - logger.info("Resuming") await self.on_open() async def start(self, token: str): diff --git a/clamor/gateway/emitter.py b/clamor/gateway/emitter.py deleted file mode 100644 index 91d4781..0000000 --- a/clamor/gateway/emitter.py +++ /dev/null @@ -1,28 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Callable, Coroutine, Union, Any - -from .exceptions import InvalidListener -from .opcodes import opcodes - - -class Emitter: - def __init__(self): - self.reg = list() - for i in range(len(opcodes)): - self.reg.append(None) - - def add_listener(self, op: Union[int, str], listener: Callable[..., Coroutine[Any, Any, None]]): - if isinstance(op, str): - op = opcodes[op] - # Todo: check if listener is valid (coroutine) - self.reg[op] = listener - - async def emit(self, op: Union[int, str], data, event: str = None): - if isinstance(op, str): - op = opcodes[op] - if self.reg[op]: - if op == 0 and event: - await self.reg[op](data, event) - else: - await self.reg[op](data) diff --git a/clamor/utils/__init__.py b/clamor/utils/__init__.py new file mode 100644 index 0000000..8e31281 --- /dev/null +++ b/clamor/utils/__init__.py @@ -0,0 +1 @@ +from .emitter import * diff --git a/clamor/utils/emitter.py b/clamor/utils/emitter.py new file mode 100644 index 0000000..dfb411f --- /dev/null +++ b/clamor/utils/emitter.py @@ -0,0 +1,199 @@ +# -*- coding: utf-8 -*- + +from collections import defaultdict +from enum import Enum +from inspect import iscoroutinefunction +from functools import wraps +from typing import Callable, Coroutine, Union, Any, Dict + +from clamor.gateway.exceptions import InvalidListener + +from anyio import create_task_group + + +def check_coroutine(func): + @wraps(func) + def wrapper(self, listener: Callable[..., Coroutine[Any, Any, None]]): + if not iscoroutinefunction(listener): + raise InvalidListener("Listener must be a coroutine") + return func(self, listener) + + return wrapper + + +class Priority(Enum): + BEFORE = 0 + NORMAL = 1 + AFTER = 2 + + +class ListenerPod: + """Event listener module + + Listeners that all follow a certain event will exist in the same pod. Pods will separate + listeners into self-explanatory categories, before, normal, and after. The listeners will + trigger in order from before to after, with each listener triggering synchronously with + listeners from the same category. + + Attributes + ---------- + before : set + Listeners that will trigger 1st + normal : set + Listeners that will trigger 2nd + after : set + Listeners that will trigger 3rd + """ + + def __init__(self): + self.before = set() + self.normal = set() + self.after = set() + + def __bool__(self): + return bool(self.before) or bool(self.normal) or bool(self.after) + + @check_coroutine + def add_before(self, listener: Callable[..., Coroutine[Any, Any, None]]): + """Add listener (Before) + + Add a coroutine to the before category + + Parameters + ---------- + listener : Coroutine + A coroutine to be triggered on it's respective event + """ + self.before.add(listener) + + @check_coroutine + def add_normal(self, listener: Callable[..., Coroutine[Any, Any, None]]): + """Add listener (Normal) + + Add a coroutine to the normal category + + Parameters + ---------- + listener : Coroutine + A coroutine to be triggered on it's respective event + """ + self.normal.add(listener) + + @check_coroutine + def add_after(self, listener: Callable[..., Coroutine[Any, Any, None]]): + """Add listener (After) + + Add a coroutine to the after category + + Parameters + ---------- + listener : Coroutine + A coroutine to be triggered on it's respective event + """ + self.after.add(listener) + + async def emit(self, data: Dict[str, Any]): + """Trigger listeners in the pod + + All listeners in the before category will be spawned with the appropriate payload, and + once all those have finished, the normal category is triggered, and then the after category. + + Parameters + ---------- + data : dict + The payload provided by discord to be distributed + """ + async with create_task_group() as tg: + for listener in self.before: + await tg.spawn(listener, data) + async with create_task_group() as tg: + for listener in self.normal: + await tg.spawn(listener, data) + async with create_task_group() as tg: + for listener in self.after: + await tg.spawn(listener, data) + + +class Emitter: + """Main event emitter + + This is what orchestrates all the event pods, adds listeners, removes them and triggers events. + Events can be either an op code, or a name (for opcode 0). + + Attributes + ---------- + listeners : defaultdict(:class `clamor.gateway.emitter.ListenerPod`:) + A default dict that holders event namess to listener pods. + """ + + def __init__(self): + self.listeners = defaultdict(ListenerPod) + + def add_listener(self, event: Union[int, str], + listener: Callable[..., Coroutine[Any, Any, None]], + order: Priority = Priority.NORMAL): + """Add a listener + + Add a listener to the correct pod and category, which by default is the normal priority. + + Parameters + ---------- + event : int or str + The op code or event to listen too + listener : Coroutine + A coroutine to be triggered on it's respective event + order : :class `clamor.gateway.emitter.Priority` + The order this listener should be triggered in + """ + + # Create a pod if one does not exist for the event, then add the listener + # using the respective method, based on the priority. + getattr(self.listeners[event], "add_" + order.name.lower())(listener) + + async def emit(self, event: Union[int, str], data): + """Emit an event + + Trigger the corresponding ListenerPod if one exists. + + Parameters + ---------- + event: int or str + The op code or event to listen too + data : dict + The payload provided by discord to be distributed + """ + if self.listeners[event]: + await self.listeners[event].emit(data) + + def clear_event(self, event: Union[str, int]): + """Clear all listeners + + Removes all listeners, to matter the category, from the provided event. + + Parameters + ---------- + event : str or int + The op code or event to remove + """ + self.listeners.pop(event) + + def remove_listener(self, event: Union[str, int], + listener: Callable[..., Coroutine[Any, Any, None]]): + """Remove a specific listener from an event + + Removes a the provided listener from an event, no matter the category. + + Parameters + ---------- + event : int or str + The op code or event to search + listener : Coroutine + Listener to remove + """ + if self.listeners[event]: + if listener in self.listeners[event].before: + self.listeners[event].before.remove(listener) + if listener in self.listeners[event].normal: + self.listeners[event].normal.remove(listener) + if listener in self.listeners[event].after: + self.listeners[event].after.remove(listener) diff --git a/tests/test_emitter.py b/tests/test_emitter.py new file mode 100644 index 0000000..9b4ab0b --- /dev/null +++ b/tests/test_emitter.py @@ -0,0 +1,54 @@ +import unittest + +from clamor import Emitter, Priority + +from anyio import run + + +class TestEmitter(unittest.TestCase): + + def test_main_functionality(self): + async def main(): + emitter = Emitter() + goal = [] + + async def early(_): + goal.append(1) + + async def timely(_): + goal.append(2) + + async def late(_): + goal.append(3) + + emitter.add_listener("test", early, Priority.BEFORE) + emitter.add_listener("test", timely) + emitter.add_listener("test", late, Priority.AFTER) + for _ in range(20): # Make sure it wasn't an accident + await emitter.emit("test", {}) + self.assertEqual(goal, [1, 2, 3]) + goal.clear() + run(main) + + def test_removal(self): + async def main(): + emitter = Emitter() + goal = [] + + async def early(_): + goal.append(1) + + async def timely(_): + goal.append(2) + + emitter.add_listener("test", early, Priority.BEFORE) + emitter.add_listener("test", early, Priority.BEFORE) + emitter.add_listener("test", timely) + await emitter.emit("test", {}) + self.assertEqual(len(goal), 3) + goal.clear() + emitter.remove_listener("test", early) + await emitter.emit("test", {}) + self.assertEqual(len(goal), 1) + emitter.clear_event("test") + self.assertFalse(emitter.listeners['test']) diff --git a/tests/test_gateway_connecting.py b/tests/test_gateway_connecting.py index 12dcd1b..8fa14ac 100644 --- a/tests/test_gateway_connecting.py +++ b/tests/test_gateway_connecting.py @@ -9,14 +9,21 @@ class GatewayTests(unittest.TestCase): - def test_gateway_connect(self): + def gateway_connect(self, compressed: bool): async def main(): http = HTTP(os.environ['TEST_BOT_TOKEN']) url = await http.make_request(Routes.GET_GATEWAY) - gw = gateway.DiscordWebsocketClient(url['url']) + gw = gateway.DiscordWebsocketClient(url['url'], zlib_compressed=compressed) + connected = False self.assertIsInstance(gw, gateway.DiscordWebsocketClient) + async def set_connected(data): + nonlocal connected + connected = True + + gw.emitter.add_listener("READY", set_connected) + async def stop_gatway(after): await anyio.sleep(after) await gw.close() @@ -25,4 +32,12 @@ async def stop_gatway(after): await tg.spawn(gw.start, os.environ['TEST_BOT_TOKEN']) await tg.spawn(stop_gatway, 10) + self.assertTrue(connected) + anyio.run(main) + + def test_normal_gateway_connect(self): + self.gateway_connect(False) + + def test_compressed_gateway_connect(self): + self.gateway_connect(True) diff --git a/tests/test_gateway_resume.py b/tests/test_gateway_resume.py index 39af357..d5c3c65 100644 --- a/tests/test_gateway_resume.py +++ b/tests/test_gateway_resume.py @@ -16,18 +16,27 @@ async def main(): gw = gateway.DiscordWebsocketClient(url['url']) self.assertIsInstance(gw, gateway.DiscordWebsocketClient) + reconnected = False - async def trigger_resume(after): - await anyio.sleep(after) + async def got_reconnect(data): + nonlocal reconnected + reconnected = True + + async def trigger_resume(data): + await anyio.sleep(5) await gw.resume() async def stop_gatway(after): await anyio.sleep(after) await gw.close() + gw.emitter.add_listener("RESUMED", got_reconnect) + gw.emitter.add_listener("READY", trigger_resume) + async with anyio.create_task_group() as tg: await tg.spawn(gw.start, os.environ['TEST_BOT_TOKEN']) - await tg.spawn(trigger_resume, 5) - await tg.spawn(stop_gatway, 10) + await tg.spawn(stop_gatway, 90) + + self.assertTrue(reconnected) anyio.run(main)