diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index e1de5ec..0782f4a 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -25,4 +25,4 @@ jobs: python -m pip install build python -m build - name: Publish release distributions to PyPI - uses: pypa/gh-action-pypi-publish@v1.11.0 \ No newline at end of file + uses: pypa/gh-action-pypi-publish@v1.12.4 \ No newline at end of file diff --git a/README.md b/README.md index a1b3de8..a3d3b8e 100644 --- a/README.md +++ b/README.md @@ -284,6 +284,69 @@ And a command like the following is called by the Stream Deck software: streamdeck -port 28196 -pluginUUID 63831042F4048F072B096732E0385245 -registerEvent registerPlugin -info '{"application": {...}, "plugin": {"uuid": "my-plugin-name", "version": "1.1.3"}, ...}' ``` +## Custom Event Listeners + +The SDK allows you to create custom event listeners and events by extending the `EventListener` and `EventBase` classes. This is useful when you need to monitor data from external applications and perform specific actions in response to changes or alerts. + +### Creating a Custom Event Listener + +To create a custom event listener: + +1. Create new event model that inherits from `EventBase`. +2. Create a new class that inherits from `EventListener`. + a. Implement the required `listen` and `stop` methods. The `listen` method should yield results as a json string that matches the new event model. + b. List the new event classes in the `event_models` class variable of the new `EventListener` class. +3. Configure your plugin in its `pyproject.toml` file to use your custom listener. + +```python +# custom_listener.py +from collections.abc import Generator +from typing import ClassVar, Literal + +from streamdeck.event_listener import EventListener +from streamdeck.models.events import EventBase + + +class MyCustomEvent(EventBase): + event: Literal["somethingHappened"] + ... # Define additional data attributes here + +class MyCustomEventListener(EventListener): + def listen(self) -> Generator[str | bytes, None, None]: + ... + # Listen/poll for something here in a loop, and yield the result. + # This will be ran in a background thread. + # Ex: + # while self._running is True: + # result = module.check_status() + # if result is not None: + # yield json.dumps({"event": "somethingHappend", "result": result}) + # time.sleep(1) + + def stop(self) -> None: + ... + # Stop the loop or blocking call in the listen method. + # Ex: + # self._running = False +``` + +### Configuring Your Custom Listener + +To use your custom event listener, add it to your `pyproject.toml` file: + +```toml +[tools.streamdeck] + action_scripts = [ + "main.py", + ] + event_listener_modules = [ + "myplugin.custom_listener", + ] +``` + +The `event_listeners` list should contain strings in module format for each module you want to use. + + ## Creating and Packaging Plugins To create a new plugin with all of the necessary files to start from and package it for use on your Stream Deck, use [the Python SDK CLI tool](https://github.com/strohganoff/python-streamdeck-plugin-sdk-cli). diff --git a/streamdeck/__main__.py b/streamdeck/__main__.py index c9c5a9d..67f7a55 100644 --- a/streamdeck/__main__.py +++ b/streamdeck/__main__.py @@ -72,10 +72,18 @@ def main( info=info_data, ) + # Event listeners and their Event models are registered before actions in order to validate the actions' registered events' names. + for event_listener in pyproject.event_listeners: + manager.register_event_listener(event_listener()) + for action in actions: manager.register_action(action) - manager.run() + try: + manager.run() + except Exception as e: + logger.exception("Error in plugin manager") + raise # Also run the plugin if this script is ran as a console script. diff --git a/streamdeck/actions.py b/streamdeck/actions.py index 25d222c..6161848 100644 --- a/streamdeck/actions.py +++ b/streamdeck/actions.py @@ -6,7 +6,7 @@ from logging import getLogger from typing import TYPE_CHECKING, cast -from streamdeck.types import BaseEventHandlerFunc, available_event_names +from streamdeck.types import BaseEventHandlerFunc if TYPE_CHECKING: @@ -22,7 +22,7 @@ class ActionBase(ABC): """Base class for all actions.""" - def __init__(self): + def __init__(self) -> None: """Initialize an Action instance. Args: @@ -42,13 +42,13 @@ def on(self, event_name: EventNameStr, /) -> Callable[[EventHandlerFunc[TEvent_c Raises: KeyError: If the provided event name is not available. """ - if event_name not in available_event_names: - msg = f"Provided event name for action handler does not exist: {event_name}" - raise KeyError(msg) + # if event_name not in DEFAULT_EVENT_NAMES: + # msg = f"Provided event name for action handler does not exist: {event_name}" + # raise KeyError(msg) def _wrapper(func: EventHandlerFunc[TEvent_contra]) -> EventHandlerFunc[TEvent_contra]: # Cast to BaseEventHandlerFunc so that the storage type is consistent. - self._events[event_name].add(cast(BaseEventHandlerFunc, func)) + self._events[event_name].add(cast("BaseEventHandlerFunc", func)) return func @@ -66,15 +66,22 @@ def get_event_handlers(self, event_name: EventNameStr, /) -> Generator[EventHand Raises: KeyError: If the provided event name is not available. """ - if event_name not in available_event_names: - msg = f"Provided event name for pulling handlers from action does not exist: {event_name}" - raise KeyError(msg) + # if event_name not in DEFAULT_EVENT_NAMES: + # msg = f"Provided event name for pulling handlers from action does not exist: {event_name}" + # raise KeyError(msg) if event_name not in self._events: return yield from self._events[event_name] + def get_registered_event_names(self) -> list[str]: + """Get all event names for which event handlers are registered. + + Returns: + list[str]: The list of event names for which event handlers are registered. + """ + return list(self._events.keys()) class GlobalAction(ActionBase): """Represents an action that is performed at the plugin level, meaning it isn't associated with a specific device or action.""" @@ -83,7 +90,7 @@ class GlobalAction(ActionBase): class Action(ActionBase): """Represents an action that can be performed for a specific action, with event handlers for specific event types.""" - def __init__(self, uuid: str): + def __init__(self, uuid: str) -> None: """Initialize an Action instance. Args: diff --git a/streamdeck/command_sender.py b/streamdeck/command_sender.py index 7885d69..9afa230 100644 --- a/streamdeck/command_sender.py +++ b/streamdeck/command_sender.py @@ -257,7 +257,6 @@ def send_to_plugin( def send_action_registration( self, register_event: str, - plugin_registration_uuid: str, ) -> None: """Registers a plugin with the Stream Deck software very shortly after the plugin is started. @@ -270,10 +269,8 @@ def send_action_registration( Args: register_event (str): The registration event type, passed in by the Stream Deck software as -registerEvent option. It's value will almost definitely will be "registerPlugin". - plugin_registration_uuid (str): Randomly-generated unique ID passed in by StreamDeck as -pluginUUID option, - used to send back in the registerPlugin event. Note that this is NOT the manifest.json -configured plugin UUID value. """ self._send_event( event=register_event, - uuid=plugin_registration_uuid, + uuid=self._plugin_registration_uuid, ) diff --git a/streamdeck/event_listener.py b/streamdeck/event_listener.py new file mode 100644 index 0000000..2e1da2e --- /dev/null +++ b/streamdeck/event_listener.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import threading +from abc import ABC, abstractmethod +from logging import getLogger +from queue import Queue +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Generator + from typing import Any, ClassVar + + from typing_extensions import TypeIs + + from streamdeck.models.events import EventBase + + + +logger = getLogger("streamdeck.event_listener") + + +class _SENTINAL: + """A sentinel object used to signal the end of the event stream. + + Not meant to be instantiated, but rather used as a singleton (e.g. `_SENTINAL`). + """ + @classmethod + def is_sentinal(cls, event: str | bytes | type[_SENTINAL]) -> TypeIs[type[_SENTINAL]]: + """Check if an event is the sentinal object. Provided to enable better type-checking.""" + return event is cls + + +class StopStreaming(Exception): # noqa: N818 + """Raised by an EventListener implementation to signal that the entire EventManagerListener should stop streaming events.""" + + +class EventListenerManager: + """Manages event listeners and provides a shared event queue for them to push events into. + + With this class, a single event stream can be created from multiple listeners. + This allows for us to listen for not only Stream Deck events, but also other events plugin-developer -defined events. + """ + def __init__(self) -> None: + self.event_queue: Queue[str | bytes | type[_SENTINAL]] = Queue() + self.listeners_lookup_by_thread: dict[threading.Thread, EventListener] = {} + self._running = False + + def add_listener(self, listener: EventListener) -> None: + """Registers a listener function that yields events. + + Args: + listener: A function that yields events. + """ + # Create a thread for the listener + thread = threading.Thread( + target=self._listener_wrapper, + args=(listener,), + daemon=True, + ) + self.listeners_lookup_by_thread[thread] = listener + + def _listener_wrapper(self, listener: EventListener) -> None: + """Wraps the listener function: for each event yielded, push it into the shared queue.""" + try: + for event in listener.listen(): + self.event_queue.put(event) + + if not self.running: + break + + except StopStreaming: + logger.debug("Event listener requested to stop streaming.") + self.event_queue.put(_SENTINAL) + + except Exception: + logger.exception("Unexpected error in wrapped listener %s. Stopping just this listener.", listener) + + def stop(self) -> None: + """Stops the event generation loop and waits for all threads to finish. + + Listeners will check the running flag if implemented to stop listening. + """ + # Set the running flag to False to stop the listeners running in separate threads. + self.running = False + # Push the sentinel to immediately unblock the queue.get() in event_stream. + self.event_queue.put(_SENTINAL) + + for thread in self.listeners_lookup_by_thread: + logger.debug("Stopping listener %s.") + self.listeners_lookup_by_thread[thread].stop() + if thread.is_alive(): + thread.join() + + logger.info("All listeners have been stopped.") + + def event_stream(self) -> Generator[str | bytes, None, None]: + """Starts all registered listeners, sets the running flag to True, and yields events from the shared queue.""" + logger.info("Starting event stream.") + # Set the running flag to True and start the listeners in their separate threads. + self.running = True + for thread in self.listeners_lookup_by_thread: + thread.start() + + try: + while True: + event = self.event_queue.get() + if _SENTINAL.is_sentinal(event): + logger.debug("Sentinal received, stopping event stream.") + break # Exit loop immediately if the sentinal is received + yield event + finally: + self.stop() + + +class EventListener(ABC): + """Base class for event listeners. + + Event listeners are classes that listen for events and simply yield them as they come. + The EventListenerManager will handle the threading and pushing the events yielded into a shared queue. + """ + event_models: ClassVar[list[type[EventBase]]] + """A list of event models that the listener can yield. Read in by the PluginManager to model the incoming event data off of. + + The plugin-developer must define this list in their subclass. + """ + + @abstractmethod + def listen(self) -> Generator[str | bytes, Any, None]: + """Start listening for events and yield them as they come. + + This is the method that run in a separate thread. + """ + + @abstractmethod + def stop(self) -> None: + """Stop the listener. This could set an internal flag, close a connection, etc.""" diff --git a/streamdeck/manager.py b/streamdeck/manager.py index 1cd76d6..e74ff63 100644 --- a/streamdeck/manager.py +++ b/streamdeck/manager.py @@ -4,21 +4,24 @@ from logging import getLogger from typing import TYPE_CHECKING +from pydantic import ValidationError + from streamdeck.actions import Action, ActionBase, ActionRegistry from streamdeck.command_sender import StreamDeckCommandSender -from streamdeck.models.events import ContextualEventMixin, event_adapter +from streamdeck.event_listener import EventListener, EventListenerManager +from streamdeck.models.events import ContextualEventMixin, EventAdapter from streamdeck.types import ( EventHandlerBasicFunc, EventHandlerFunc, TEvent_contra, is_bindable_handler, - is_valid_event_name, ) from streamdeck.utils.logging import configure_streamdeck_logger from streamdeck.websocket import WebSocketClient if TYPE_CHECKING: + from collections.abc import Generator from typing import Any, Literal from streamdeck.models.events import EventBase @@ -59,7 +62,21 @@ def __init__( self._register_event = register_event self._info = info - self._registry = ActionRegistry() + self._action_registry = ActionRegistry() + self._event_listener_manager = EventListenerManager() + self._event_adapter = EventAdapter() + + def _ensure_action_has_valid_events(self, action: ActionBase) -> None: + """Ensure that the action's registered events are valid. + + Args: + action (Action): The action to validate. + """ + for event_name in action.get_registered_event_names(): + if not self._event_adapter.event_name_exists(event_name): + msg = f"Invalid event received: {event_name}" + logger.error(msg) + raise KeyError(msg) def register_action(self, action: ActionBase) -> None: """Register an action with the PluginManager, and configure its logger. @@ -67,11 +84,25 @@ def register_action(self, action: ActionBase) -> None: Args: action (Action): The action to register. """ - # First, configure a logger for the action, giving it the last part of its uuid as name (if it has one). + # First, validate that the action's registered events are valid. + self._ensure_action_has_valid_events(action) + + # Next, configure a logger for the action, giving it the last part of its uuid as name (if it has one). action_component_name = action.uuid.split(".")[-1] if isinstance(action, Action) else "global" configure_streamdeck_logger(name=action_component_name, plugin_uuid=self.uuid) - self._registry.register(action) + self._action_registry.register(action) + + def register_event_listener(self, listener: EventListener) -> None: + """Register an event listener with the PluginManager, and add its event models to the event adapter. + + Args: + listener (EventListener): The event listener to register. + """ + self._event_listener_manager.add_listener(listener) + + for event_model in listener.event_models: + self._event_adapter.add_model(event_model) def _inject_command_sender(self, handler: EventHandlerFunc[TEvent_contra], command_sender: StreamDeckCommandSender) -> EventHandlerBasicFunc[TEvent_contra]: """Inject command_sender into handler if it accepts it as a parameter. @@ -88,6 +119,30 @@ def _inject_command_sender(self, handler: EventHandlerFunc[TEvent_contra], comma return handler + def _stream_event_data(self) -> Generator[EventBase, None, None]: + """Stream event data from the event listeners. + + Validate and model the incoming event data before yielding it. + + Yields: + EventBase: The event data received from the event listeners. + """ + for message in self._event_listener_manager.event_stream(): + try: + data: EventBase = self._event_adapter.validate_json(message) + except ValidationError: + logger.exception("Error modeling event data.") + continue + + logger.debug("Event received: %s", data.event) + + # TODO: is this necessary? Or would this be covered by the event adapter validation? + if not self._event_adapter.event_name_exists(data.event): + logger.error("Invalid event received: %s", data.event) + continue + + yield data + def run(self) -> None: """Run the PluginManager by connecting to the WebSocket server and processing incoming events. @@ -95,26 +150,20 @@ def run(self) -> None: and triggers the appropriate action handlers based on the received events. """ with WebSocketClient(port=self._port) as client: - command_sender = StreamDeckCommandSender(client, plugin_registration_uuid=self._registration_uuid) - - command_sender.send_action_registration(register_event=self._register_event, plugin_registration_uuid=self._registration_uuid) - - for message in client.listen(): - data: EventBase = event_adapter.validate_json(message) + # Register the WebSocketClient as an event listener, so we can receive Stream Deck events. + self.register_event_listener(client) - if not is_valid_event_name(data.event): - logger.error("Received event name is not valid: %s", data.event) - continue - - logger.debug("Event received: %s", data.event) + command_sender = StreamDeckCommandSender(client, plugin_registration_uuid=self._registration_uuid) + command_sender.send_action_registration(register_event=self._register_event) + for data in self._stream_event_data(): # If the event is action-specific, we'll pass the action's uuid to the handler to ensure only the correct action is triggered. event_action_uuid = data.action if isinstance(data, ContextualEventMixin) else None - for event_handler in self._registry.get_action_handlers(event_name=data.event, event_action_uuid=event_action_uuid): + for event_handler in self._action_registry.get_action_handlers(event_name=data.event, event_action_uuid=event_action_uuid): processed_handler = self._inject_command_sender(event_handler, command_sender) # TODO: from contextual event occurences, save metadata to the action's properties. processed_handler(data) - + logger.info("PluginManager has stopped processing events.") diff --git a/streamdeck/models/configs.py b/streamdeck/models/configs.py index 94d576d..f097379 100644 --- a/streamdeck/models/configs.py +++ b/streamdeck/models/configs.py @@ -1,7 +1,7 @@ from __future__ import annotations -from types import ModuleType -from typing import TYPE_CHECKING, Annotated +from types import ModuleType # noqa: TC003 +from typing import TYPE_CHECKING, Annotated, ClassVar import tomli as toml from pydantic import ( @@ -14,19 +14,36 @@ ) from streamdeck.actions import ActionBase +from streamdeck.event_listener import EventListener +from streamdeck.models.events import EventBase if TYPE_CHECKING: from collections.abc import Generator from pathlib import Path - from typing import Any +def parse_objects_from_modules(value: list[ModuleType]) -> Generator[object, None, None]: + """Loop through objects in each provided module to be yielded. + + Methods and attributes that are magic, special, or built-in are ignored. + """ + for module in value: + for object_name in dir(module): + obj = getattr(module, object_name) + + # Ignore magic/special/built-in methods and attributes. + if object_name.startswith("__"): + continue + + yield obj + class PyProjectConfigs(BaseModel): """A Pydantic model for the PyProject.toml configuration file to load a Stream Deck plugin's actions.""" - tool: ToolSection = Field(alias="tool") + tool: ToolSection + """The "tool" section of a pyproject.toml file, which contains configs for project tools, including for this Stream Deck plugin.""" @classmethod def validate_from_toml_file(cls, filepath: Path, action_scripts: list[str] | None = None) -> PyProjectConfigs: @@ -59,17 +76,24 @@ def overwrite_action_scripts(cls, data: object, info: ValidationInfo) -> object: return data @property - def streamdeck_plugin_actions(self) -> Generator[ActionBase, Any, None]: + def streamdeck(self) -> StreamDeckToolConfig: + """Reach into the [tool.streamdeck] section of the PyProject.toml file and return the plugin's configuration.""" + return self.tool.streamdeck + + @property + def streamdeck_plugin_actions(self) -> Generator[ActionBase, None, None]: """Reach into the [tool.streamdeck] section of the PyProject.toml file and yield the plugin's actions configured by the developer.""" - for loaded_action_script in self.tool.streamdeck.action_script_modules: - for object_name in dir(loaded_action_script): - obj = getattr(loaded_action_script, object_name) + yield from self.streamdeck.actions - # Ensure the object isn't a magic method or attribute of the loaded module. - if object_name.startswith("__"): - continue + @property + def event_listeners(self) -> Generator[type[EventListener], None, None]: + """Reach into the [tool.streamdeck] section of the PyProject.toml file and yield the plugin's event listeners configured by the developer.""" + yield from self.streamdeck.event_listeners - yield obj + @property + def event_models(self) -> Generator[type[EventBase], None, None]: + """Reach into the [tool.streamdeck] section of the PyProject.toml file and yield the plugin's event models configured by the developer.""" + yield from self.streamdeck.event_models class ToolSection(BaseModel): @@ -78,6 +102,7 @@ class ToolSection(BaseModel): Nothing much to see here, just a wrapper around the model representing the "streamdeck" subsection. """ streamdeck: StreamDeckToolConfig + """The "streamdeck" subsection in the "tool" section of the pyproject.toml file, which contains the developer's configuration for their Stream Deck plugin.""" class StreamDeckToolConfig(BaseModel, arbitrary_types_allowed=True): @@ -91,25 +116,54 @@ class StreamDeckToolConfig(BaseModel, arbitrary_types_allowed=True): This field is filtered to only include objects that are subclasses of ActionBase (as well as the built-in magic methods and attributes typically found in a module). """ + event_listener_modules: list[ImportString[ModuleType]] = [] + """A list of loaded event listener modules with all of their objects. + + This field is filtered to only include objects that are subclasses of EventListener & EventBase (as well as the built-in magic methods and attributes typically found in a module). + """ + + # The following fields are populated by the field validators below, and are not during Pydantic's validation process. + actions: ClassVar[list[ActionBase]] = [] + """The collected ActionBase subclass instances from the loaded modules listed in the `action_script_modules` field.""" + event_listeners: ClassVar[list[type[EventListener]]] = [] + """The collected EventListener subclasses from the loaded modules listed in the `event_listener_modules` field.""" + event_models: ClassVar[list[type[EventBase]]] = [] + """The collected EventBase subclasses from the loaded modules listed in the `event_listener_modules` field.""" + @field_validator("action_script_modules", mode="after") @classmethod - def filter_module_objects(cls, value: list[ModuleType]) -> list[ModuleType]: - """Filter out non- ActionBase subclasses from the list of objects loaded from each action script module.""" - loaded_modules: list[ModuleType] = [] + def filter_action_module_objects(cls, value: list[ModuleType]) -> list[ModuleType]: + """Loop through objects in each configured action script module, and collect ActionBase subclasses. - for module in value: - new_module = ModuleType(module.__name__) + The value arg isn't modified here, it is simply returned as-is at the end of the method. + """ + for obj in parse_objects_from_modules(value): + # Ignore obj if it's not an instance of an ActionBase subclass. + if not isinstance(obj, ActionBase): + continue - for object_name in dir(module): - obj = getattr(module, object_name) + cls.actions.append(obj) - if not isinstance(obj, ActionBase): - continue + return value - setattr(new_module, object_name, obj) + @field_validator("event_listener_modules", mode="after") + @classmethod + def filter_event_listener_module_objects(cls, value: list[ModuleType]) -> list[ModuleType]: + """Loop through objects in each configured event listener module, and collect EventListener and EventBase subclasses. - loaded_modules.append(new_module) + The value arg isn't modified here, it is simply returned as-is at the end of the method. + """ + for obj in parse_objects_from_modules(value): + # Ensure obj is a type (class definition), but not the base classes themselves. + if not isinstance(obj, type) or obj in (EventListener, EventBase): + continue - return loaded_modules + # Collect obj if it is an EventListener subclass + if issubclass(obj, EventListener): + cls.event_listeners.append(obj) + # Collect obj if it is an EventBase subclass + if issubclass(obj, EventBase): + cls.event_models.append(obj) + return value diff --git a/streamdeck/models/events.py b/streamdeck/models/events.py index bad3cd2..b13789d 100644 --- a/streamdeck/models/events.py +++ b/streamdeck/models/events.py @@ -1,33 +1,79 @@ from __future__ import annotations from abc import ABC -from typing import Annotated, Any, Literal, Union +from typing import Annotated, Any, Final, Literal, Union, get_args, get_type_hints from pydantic import BaseModel, ConfigDict, Field, TypeAdapter -from typing_extensions import TypedDict +from typing_extensions import LiteralString, TypedDict, TypeIs # noqa: UP035 # TODO: Create more explicitly-defined payload objects. class EventBase(BaseModel, ABC): + """Base class for event models that represent Stream Deck Plugin SDK events.""" + # Configure to use the docstrings of the fields as the field descriptions. model_config = ConfigDict(use_attribute_docstrings=True) event: str - """Name of the event used to identify what occurred.""" + """Name of the event used to identify what occurred. + Subclass models must define this field as a Literal type with the event name string that the model represents. + """ + + def __init_subclass__(cls, **kwargs: Any) -> None: + """Validate that the event field is a Literal[str] type.""" + super().__init_subclass__(**kwargs) + + model_event_type = get_type_hints(cls)["event"] + + if not is_literal_str_type(model_event_type): + msg = f"The event field annotation must be a Literal[str] type. Given type: {model_event_type}" + raise TypeError(msg) + + @classmethod + def get_model_event_name(cls) -> tuple[str, ...]: + """Get the value of the subclass model's event field Literal annotation.""" + model_event_type = get_type_hints(cls)["event"] + + # Ensure that the event field annotation is a Literal type. + if not is_literal_str_type(model_event_type): + msg = "The `event` field annotation of an Event model must be a Literal[str] type." + raise TypeError(msg) + + return get_args(model_event_type) + + +def is_literal_str_type(value: object | None) -> TypeIs[LiteralString]: + """Check if a type is a Literal type.""" + if value is None: + return False + + event_field_base_type = getattr(value, "__origin__", None) + + if event_field_base_type is not Literal: + return False + + return all(isinstance(literal_value, str) for literal_value in get_args(value)) + + +## Mixin classes for common event model fields. class ContextualEventMixin: + """Mixin class for event models that have action and context fields.""" action: str """Unique identifier of the action""" context: str """Identifies the instance of an action that caused the event, i.e. the specific key or dial.""" class DeviceSpecificEventMixin: + """Mixin class for event models that have a device field.""" device: str """Unique identifier of the Stream Deck device that this event is associated with.""" +## EventBase implementation models of the Stream Deck Plugin SDK events. + class ApplicationDidLaunch(EventBase): event: Literal["applicationDidLaunch"] # type: ignore[override] payload: dict[Literal["application"], str] @@ -148,6 +194,7 @@ class WillDisappear(EventBase, ContextualEventMixin, DeviceSpecificEventMixin): payload: dict[str, Any] +## Default event models and names. event_adapter: TypeAdapter[EventBase] = TypeAdapter( @@ -177,3 +224,83 @@ class WillDisappear(EventBase, ContextualEventMixin, DeviceSpecificEventMixin): Field(discriminator="event") ] ) + + +DEFAULT_EVENT_MODELS: Final[list[type[EventBase]]] = [ + ApplicationDidLaunch, + ApplicationDidTerminate, + DeviceDidConnect, + DeviceDidDisconnect, + DialDown, + DialRotate, + DialUp, + DidReceiveDeepLink, + KeyUp, + KeyDown, + DidReceivePropertyInspectorMessage, + PropertyInspectorDidAppear, + PropertyInspectorDidDisappear, + DidReceiveGlobalSettings, + DidReceiveSettings, + SystemDidWakeUp, + TitleParametersDidChange, + TouchTap, + WillAppear, + WillDisappear, +] + + +def _get_default_event_names() -> set[str]: + default_event_names: set[str] = set() + + for event_model in DEFAULT_EVENT_MODELS: + default_event_names.update(event_model.get_model_event_name()) + + return default_event_names + + +DEFAULT_EVENT_NAMES: Final[set[str]] = _get_default_event_names() + + +## EventAdapter class for handling and extending available event models. + +class EventAdapter: + """TypeAdapter-encompassing class for handling and extending available event models.""" + def __init__(self) -> None: + self._models: list[type[EventBase]] = [] + self._type_adapter: TypeAdapter[EventBase] | None = None + + self._event_names: set[str] = set() + """A set of all event names that have been registered with the adapter. + This set starts out containing the default event models defined by the library. + """ + + for model in DEFAULT_EVENT_MODELS: + self.add_model(model) + + def add_model(self, model: type[EventBase]) -> None: + """Add a model to the adapter, and add the event name of the model to the set of registered event names.""" + self._models.append(model) + self._event_names.update(model.get_model_event_name()) + + def event_name_exists(self, event_name: str) -> bool: + """Check if an event name has been registered with the adapter.""" + return event_name in self._event_names + + @property + def type_adapter(self) -> TypeAdapter[EventBase]: + """Get the TypeAdapter instance for the event models.""" + if self._type_adapter is None: + self._type_adapter = TypeAdapter( + Annotated[ + Union[tuple(self._models)], # noqa: UP007 + Field(discriminator="event") + ] + ) + + return self._type_adapter + + def validate_json(self, data: str | bytes) -> EventBase: + """Validate a JSON string or bytes object as an event model.""" + return self.type_adapter.validate_json(data) + diff --git a/streamdeck/types.py b/streamdeck/types.py index 52a6f71..a2d3994 100644 --- a/streamdeck/types.py +++ b/streamdeck/types.py @@ -1,77 +1,36 @@ from __future__ import annotations import inspect -from typing import TYPE_CHECKING, Literal, Protocol, TypeVar, Union +from typing import TYPE_CHECKING, Protocol, TypeVar, Union -from typing_extensions import TypeAlias, TypeIs # noqa: UP035 - -from streamdeck.models.events import EventBase +from streamdeck.models.events import DEFAULT_EVENT_NAMES, EventBase if TYPE_CHECKING: + from typing_extensions import TypeAlias, TypeIs # noqa: UP035 + from streamdeck.command_sender import StreamDeckCommandSender -available_event_names: set[EventNameStr] = { - "applicationDidLaunch", - "applicationDidTerminate", - "deviceDidConnect", - "deviceDidDisconnect", - "dialDown", - "dialRotate", - "dialUp", - "didReceiveGlobalSettings", - "didReceiveDeepLink", - "didReceiveSettings", - "sendToPlugin", # DidReceivePropertyInspectorMessage event - "keyDown", - "keyUp", - "propertyInspectorDidAppear", - "propertyInspectorDidDisappear", - "systemDidWakeUp", - "titleParametersDidChange", - "touchTap", - "willAppear", - "willDisappear", -} - - -# For backwards compatibility with older versions of Python, we can't just use available_event_names as the values of the Literal EventNameStr. -EventNameStr: TypeAlias = Literal[ # noqa: UP040 - "applicationDidLaunch", - "applicationDidTerminate", - "deviceDidConnect", - "deviceDidDisconnect", - "dialDown", - "dialRotate", - "dialUp", - "didReceiveGlobalSettings", - "didReceiveDeepLink", - "didReceiveSettings", - "sendToPlugin", # DidReceivePropertyInspectorMessage event - "keyDown", - "keyUp", - "propertyInspectorDidAppear", - "propertyInspectorDidDisappear", - "systemDidWakeUp", - "titleParametersDidChange", - "touchTap", - "willAppear", - "willDisappear" -] + +EventNameStr: TypeAlias = str # noqa: UP040 +"""Type alias for the event name string. + +We don't define literal string values here, as the list of available event names can be added to dynamically. +""" def is_valid_event_name(event_name: str) -> TypeIs[EventNameStr]: """Check if the event name is one of the available event names.""" - return event_name in available_event_names + return event_name in DEFAULT_EVENT_NAMES ### Event Handler Type Definitions ### ## Protocols for event handler functions that act on subtypes of EventBase instances in a Generic way. -# A type variable for a subtype of EventBase TEvent_contra = TypeVar("TEvent_contra", bound=EventBase, contravariant=True) +"""Type variable for a subtype of EventBase.""" class EventHandlerBasicFunc(Protocol[TEvent_contra]): @@ -84,8 +43,8 @@ class EventHandlerBindableFunc(Protocol[TEvent_contra]): def __call__(self, event_data: TEvent_contra, command_sender: StreamDeckCommandSender) -> None: ... -# Type alias for an event handler function that takes an event (of subtype of EventBase), and optionally a command sender. EventHandlerFunc = Union[EventHandlerBasicFunc[TEvent_contra], EventHandlerBindableFunc[TEvent_contra]] # noqa: UP007 +"""Type alias for an event handler function that takes an event (of subtype of EventBase), and optionally a command sender.""" ## Protocols for event handler functions that act on EventBase instances. @@ -97,8 +56,11 @@ class BaseEventHandlerBindableFunc(EventHandlerBindableFunc[EventBase]): """Protocol for an event handler function that takes an event (of subtype of EventBase) and a command sender.""" -# Type alias for a base event handler function that expects an actual EventBase instance (and optionally a command sender) — used for type hinting internal storage of event handlers. BaseEventHandlerFunc = Union[BaseEventHandlerBasicFunc, BaseEventHandlerBindableFunc] # noqa: UP007 +"""Type alias for a base event handler function that takes an actual EventBase instance argument, and optionally a command sender. + +This is used for type hinting internal storage of event handlers. +""" diff --git a/streamdeck/utils/helper_actions.py b/streamdeck/utils/helper_actions.py index 9255bd2..6589e55 100644 --- a/streamdeck/utils/helper_actions.py +++ b/streamdeck/utils/helper_actions.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING from streamdeck.actions import Action -from streamdeck.types import available_event_names +from streamdeck.models.events import DEFAULT_EVENT_NAMES if TYPE_CHECKING: @@ -24,7 +24,7 @@ def log_event(event_data: EventBase) -> None: logger.info("Action %s — event %s", logging_action.__class__, event_data.event) # Register the above function for every event - for event_name in available_event_names: + for event_name in DEFAULT_EVENT_NAMES: logging_action.on(event_name)(log_event) return logging_action @@ -45,7 +45,7 @@ def write_event(event_data: EventBase) -> None: file.flush() # Register the above function for every event - for event_name in available_event_names: + for event_name in DEFAULT_EVENT_NAMES: file_writing_action.on(event_name)(write_event) return file_writing_action diff --git a/streamdeck/websocket.py b/streamdeck/websocket.py index 8651837..8fa0504 100644 --- a/streamdeck/websocket.py +++ b/streamdeck/websocket.py @@ -1,16 +1,25 @@ +"""A client for connecting to the Stream Deck device's WebSocket server and sending/receiving events. + +Inherits from the EventListener class to work with the EventListenerManager for processing events. +""" from __future__ import annotations import json from logging import getLogger from typing import TYPE_CHECKING +from websockets import ConnectionClosedError, WebSocketException from websockets.exceptions import ConnectionClosed, ConnectionClosedOK from websockets.sync.client import ClientConnection, connect +from streamdeck.event_listener import EventListener, StopStreaming +from streamdeck.models import events + if TYPE_CHECKING: from collections.abc import Generator - from typing import Any + from types import TracebackType + from typing import Any, ClassVar from typing_extensions import Self # noqa: UP035 @@ -19,10 +28,12 @@ -class WebSocketClient: +class WebSocketClient(EventListener): """A client for connecting to the Stream Deck device's WebSocket server and sending/receiving events.""" _client: ClientConnection | None + event_models: ClassVar[list[type[events.EventBase]]] = events.DEFAULT_EVENT_MODELS + def __init__(self, port: int): """Initialize a WebSocketClient instance. @@ -53,21 +64,30 @@ def listen(self) -> Generator[str | bytes, Any, None]: Yields: Union[str, bytes]: The received message from the WebSocket server. """ - # TODO: Check that self._client is a connected thing. + if self._client is None: + msg = "WebSocket connection not established yet." + raise ValueError(msg) + try: while True: message: str | bytes = self._client.recv() yield message - except ConnectionClosedOK as exc: - logger.debug("Connection was closed normally, stopping the client.") - logger.exception(dir(exc)) + except WebSocketException as exc: + if isinstance(exc, ConnectionClosedOK): + logger.debug("Connection was closed normally, stopping the client.") + elif isinstance(exc, ConnectionClosedError): + logger.exception("Connection was terminated with an error.") + elif isinstance(exc, ConnectionClosed): + logger.exception("Connection was already closed.") + else: + logger.exception("Connection is closed due to an unexpected WebSocket error.") - except ConnectionClosed: - logger.exception("Connection was closed with an error.") + raise StopStreaming from None except Exception: logger.exception("Failed to receive messages from websocket server due to unexpected error.") + raise def start(self) -> None: """Start the connection to the websocket server.""" @@ -91,7 +111,7 @@ def __enter__(self) -> Self: self.start() return self - def __exit__(self, *args, **kwargs) -> None: + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> None: """Close the WebSocket connection, if open.""" self.stop() diff --git a/tests/actions/test_action.py b/tests/actions/test_action.py deleted file mode 100644 index c5751b8..0000000 --- a/tests/actions/test_action.py +++ /dev/null @@ -1,89 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest -from streamdeck.actions import Action -from streamdeck.types import available_event_names - - -if TYPE_CHECKING: - from streamdeck.models.events import EventBase - from streamdeck.types import EventNameStr - - -@pytest.mark.parametrize("event_name", list(available_event_names)) -def test_action_register_event_handler(event_name: EventNameStr): - """Test that an event handler can be registered for each valid event name.""" - action = Action("test.uuid.for.action") - - @action.on(event_name) - def handler(event: EventBase) -> None: - pass - - # Ensure the handler is registered for the correct event name - assert len(action._events[event_name]) == 1 - assert handler in action._events[event_name] - - -def test_action_register_invalid_event_handler(): - """Test that attempting to register an invalid event handler raises an exception.""" - action = Action("test.uuid.for.action") - - with pytest.raises(Exception): - @action.on("InvalidEvent") - def handler(event: EventBase): - pass - - -@pytest.mark.parametrize("event_name", list(available_event_names)) -def test_action_get_event_handlers(event_name: EventNameStr): - """Test that the correct event handlers are retrieved for each event name.""" - action = Action("test.uuid.for.action") - - # Register a handler for the event name - @action.on(event_name) - def handler(event: EventBase): - pass - - # Retrieve the handlers using the generator - handlers = list(action.get_event_handlers(event_name)) - - # Ensure that the correct handler is retrieved - assert len(handlers) == 1 - assert handlers[0] == handler - - -def test_action_get_event_handlers_no_event_registered(): - """Test that attempting to get handlers for an event with no registered handlers raises an exception.""" - action = Action("test.uuid.for.action") - - with pytest.raises(Exception): - list(action.get_event_handlers("InvalidEvent")) # type: ignore - - -def test_action_register_multiple_handlers_for_event(): - """Test that multiple handlers can be registered for the same event on the same action.""" - action = Action("test.uuid.for.action") - - @action.on("keyDown") - def handler_one(event: EventBase): - pass - - @action.on("keyDown") - def handler_two(event: EventBase): - pass - - handlers = list(action.get_event_handlers("keyDown")) - - # Ensure both handlers are retrieved - assert len(handlers) == 2 - assert handler_one in handlers - assert handler_two in handlers - -def test_action_get_event_handlers_invalid_event(): - """Test that getting handlers for an invalid event raises a KeyError.""" - action = Action("test.uuid.for.action") - - with pytest.raises(KeyError): - list(action.get_event_handlers("invalidEvent")) diff --git a/tests/actions/test_event_handler_filtering.py b/tests/actions/test_action_event_handler_filtering.py similarity index 66% rename from tests/actions/test_event_handler_filtering.py rename to tests/actions/test_action_event_handler_filtering.py index 0b075aa..4c5e315 100644 --- a/tests/actions/test_event_handler_filtering.py +++ b/tests/actions/test_action_event_handler_filtering.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from unittest.mock import create_autospec import pytest +from polyfactory.factories.pydantic_factory import ModelFactory from streamdeck.actions import Action, ActionRegistry, GlobalAction from streamdeck.models import events @@ -17,8 +18,6 @@ if TYPE_CHECKING: from unittest.mock import Mock - from polyfactory.factories.pydantic_factory import ModelFactory - @pytest.fixture @@ -29,53 +28,54 @@ def dummy_handler(event: events.EventBase) -> None: return create_autospec(dummy_handler, spec_set=True) -@pytest.mark.parametrize(("event_name","event_factory"), [ - ("keyDown", KeyDownEventFactory), - ("deviceDidConnect", DeviceDidConnectFactory), - ("applicationDidLaunch", ApplicationDidLaunchEventFactory) +@pytest.fixture(params=[ + KeyDownEventFactory, + DeviceDidConnectFactory, + ApplicationDidLaunchEventFactory ]) +def fake_event_data(request: pytest.FixtureRequest) -> events.EventBase: + event_factory = cast(ModelFactory[events.EventBase], request.param) + return event_factory.build() + + def test_global_action_gets_triggered_by_event( mock_event_handler: Mock, - event_name: str, - event_factory: ModelFactory[events.EventBase], -): + fake_event_data: events.EventBase, +) -> None: """Test that a global action's event handlers are triggered by an event. Global actions should be triggered by any event type that is registered with them, regardless of the event's unique identifier properties (or whether they're even present). """ - fake_event_data = event_factory.build() - global_action = GlobalAction() - global_action.on(event_name)(mock_event_handler) + global_action.on(fake_event_data.event)(mock_event_handler) - for handler in global_action.get_event_handlers(event_name): + for handler in global_action.get_event_handlers(fake_event_data.event): handler(fake_event_data) assert mock_event_handler.call_count == 1 assert fake_event_data in mock_event_handler.call_args.args -@pytest.mark.parametrize(("event_name","event_factory"), [ - ("keyDown", KeyDownEventFactory), - ("deviceDidConnect", DeviceDidConnectFactory), - ("applicationDidLaunch", ApplicationDidLaunchEventFactory) -]) -def test_action_gets_triggered_by_event(mock_event_handler: Mock, event_name: str, event_factory: ModelFactory[events.EventBase]): - # Create a fake event model instance - fake_event_data: events.EventBase = event_factory.build() +def test_action_gets_triggered_by_event( + mock_event_handler: Mock, + fake_event_data: events.EventBase, +) -> None: + """Test that an action's event handlers are triggered by an event. + + Actions should only be triggered by events that have the same unique identifier properties as the action. + """ # Extract the action UUID from the fake event data, or use a default value - # action_uuid: str = fake_event_data.action if fake_event_data.is_action_specific() else "my-fake-action-uuid" action_uuid: str = fake_event_data.action if isinstance(fake_event_data, events.ContextualEventMixin) else "my-fake-action-uuid" action = Action(uuid=action_uuid) # Register the mock event handler with the action - action.on(event_name)(mock_event_handler) + action.on(fake_event_data.event)(mock_event_handler) # Get the action's event handlers for the event and call them - for handler in action.get_event_handlers(event_name): + for handler in action.get_event_handlers(fake_event_data.event): handler(fake_event_data) # For some reason, assert_called_once() and assert_called_once_with() are returning False here... @@ -85,23 +85,18 @@ def test_action_gets_triggered_by_event(mock_event_handler: Mock, event_name: st -@pytest.mark.parametrize(("event_name","event_factory"), [ - ("keyDown", KeyDownEventFactory), - ("deviceDidConnect", DeviceDidConnectFactory), - ("applicationDidLaunch", ApplicationDidLaunchEventFactory) -]) -def test_global_action_registry_get_action_handlers_filtering(mock_event_handler: Mock, event_name: str, event_factory: ModelFactory[events.EventBase]): - # Create a fake event model instance - fake_event_data: events.EventBase = event_factory.build() +def test_global_action_registry_get_action_handlers_filtering( + mock_event_handler: Mock, + fake_event_data: events.EventBase, +) -> None: # Extract the action UUID from the fake event data, or use a default value - # action_uuid: str = fake_event_data.action if fake_event_data.is_action_specific() else None action_uuid: str | None = fake_event_data.action if isinstance(fake_event_data, events.ContextualEventMixin) else None registry = ActionRegistry() # Create an Action instance, without an action UUID as global actions aren't associated with a specific action global_action = GlobalAction() - global_action.on(event_name)(mock_event_handler) + global_action.on(fake_event_data.event)(mock_event_handler) # Register the global action with the registry registry.register(global_action) @@ -117,23 +112,18 @@ def test_global_action_registry_get_action_handlers_filtering(mock_event_handler -@pytest.mark.parametrize(("event_name","event_factory"), [ - ("keyDown", KeyDownEventFactory), - ("deviceDidConnect", DeviceDidConnectFactory), - ("applicationDidLaunch", ApplicationDidLaunchEventFactory) -]) -def test_action_registry_get_action_handlers_filtering(mock_event_handler: Mock, event_name: str, event_factory: ModelFactory[events.EventBase]): - # Create a fake event model instance - fake_event_data: events.EventBase = event_factory.build() +def test_action_registry_get_action_handlers_filtering( + mock_event_handler: Mock, + fake_event_data: events.EventBase, +) -> None: # Extract the action UUID from the fake event data, or use a default value - # action_uuid: str = fake_event_data.action if fake_event_data.is_action_specific() else None action_uuid: str | None = fake_event_data.action if isinstance(fake_event_data, events.ContextualEventMixin) else None registry = ActionRegistry() # Create an Action instance, using either the fake event's action UUID or a default value action = Action(uuid=action_uuid or "my-fake-action-uuid") - action.on(event_name)(mock_event_handler) + action.on(fake_event_data.event)(mock_event_handler) # Register the action with the registry registry.register(action) @@ -149,7 +139,7 @@ def test_action_registry_get_action_handlers_filtering(mock_event_handler: Mock, -def test_multiple_actions_filtering(): +def test_multiple_actions_filtering() -> None: registry = ActionRegistry() action = Action("my-fake-action-uuid-1") global_action = GlobalAction() @@ -158,12 +148,12 @@ def test_multiple_actions_filtering(): action_event_handler_called = False @global_action.on("applicationDidLaunch") - def _global_app_did_launch_action_handler(event: events.EventBase): + def _global_app_did_launch_action_handler(event: events.EventBase) -> None: nonlocal global_action_event_handler_called global_action_event_handler_called = True @action.on("keyDown") - def _action_key_down_event_handler(event: events.EventBase): + def _action_key_down_event_handler(event: events.EventBase) -> None: nonlocal action_event_handler_called action_event_handler_called = True diff --git a/tests/actions/test_action_registry.py b/tests/actions/test_action_registry.py index 7efecd8..0909ef2 100644 --- a/tests/actions/test_action_registry.py +++ b/tests/actions/test_action_registry.py @@ -11,7 +11,7 @@ ) -def test_register_action(): +def test_register_action() -> None: """Test that an action can be registered.""" registry = ActionRegistry() action = Action("my-fake-action-uuid") @@ -23,7 +23,7 @@ def test_register_action(): assert registry._plugin_actions[0] == action -def test_get_action_handlers_no_handlers(): +def test_get_action_handlers_no_handlers() -> None: """Test that getting action handlers when there are no handlers yields nothing.""" registry = ActionRegistry() action = Action("my-fake-action-uuid") @@ -35,13 +35,13 @@ def test_get_action_handlers_no_handlers(): assert len(handlers) == 0 -def test_get_action_handlers_with_handlers(): +def test_get_action_handlers_with_handlers() -> None: """Test that registered event handlers can be retrieved correctly.""" registry = ActionRegistry() action = Action("my-fake-action-uuid") @action.on("dialDown") - def dial_down_handler(event: events.EventBase): + def dial_down_handler(event: events.EventBase) -> None: pass registry.register(action) @@ -53,7 +53,7 @@ def dial_down_handler(event: events.EventBase): assert handlers[0] == dial_down_handler -def test_get_action_handlers_multiple_actions(): +def test_get_action_handlers_multiple_actions() -> None: """Test that multiple actions with registered handlers return all handlers.""" registry = ActionRegistry() @@ -61,11 +61,11 @@ def test_get_action_handlers_multiple_actions(): action2 = Action("fake-action-uuid-2") @action1.on("keyUp") - def key_up_handler1(event): + def key_up_handler1(event) -> None: pass @action2.on("keyUp") - def key_up_handler2(event): + def key_up_handler2(event) -> None: pass registry.register(action1) @@ -78,14 +78,3 @@ def key_up_handler2(event): assert len(handlers) == 2 assert key_up_handler1 in handlers assert key_up_handler2 in handlers - - -def test_get_action_handlers_event_not_available(): - """Test that a KeyError is raised if an unavailable event name is provided.""" - registry = ActionRegistry() - action = Action("my-fake-action-uuid") - - registry.register(action) - - with pytest.raises(KeyError): - list(registry.get_action_handlers("nonExistentEvent")) diff --git a/tests/actions/test_actions.py b/tests/actions/test_actions.py new file mode 100644 index 0000000..8decc51 --- /dev/null +++ b/tests/actions/test_actions.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, cast + +import pytest +from streamdeck.actions import Action, ActionBase, GlobalAction +from streamdeck.models.events import DEFAULT_EVENT_NAMES + + +if TYPE_CHECKING: + from streamdeck.models.events import EventBase + + +@pytest.fixture(params=[[Action, ("test.uuid.for.action",)], [GlobalAction, []]]) +def action(request: pytest.FixtureRequest) -> ActionBase: + """Fixture for initializing the Action and GlobalAction classes to parameterize the tests. + + We have to initialize the classes here to ensure fresh instances are used to avoid sharing data between tests. + """ + action_class, init_args = cast(tuple[type[ActionBase], tuple[Any]], request.param) + return action_class(*init_args) + + +@pytest.mark.parametrize("event_name", list(DEFAULT_EVENT_NAMES)) +def test_action_register_event_handler(action: ActionBase, event_name: str) -> None: + """Test that an event handler can be registered for each valid event name.""" + @action.on(event_name) + def handler(event: EventBase) -> None: + pass + + # Ensure the handler is registered for the correct event name + assert len(action._events[event_name]) == 1 + assert handler in action._events[event_name] + + +def test_action_get_event_handlers(action: ActionBase) -> None: + """Test that the correct event handlers are retrieved for each event name.""" + # Each iteration will add to the action's event handlers, thus we're checking that + # even with multiple event names, the handlers are correctly retrieved. + for i, event_name in enumerate(DEFAULT_EVENT_NAMES): + # Register a handler for the given event name + @action.on(event_name) + def handler(event: EventBase) -> None: + pass + + # Retrieve the handlers using the generator + handlers = list(action.get_event_handlers(event_name)) + + # Ensure that the correct handler is retrieved + assert len(handlers) == 1 + assert handlers[0] == handler + + +def test_action_get_event_handlers_no_event_registered(action: ActionBase) -> None: + """Test that attempting to get handlers for an event with no registered handlers returns an empty list. + + An implicit assumption is that an exception is not raised. + """ + handlers = list(action.get_event_handlers("dialDown")) + assert len(handlers) == 0 + + +def test_action_get_event_handlers_invalid_event_name(action: ActionBase) -> None: + """Test that attempting to get handlers for an event with an invalid event name returns an empty list. + + An implicit assumption is that an exception is not raised. + """ + events = list(action.get_event_handlers("invalidEvent")) + assert len(events) == 0 + + +@pytest.mark.parametrize("action", [Action("test.uuid.for.action"), GlobalAction()]) +def test_action_register_multiple_handlers_for_event(action: ActionBase) -> None: + """Test that multiple handlers can be registered for the same event on the same action.""" + @action.on("keyDown") + def handler_one(event: EventBase) -> None: + pass + + @action.on("keyDown") + def handler_two(event: EventBase) -> None: + pass + + handlers = list(action.get_event_handlers("keyDown")) + + # Ensure both handlers are retrieved + assert len(handlers) == 2 + assert handler_one in handlers + assert handler_two in handlers diff --git a/tests/actions/test_global_action.py b/tests/actions/test_global_action.py deleted file mode 100644 index 8654a96..0000000 --- a/tests/actions/test_global_action.py +++ /dev/null @@ -1,89 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest -from streamdeck.actions import GlobalAction -from streamdeck.types import available_event_names - - -if TYPE_CHECKING: - from streamdeck.models.events import EventBase - - -def test_global_action_register_event_handler(): - """Test that an event handler can be registered for each valid event name.""" - global_action = GlobalAction() - - for event_name in available_event_names: - @global_action.on(event_name) - def handler(event: EventBase) -> None: - pass - - # Ensure the handler is registered for the correct event name - assert len(global_action._events[event_name]) == 1 - assert handler in global_action._events[event_name] - - -def test_global_action_register_invalid_event_handler(): - """Test that attempting to register an invalid event handler raises an exception.""" - global_action = GlobalAction() - - with pytest.raises(Exception): - @global_action.on("InvalidEvent") - def handler(event: EventBase): - pass - - -def test_global_action_get_event_handlers(): - """Test that the correct event handlers are retrieved for each event name.""" - global_action = GlobalAction() - - for event_name in available_event_names: - # Register a handler for the event name - @global_action.on(event_name) - def handler(event: EventBase): - pass - - # Retrieve the handlers using the generator - handlers = list(global_action.get_event_handlers(event_name)) - - # Ensure that the correct handler is retrieved - assert len(handlers) == 1 - assert handlers[0] == handler - - -def test_global_action_get_event_handlers_no_event_registered(): - """Test that attempting to get handlers for an event with no registered handlers raises an exception.""" - global_action = GlobalAction() - - with pytest.raises(Exception): - list(global_action.get_event_handlers("InvalidEvent")) - - -def test_global_action_register_multiple_handlers_for_event(): - """Test that multiple handlers can be registered for an event.""" - global_action = GlobalAction() - - @global_action.on("keyDown") - def handler_one(event: EventBase): - pass - - @global_action.on("keyDown") - def handler_two(event: EventBase): - pass - - handlers = list(global_action.get_event_handlers("keyDown")) - - # Ensure both handlers are registered for the event - assert len(handlers) == 2 - assert handler_one in handlers - assert handler_two in handlers - - -def test_global_action_get_event_handlers_invalid_event(): - """Test that attempting to get handlers for an invalid event raises a KeyError.""" - global_action = GlobalAction() - - with pytest.raises(KeyError): - list(global_action.get_event_handlers("InvalidEvent")) diff --git a/tests/data/dummy_event_listener.py b/tests/data/dummy_event_listener.py new file mode 100644 index 0000000..7bdc3da --- /dev/null +++ b/tests/data/dummy_event_listener.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from streamdeck.event_listener import EventListener +from streamdeck.models.events import EventBase +from typing_extensions import override # noqa: UP035 + + +if TYPE_CHECKING: + from collections.abc import Generator + from typing import Any, ClassVar + + +class DummyEvent(EventBase): + """A dummy event for testing purposes.""" + event: Literal["dummy"] # type: ignore[assignment] + something: int + + +class DummyEventListener(EventListener): + """A dummy event listener for testing purposes.""" + event_models: ClassVar[list[type[EventBase]]] = [DummyEvent] + + @override + def listen(self) -> Generator[str | bytes, Any, None]: + """Yields a dummy event.""" + yield '{"event": "dummy", "something": 42}' + + @override + def stop(self) -> None: + """Doesn't do anything.""" diff --git a/tests/data/pyproject.toml b/tests/data/pyproject.toml index 00f9c53..498a7a1 100644 --- a/tests/data/pyproject.toml +++ b/tests/data/pyproject.toml @@ -5,4 +5,7 @@ action_scripts = [ "tests.data.test_action1", "tests.data.test_action2", + ] + "event_listener_modules" = [ + "tests.data.dummy_event_listener", ] \ No newline at end of file diff --git a/tests/event_listener/__init__.py b/tests/event_listener/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/event_listener/test_event_listener.py b/tests/event_listener/test_event_listener.py new file mode 100644 index 0000000..0ba9782 --- /dev/null +++ b/tests/event_listener/test_event_listener.py @@ -0,0 +1,182 @@ +import threading +import time +from collections.abc import Generator +from typing import Any, ClassVar +from unittest.mock import Mock, patch + +import pytest +from streamdeck.event_listener import EventListener, EventListenerManager +from streamdeck.models.events import ApplicationDidLaunch, EventBase + + +class MockEventListener(EventListener): + """Mock implementation of EventListener for testing.""" + event_models: ClassVar[list[type[EventBase]]] = [ApplicationDidLaunch] + + def __init__(self): + self._running = True + self.events = ["event1", "event2", "event3"] + self.listen_called = False + self.stop_called = False + + def listen(self) -> Generator[str, None, None]: + self.listen_called = True + for event in self.events: + if not self._running: + break + yield event + + def stop(self) -> None: + self._running = False + self.stop_called = True + + +class SlowMockEventListener(EventListener): + """Mock implementation of EventListener that yields events with a delay.""" + event_models: ClassVar[list[type[EventBase]]] = [ApplicationDidLaunch] + + def __init__(self, delay: float = 0.1): + self._running = True + self.delay = delay + self.events = ["slow1", "slow2", "slow3"] + + def listen(self) -> Generator[str, None, None]: + for event in self.events: + if not self._running: + break + time.sleep(self.delay) # simulate processing time + yield event + + def stop(self) -> None: + self._running = False + + +class ExceptionEventListener(EventListener): + """Mock implementation of EventListener that raises an exception.""" + event_models: ClassVar[list[type[EventBase]]] = [ApplicationDidLaunch] + + def listen(self) -> Generator[str, None, None]: + self._running = True + # Raise exception after yielding one event + yield "event_before_exception" + raise ValueError("Test exception in listener") + + def stop(self) -> None: + pass + + + + +def test_add_listener(): + """Test adding a listener to the manager.""" + manager = EventListenerManager() + listener = Mock() + + assert len(manager.listeners_lookup_by_thread) == 0 + + manager.add_listener(listener) + + # Check that a thread was created and added to the lookup dict + assert len(manager.listeners_lookup_by_thread) == 1 + thread = next(iter(manager.listeners_lookup_by_thread.keys())) + assert isinstance(thread, threading.Thread) + assert manager.listeners_lookup_by_thread[thread] is listener + + +def test_event_stream_basic(): + """Test that event_stream yields events from added listeners.""" + manager = EventListenerManager() + listener = MockEventListener() + + manager.add_listener(listener) + + # Collect the first few events + events = [] + for event in manager.event_stream(): + events.append(event) + if len(events) >= 3: # We expect 3 events from MockEventListener + manager.stop() + break + + # Check that we got the expected events + assert events == listener.events + + +def test_event_stream_multiple_listeners(): + """Test that events from multiple listeners are interleaved.""" + manager = EventListenerManager() + listener1 = MockEventListener() + listener2 = SlowMockEventListener(delay=0.01) # Small delay to ensure deterministic order + + manager.add_listener(listener1) + manager.add_listener(listener2) + + # Collect all events + events = [] + for event in manager.event_stream(): + events.append(event) + if len(events) >= 6: # We expect 6 events total (3 from each listener) + manager.stop() + break + + # Check that we got all expected events (order may vary) + assert len(events) == 6 + assert set(events) == {"event1", "event2", "event3", "slow1", "slow2", "slow3"} + + +def test_stop(): + """Test that stop() correctly stops all listeners.""" + manager = EventListenerManager() + listener1 = MockEventListener() + listener2 = MockEventListener() + + manager.add_listener(listener1) + manager.add_listener(listener2) + + # Start the event stream but then immediately stop it + event_iter = manager.event_stream() + next(event_iter) # Get the first event to ensure threads start + manager.stop() + + # Check that all listeners were stopped + assert listener1.stop_called + assert listener2.stop_called + assert not manager.running + + +@patch("streamdeck.event_listener.logger") +def test_listener_exception_handling(mock_logger: Mock): + """Test that exceptions in listeners are properly caught and logged.""" + manager = EventListenerManager() + listener = ExceptionEventListener() + + manager.add_listener(listener) + + # Collect events - should get one event before the exception + events: list[Any] = [] + for event in manager.event_stream(): + events.append(event) + if len(events) >= 1: # We expect 1 event before exception + time.sleep(0.1) # Give time for exception to be raised + manager.stop() + break + + assert "event_before_exception" in events + # Check that the exception was logged + mock_logger.exception.assert_called_with("Unexpected error in wrapped listener %s. Stopping just this listener.", listener) + + +def test_listener_wrapper(): + """Test the _listener_wrapper method that runs listeners in threads.""" + manager = EventListenerManager() + listener = MockEventListener() + + # Run the wrapper directly + manager.running = True + manager._listener_wrapper(listener) + + # Check that events were put in the queue + assert manager.event_queue.qsize() == 3 + assert manager.event_queue.get() == "event1" + assert manager.event_queue.get() == "event2" + assert manager.event_queue.get() == "event3" diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/test_event_adapter.py b/tests/models/test_event_adapter.py new file mode 100644 index 0000000..efcecc2 --- /dev/null +++ b/tests/models/test_event_adapter.py @@ -0,0 +1,135 @@ +import json +from typing import Literal + +import pytest +from pydantic import ValidationError +from streamdeck.models.events import ( + DEFAULT_EVENT_MODELS, + DEFAULT_EVENT_NAMES, + EventAdapter, + EventBase, + KeyDown, +) + +from tests.test_utils.fake_event_factories import KeyDownEventFactory + + +def test_init_with_default_models() -> None: + """Test that the EventAdapter initializes with just the default models and their event names.""" + adapter = EventAdapter() + + # Check that all default event names are registered + assert len(adapter._event_names) == len(DEFAULT_EVENT_NAMES) + for event_name in DEFAULT_EVENT_NAMES: + assert event_name in adapter._event_names + + # Check that all default models are registered + assert len(adapter._models) == len(DEFAULT_EVENT_MODELS) + for model in DEFAULT_EVENT_MODELS: + assert model in adapter._models + +def test_add_model() -> None: + """Test that models can be added to the adapter.""" + adapter = EventAdapter() + + # Create a fake event model + class DummyEvent(EventBase): + event: Literal["dummyEvent"] + payload: dict[str, str] + + # Add the custom model + adapter.add_model(DummyEvent) + + # Check that the custom event name is registered + assert "dummyEvent" in adapter._event_names + assert len(adapter._event_names) == len(DEFAULT_EVENT_NAMES) + 1 + + # Check that the custom model is registered + assert DummyEvent in adapter._models + assert len(adapter._models) == len(DEFAULT_EVENT_MODELS) + 1 + +def test_event_name_exists() -> None: + """Test that event_name_exists correctly identifies registered events.""" + adapter = EventAdapter() + + # Should return True for default events + assert adapter.event_name_exists("keyDown") + assert adapter.event_name_exists("willAppear") + + # Should return False for non-registered events + assert not adapter.event_name_exists("nonExistentEvent") + +def test_type_adapter_creation() -> None: + """Test that the type_adapter property creates a TypeAdapter.""" + adapter = EventAdapter() + + # First access should create the type adapter + type_adapter = adapter.type_adapter + assert type_adapter is not None + + # Second access should return the same instance + assert adapter.type_adapter is type_adapter + +def test_validate_json_with_valid_data() -> None: + """Test that validate_json correctly parses valid JSON into the correct event model instance.""" + adapter = EventAdapter() + + # Create valid JSON for a keyDown event + fake_key_down_event: KeyDown = KeyDownEventFactory.build() + fake_key_down_json: str = fake_key_down_event.model_dump_json() + + # "Adapt" the JSON to an event model instance + event = adapter.validate_json(fake_key_down_json) + + # Check that the event is parsed correctly + assert isinstance(event, KeyDown) + assert event == fake_key_down_event + +@pytest.mark.parametrize("invalid_event_json", [ + json.dumps({"event": "keyDown"}), # Missing required fields: action, context, device + json.dumps({"event": "unknownEvent", "action": "com.example.plugin.action", "context": "context123"}) # Unknown event +]) +def test_validate_json_with_invalid_data(invalid_event_json: str) -> None: + """Test that validate_json raises an error for invalid JSON. + + Parameter 1: json object that is missing required fields. + Parameter 2: json object with an unknown event name. + """ + adapter = EventAdapter() + + # Validate the JSON - should raise ValidationError + with pytest.raises(ValidationError): + adapter.validate_json(invalid_event_json) + + +def test_adding_custom_event_allows_validation() -> None: + """Test that adding a custom event allows validating JSON for that event.""" + adapter = EventAdapter() + + # Create a fake event model + class DummyEvent(EventBase): + event: Literal["dummyEvent"] + action: str + context: str + device: str + payload: dict[str, str] + + # Add the custom model + adapter.add_model(DummyEvent) + + # Create valid JSON for the custom event + fake_event = DummyEvent( + event="dummyEvent", + action="com.example.plugin.action", + context="context123", + device="device123", + payload={"key": "value"} + ) + fake_event_json = fake_event.model_dump_json() + + # Validate the JSON + event = adapter.validate_json(fake_event_json) + + # Check that the event is parsed correctly + assert isinstance(event, DummyEvent) + assert event == fake_event diff --git a/tests/plugin_manager/conftest.py b/tests/plugin_manager/conftest.py index ff140a8..02a4194 100644 --- a/tests/plugin_manager/conftest.py +++ b/tests/plugin_manager/conftest.py @@ -1,12 +1,20 @@ from __future__ import annotations -import uuid +from typing import TYPE_CHECKING from unittest.mock import Mock, create_autospec import pytest +from streamdeck.event_listener import EventListenerManager from streamdeck.manager import PluginManager from streamdeck.websocket import WebSocketClient +from tests.test_utils.fake_event_factories import KeyDownEventFactory + + +if TYPE_CHECKING: + import pytest_mock + from streamdeck.models import events + @pytest.fixture def plugin_manager(port_number: int, plugin_registration_uuid: str) -> PluginManager: @@ -39,6 +47,7 @@ def patch_websocket_client(monkeypatch: pytest.MonkeyPatch) -> Mock: monkeypatch: pytest's monkeypatch fixture. Returns: + Mock: Mocked instance of WebSocketClient """ mock_websocket_client: Mock = create_autospec(WebSocketClient, spec_set=True) mock_websocket_client.__enter__.return_value = mock_websocket_client @@ -61,3 +70,43 @@ def mock_command_sender(mocker: pytest_mock.MockerFixture) -> Mock: mock_cmd_sender = Mock() mocker.patch("streamdeck.manager.StreamDeckCommandSender", return_value=mock_cmd_sender) return mock_cmd_sender + + + +@pytest.fixture +def patch_event_listener_manager(monkeypatch: pytest.MonkeyPatch) -> Mock: + """Fixture that uses pytest's MonkeyPatch to mock EventListenerManager for the PluginManager run method. + + The mocked EventListenerManager can be given fake event messages to yield when event_stream() is called: + ```patch_event_listener_manager.event_stream.return_value = [fake_event_json1, fake_event_json2, ...]``` + + Args: + monkeypatch: pytest's monkeypatch fixture. + + Returns: + Mock: Mocked instance of EventListenerManager + """ + mock_event_listener_manager: Mock = create_autospec(EventListenerManager, spec_set=True) + + monkeypatch.setattr("streamdeck.manager.EventListenerManager", (lambda *args, **kwargs: mock_event_listener_manager)) + + return mock_event_listener_manager + + + +@pytest.fixture +def mock_event_listener_manager_with_fake_events(patch_event_listener_manager: Mock) -> tuple[Mock, list[events.EventBase]]: + """Fixture that mocks the EventListenerManager and provides a list of fake event messages yielded by the mock manager. + + Returns: + tuple: Mocked instance of EventListenerManager, and a list of fake event messages. + """ + print("MOCK EVENT LISTENER MANAGER") + # Create a list of fake event messages, and convert them to json strings to be passed back by the client.listen() method. + fake_event_messages: list[events.EventBase] = [ + KeyDownEventFactory.build(action="my-fake-action-uuid"), + ] + + patch_event_listener_manager.event_stream.return_value = [event.model_dump_json() for event in fake_event_messages] + + return patch_event_listener_manager, fake_event_messages diff --git a/tests/plugin_manager/test_command_sender_binding.py b/tests/plugin_manager/test_command_sender_binding.py index f251a08..70f2090 100644 --- a/tests/plugin_manager/test_command_sender_binding.py +++ b/tests/plugin_manager/test_command_sender_binding.py @@ -1,31 +1,27 @@ from __future__ import annotations import inspect -from functools import partial from typing import TYPE_CHECKING, Any, cast from unittest.mock import Mock, create_autospec import pytest -from pprintpp import pprint from streamdeck.actions import Action -from streamdeck.command_sender import StreamDeckCommandSender -from streamdeck.websocket import WebSocketClient - -from tests.test_utils.fake_event_factories import KeyDownEventFactory if TYPE_CHECKING: - from collections.abc import Callable + from functools import partial + from streamdeck.command_sender import StreamDeckCommandSender from streamdeck.manager import PluginManager from streamdeck.models import events - from typing_extensions import TypeAlias # noqa: UP035 - - EventHandlerFunc: TypeAlias = Callable[[events.EventBase], None] | Callable[[events.EventBase, StreamDeckCommandSender], None] + from streamdeck.types import ( + EventHandlerBasicFunc, + EventHandlerFunc, + ) -def create_event_handler(include_command_sender_param: bool = False) -> EventHandlerFunc: +def create_event_handler(include_command_sender_param: bool = False) -> EventHandlerFunc[events.EventBase]: """Create a dummy event handler function that matches the EventHandlerFunc TypeAlias. Args: @@ -35,12 +31,12 @@ def create_event_handler(include_command_sender_param: bool = False) -> EventHan Callable[[events.EventBase], None] | Callable[[events.EventBase, StreamDeckCommandSender], None]: A dummy event handler function. """ if not include_command_sender_param: - def dummy_handler_without_cmd_sender(event: events.EventBase) -> None: + def dummy_handler_without_cmd_sender(event_data: events.EventBase) -> None: """Dummy event handler function that matches the EventHandlerFunc TypeAlias without `command_sender` param.""" return dummy_handler_without_cmd_sender - def dummy_handler_with_cmd_sender(event: events.EventBase, command_sender: StreamDeckCommandSender) -> None: + def dummy_handler_with_cmd_sender(event_data: events.EventBase, command_sender: StreamDeckCommandSender) -> None: """Dummy event handler function that matches the EventHandlerFunc TypeAlias with `command_sender` param.""" return dummy_handler_with_cmd_sender @@ -49,39 +45,19 @@ def dummy_handler_with_cmd_sender(event: events.EventBase, command_sender: Strea @pytest.fixture(params=[True, False]) def mock_event_handler(request: pytest.FixtureRequest) -> Mock: include_command_sender_param: bool = request.param - dummy_handler: EventHandlerFunc = create_event_handler(include_command_sender_param) + dummy_handler: EventHandlerFunc[events.EventBase] = create_event_handler(include_command_sender_param) return create_autospec(dummy_handler, spec_set=True) -@pytest.fixture -def mock_websocket_client_with_fake_events(patch_websocket_client: Mock) -> tuple[Mock, list[events.EventBase]]: - """Fixture that mocks the WebSocketClient and provides a list of fake event messages yielded by the mock client. - - Args: - patch_websocket_client: Mocked instance of the patched WebSocketClient. - - Returns: - tuple: Mocked instance of WebSocketClient, and a list of fake event messages. - """ - # Create a list of fake event messages, and convert them to json strings to be passed back by the client.listen() method. - fake_event_messages: list[events.EventBase] = [ - KeyDownEventFactory.build(action="my-fake-action-uuid"), - ] - - patch_websocket_client.listen.return_value = [event.model_dump_json() for event in fake_event_messages] - - return patch_websocket_client, fake_event_messages - - def test_inject_command_sender_func( plugin_manager: PluginManager, mock_event_handler: Mock, -): +) -> None: """Test that the command_sender is injected into the handler.""" mock_command_sender = Mock() - result_handler = plugin_manager._inject_command_sender(mock_event_handler, mock_command_sender) + result_handler: EventHandlerBasicFunc[events.EventBase] = plugin_manager._inject_command_sender(mock_event_handler, mock_command_sender) resulting_handler_params = inspect.signature(result_handler).parameters @@ -92,7 +68,7 @@ def test_inject_command_sender_func( assert result_handler != mock_event_handler # Check that the `command_sender` parameter is bound to the correct value. - resulting_handler_bound_kwargs: dict[str, Any] = cast(partial[Any], result_handler).keywords + resulting_handler_bound_kwargs: dict[str, Any] = cast("partial[Any]", result_handler).keywords assert resulting_handler_bound_kwargs["command_sender"] == mock_command_sender # If there isn't a `command_sender` parameter, then the `result_handler` is the original handler unaltered. @@ -100,27 +76,30 @@ def test_inject_command_sender_func( assert result_handler == mock_event_handler +@pytest.mark.usefixtures("patch_websocket_client") def test_run_manager_events_handled_with_correct_params( - mock_websocket_client_with_fake_events: tuple[Mock, list[events.EventBase]], + mock_event_listener_manager_with_fake_events: tuple[Mock, list[events.EventBase]], plugin_manager: PluginManager, mock_command_sender: Mock, -): +) -> None: """Test that the PluginManager runs and triggers event handlers with the correct parameter binding. This test will: - Register an action with the PluginManager. - Create and register mock event handlers with and without the `command_sender` parameter. - - Run the PluginManager and let it process the fake event messages generated by the mocked WebSocketClient. + - Run the PluginManager and let it process the fake event messages generated by the mocked EventListenerManager. - Ensure that mocked event handlers were called with the correct params, binding the `command_sender` parameter if defined in the handler's signature. + NOTE: The WebSocketClient is mocked so as to be essentially ignored in this test. + Args: - mock_websocket_client_with_fake_events (tuple[Mock, list[events.EventBase]]): Mocked instance of WebSocketClient, and a list of fake event messages it will yield. + mock_event_listener_manager_with_fake_events (tuple[Mock, list[events.EventBase]]): Mocked instance of EventListenerManager, and a list of fake event messages it will yield. plugin_manager (PluginManager): Instance of PluginManager with test parameters. mock_command_sender (Mock): Patched instance of StreamDeckCommandSender. Used here to ensure that the `command_sender` parameter is bound correctly. """ # As of now, fake_event_messages is a list of one KeyDown event. If this changes, I'll need to update this test. - fake_event_message: events.KeyDown = mock_websocket_client_with_fake_events[1][0] + fake_event_message: events.KeyDown = mock_event_listener_manager_with_fake_events[1][0] action = Action(fake_event_message.action) @@ -134,7 +113,8 @@ def test_run_manager_events_handled_with_correct_params( plugin_manager.register_action(action) - # Run the PluginManager and let it process the fake event messages generated by the mocked WebSocketClient. + # Run the PluginManager and let it process the fake event messages generated by the mocked EventListenerManager. + # Since the EventListenerManager.event_stream() method is mocked to return a finite list of fake event messages, it will stop after yielding all of them rather than running indefinitely. plugin_manager.run() # Ensure that mocked event handlers were called with the correct params, binding the `command_sender` parameter if defined in the handler's signature. diff --git a/tests/plugin_manager/test_plugin_manager.py b/tests/plugin_manager/test_plugin_manager.py index edd1444..4900482 100644 --- a/tests/plugin_manager/test_plugin_manager.py +++ b/tests/plugin_manager/test_plugin_manager.py @@ -4,28 +4,12 @@ import pytest import pytest_mock from streamdeck.actions import Action -from streamdeck.manager import PluginManager -from streamdeck.models.events import DialRotate, EventBase, event_adapter - -from tests.test_utils.fake_event_factories import DialRotateEventFactory - - -@pytest.fixture -def mock_websocket_client_with_event(patch_websocket_client: Mock) -> tuple[Mock, EventBase]: - """Fixture that mocks the WebSocketClient and provides a fake DialRotateEvent message. - - Args: - patch_websocket_client: Mocked instance of the patched WebSocketClient. - - Returns: - tuple: Mocked instance of WebSocketClient, and a fake DialRotateEvent. - """ - # Create a fake event message, and convert it to a json string to be passed back by the client.listen() method. - fake_event_message: DialRotate = DialRotateEventFactory.build() - patch_websocket_client.listen.return_value = [fake_event_message.model_dump_json()] - - return patch_websocket_client, fake_event_message - +from streamdeck.manager import EventAdapter, PluginManager +from streamdeck.models.events import ( #, event_adapter + DEFAULT_EVENT_MODELS, + DEFAULT_EVENT_NAMES, + EventBase, +) @pytest.fixture @@ -41,7 +25,7 @@ def _spy_action_registry_get_action_handlers( Returns: None """ - mocker.spy(plugin_manager._registry, "get_action_handlers") + mocker.spy(plugin_manager._action_registry, "get_action_handlers") @pytest.fixture @@ -54,55 +38,108 @@ def _spy_event_adapter_validate_json(mocker: pytest_mock.MockerFixture) -> None: Returns: None """ - mocker.spy(event_adapter, "validate_json") + mocker.spy(EventAdapter, "validate_json") -def test_plugin_manager_register_action(plugin_manager: PluginManager): +def test_plugin_manager_register_action(plugin_manager: PluginManager) -> None: """Test that an action can be registered in the PluginManager.""" - assert len(plugin_manager._registry._plugin_actions) == 0 + assert len(plugin_manager._action_registry._plugin_actions) == 0 action = Action("my-fake-action-uuid") plugin_manager.register_action(action) - assert len(plugin_manager._registry._plugin_actions) == 1 - assert plugin_manager._registry._plugin_actions[0] == action + assert len(plugin_manager._action_registry._plugin_actions) == 1 + assert plugin_manager._action_registry._plugin_actions[0] == action + + +def test_plugin_manager_register_event_listener(plugin_manager: PluginManager) -> None: + """Test that an event listener can be registered in the PluginManager.""" + mock_event_model_class = Mock(get_model_event_name=lambda: ["fake_event_name"]) + listener = Mock(event_models=[mock_event_model_class]) + + assert len(plugin_manager._event_listener_manager.listeners_lookup_by_thread) == 0 + + plugin_manager.register_event_listener(listener) + + # Validate that the PluginManager's EventListenerManager has the listener properly registered. + assert len(plugin_manager._event_listener_manager.listeners_lookup_by_thread) == 1 + assert next(iter(plugin_manager._event_listener_manager.listeners_lookup_by_thread.values())) == listener + # Validate that the PluginManager's EventAdapter has the event model class properly registered. + assert len(plugin_manager._event_adapter._models) == len(DEFAULT_EVENT_MODELS) + 1 + assert mock_event_model_class in plugin_manager._event_adapter._models + # Also validate that the event name is in the set of registered event names. + assert len(plugin_manager._event_adapter._event_names) == len(DEFAULT_EVENT_NAMES) + 1 + assert "fake_event_name" in plugin_manager._event_adapter._event_names -@pytest.mark.usefixtures("mock_websocket_client_with_event") + +@pytest.mark.usefixtures("patch_websocket_client", "patch_event_listener_manager") def test_plugin_manager_sends_registration_event( - mock_command_sender: Mock, plugin_manager: PluginManager -): - """Test that StreamDeckCommandSender.send_action_registration() method is called with correct arguments within the PluginManager.run() method.""" + mock_command_sender: Mock, + plugin_manager: PluginManager, +) -> None: + """Test that StreamDeckCommandSender.send_action_registration() method is called with correct arguments within the PluginManager.run() method. + + When the PluginManager.run() method is called, it should register the plugin with the StreamDeck software by sending an action registration event via the StreamDeckCommandSender instance. + + NOTE: The WebSocketClient and EventListenerManager are mocked so as to be essentially ignored in this test. + """ plugin_manager.run() mock_command_sender.send_action_registration.assert_called_once_with( register_event=plugin_manager._register_event, - plugin_registration_uuid=plugin_manager._registration_uuid, ) -@pytest.mark.usefixtures("_spy_action_registry_get_action_handlers") -@pytest.mark.usefixtures("_spy_event_adapter_validate_json") +def test_plugin_manager_adds_websocket_event_listener( + patch_event_listener_manager: Mock, + patch_websocket_client: Mock, + plugin_manager: PluginManager, # This fixture must come after patch_event_listener_manager to ensure monkeypatching occurs. +) -> None: + """Test that the PluginManager adds the WebSocketClient as an event listener. + + The PluginManager.run() method should add the WebSocketClient as an event listener to the EventListenerManager. + + Args: + patch_event_listener_manager (Mock): Mocked instance of EventListenerManager. + Patched by the fixture, and used here to check if this instance's .add_listener() method was called with the appropriate arguments. + patch_websocket_client (Mock): Mocked instance of WebSocketClient. + Patched by the fixture, and used here to check if this instance was passed to the EventListenerManager.add_listener() method as an argument. + plugin_manager (PluginManager): PluginManager fixture + """ + plugin_manager.run() + + patch_event_listener_manager.add_listener.assert_called_once_with(patch_websocket_client) + + +@pytest.mark.integration +@pytest.mark.usefixtures("patch_websocket_client") def test_plugin_manager_process_event( - mock_websocket_client_with_event: tuple[Mock, EventBase], plugin_manager: PluginManager -): + mock_event_listener_manager_with_fake_events: tuple[Mock, list[EventBase]], + _spy_action_registry_get_action_handlers: None, # This fixture must come after mock_event_listener_manager_with_fake_events to ensure monkeypatching occurs. + _spy_event_adapter_validate_json: None, # This fixture must come after mock_event_listener_manager_with_fake_events to ensure monkeypatching occurs. + plugin_manager: PluginManager, # This fixture must come after patch_event_listener_manager and spy-fixtures to ensure things are mocked and spied correctly. +) -> None: """Test that PluginManager processes events correctly, calling event_adapter.validate_json and action_registry.get_action_handlers.""" - mock_websocket_client, fake_event_message = mock_websocket_client_with_event + mock_event_listener_mgr, fake_event_messages = mock_event_listener_manager_with_fake_events + fake_event_message = fake_event_messages[0] plugin_manager.run() - # First check that the WebSocketClient's listen() method was called. + # First check that the EventListenerManager's event_stream() method was called. # This has been stubbed to return the fake_event_message's json string. - mock_websocket_client.listen.assert_called_once() + mock_event_listener_mgr.event_stream.assert_called_once() # Check that the event_adapter.validate_json method was called with the stub json string returned by listen(). - spied_event_adapter_validate_json = cast(Mock, event_adapter.validate_json) - spied_event_adapter_validate_json.assert_called_once_with(fake_event_message.model_dump_json()) + spied_event_adapter__validate_json = cast("Mock", EventAdapter.validate_json) + # Since this is an instance method, the first argument is the instance itself. + spied_event_adapter__validate_json.assert_called_once_with(plugin_manager._event_adapter, fake_event_message.model_dump_json()) # Check that the validate_json method returns the same event type model as the fake_event_message. - assert spied_event_adapter_validate_json.spy_return == fake_event_message + assert spied_event_adapter__validate_json.spy_return == fake_event_message # Check that the action_registry.get_action_handlers method was called with the event name and action uuid. - cast(Mock, plugin_manager._registry.get_action_handlers).assert_called_once_with( + spied_action_registry__get_action_handlers = cast("Mock", plugin_manager._action_registry.get_action_handlers) + spied_action_registry__get_action_handlers.assert_called_once_with( event_name=fake_event_message.event, event_action_uuid=fake_event_message.action ) diff --git a/tests/test_logging.py b/tests/test_logging.py index 23f6094..d9b3d3a 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -49,7 +49,7 @@ def fake_streamdeck_log_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> return tmp_path -def test_local_logger(fake_plugin_local_log_dir: Path): +def test_local_logger(fake_plugin_local_log_dir: Path) -> None: """Test the configuration of a local plugin-specific logger. This test verifies that a logger configured for a specific plugin writes logs to the correct local directory. @@ -86,7 +86,7 @@ def test_local_logger(fake_plugin_local_log_dir: Path): # assert fake_name in actual_log_file_output -def test_streamdeck_logger(fake_streamdeck_log_dir: Path): +def test_streamdeck_logger(fake_streamdeck_log_dir: Path) -> None: """Test the configuration of a centralized Stream Deck logger. This test verifies that a logger configured for the Stream Deck writes logs to the correct centralized directory. diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 42befd3..bf61233 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -1,12 +1,20 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING +from functools import partial +from typing import TYPE_CHECKING, Any from unittest.mock import Mock import pytest +from streamdeck.event_listener import StopStreaming from streamdeck.websocket import WebSocketClient -from websockets import ConnectionClosedOK, WebSocketException +from websockets import ( + ConnectionClosed, + ConnectionClosedError, + ConnectionClosedOK, + InvalidHeader, + WebSocketException, +) if TYPE_CHECKING: @@ -18,13 +26,16 @@ def mock_connection() -> Mock: """Fixture to mock the ClientConnection object returned by websockets.sync.client.connect.""" return Mock() + @pytest.fixture def patched_connect(mocker: MockerFixture, mock_connection: Mock) -> Mock: """Fixture to mock the ClientConnection object returned by websockets.sync.client.connect.""" return mocker.patch("streamdeck.websocket.connect", return_value=mock_connection) -def test_initialization_calls_connect_correctly(patched_connect: Mock, mock_connection: Mock, port_number: int): +def test_initialization_calls_connect_correctly( + patched_connect: Mock, mock_connection: Mock, port_number: int +) -> None: """Test that WebSocketClient initializes correctly by calling the connect function with the appropriate URI.""" with WebSocketClient(port=port_number) as client: # Assert that 'connect' was called once with the correct URI. @@ -35,7 +46,7 @@ def test_initialization_calls_connect_correctly(patched_connect: Mock, mock_conn @pytest.mark.usefixtures("patched_connect") -def test_send_event_serializes_and_sends(mock_connection: Mock, port_number: int): +def test_send_event_serializes_and_sends(mock_connection: Mock, port_number: int) -> None: """Test that the send_event method corrently serializes the data to JSON and sends it via the websocket connection.""" with WebSocketClient(port=port_number) as client: fake_data = {"event": "test_event", "payload": {"key": "value"}} @@ -47,12 +58,41 @@ def test_send_event_serializes_and_sends(mock_connection: Mock, port_number: int @pytest.mark.usefixtures("patched_connect") -def test_listen_yields_messages(mock_connection: Mock, port_number: int): +def test_listen_yields_messages(mock_connection: Mock, port_number: int) -> None: """Test that listen yields messages from the WebSocket connection.""" # Set up the mocked connection to return messages until closing - mock_connection.recv.side_effect = ["message1", b"message2", WebSocketException()] + expected_results = ["message1", b"message2", "message3"] + mock_connection.recv.side_effect = expected_results with WebSocketClient(port=port_number) as client: - messages = list(client.listen()) + actual_messages: list[Any] = [] + for i, msg in enumerate(client.listen()): + actual_messages.append(msg) + if i == 2: + break + + assert actual_messages == expected_results + + +@pytest.mark.parametrize( + "exception_class", + [ + partial(ConnectionClosedOK, None, None), + partial(ConnectionClosedError, None, None), + partial(InvalidHeader, "header-name", None), + partial(ConnectionClosed, None, None), + WebSocketException, + ], +) +@pytest.mark.usefixtures("patched_connect") +def test_listen_raises_StopStreaming_from_WebSocketException( + mock_connection: Mock, port_number: int, exception_class: type[WebSocketException] +) -> None: + """Test that listen raises a StopStreaming exception when a WebSocketException is raised.""" + # Set up the mocked connection to return messages until closing + mock_connection.recv.side_effect = ["message1", b"message2", exception_class()] - assert messages == ["message1", b"message2"] + # This should raise a StopStreaming exception when any WebSocketException is raised + with WebSocketClient(port=port_number) as client, pytest.raises(StopStreaming): + for _ in client.listen(): + pass