diff --git a/.devcontainer/config.json b/.devcontainer/config.json new file mode 100644 index 0000000..3943043 --- /dev/null +++ b/.devcontainer/config.json @@ -0,0 +1,17 @@ +{ + "result_backend": { + "redis": { + "url": "redis://redis:6379/0" + } + }, + "state_backend": { + "redis": { + "url": "redis://redis:6379/0" + } + }, + "broker": { + "amqp": { + "url": "amqp://rabbitmq:5672" + } + } +} \ No newline at end of file diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 35c9f1b..30e7a7d 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -3,7 +3,8 @@ { "name": "Python 3", "dockerComposeFile": [ - "docker-compose.yaml" + "docker-compose.yaml", + "docker-compose-devcontainer.yaml" ], "workspaceFolder": "/workspaces/project-mognet", "service": "mognet", diff --git a/.devcontainer/docker-compose-devcontainer.yaml b/.devcontainer/docker-compose-devcontainer.yaml new file mode 100644 index 0000000..ad09b8e --- /dev/null +++ b/.devcontainer/docker-compose-devcontainer.yaml @@ -0,0 +1,35 @@ +version: "2.4" + +services: + mognet: + build: + context: .. + dockerfile: .devcontainer/Dockerfile + args: + VARIANT: "3.8" + INSTALL_NODE: "false" + + mem_limit: 4g + cpus: 2 + working_dir: /workspaces/project-mognet + volumes: + - ..:/workspaces/project-mognet:cached + - vscode_server_data:/home/vscode/.vscode-server + - cache:/home/vscode/.cache + + environment: + - MOGNET_CONFIG_FILE=/workspaces/project-mognet/.devcontainer/config.json + + command: + - bash + - -c + - | + set -e + sudo chown -R vscode:vscode /home/vscode/.vscode-server + sudo chown -R vscode:vscode /home/vscode/.cache + exec sleep infinity + +volumes: + results_backend_data: {} + vscode_server_data: {} + cache: {} diff --git a/.devcontainer/docker-compose.yaml b/.devcontainer/docker-compose.yaml index b2528cb..1853720 100644 --- a/.devcontainer/docker-compose.yaml +++ b/.devcontainer/docker-compose.yaml @@ -1,31 +1,6 @@ version: "2.4" services: - mognet: - build: - context: .. - dockerfile: .devcontainer/Dockerfile - args: - VARIANT: "3.8" - INSTALL_NODE: "false" - - mem_limit: 4g - cpus: 2 - working_dir: /workspaces/project-mognet - volumes: - - ..:/workspaces/project-mognet:cached - - vscode_server_data:/home/vscode/.vscode-server - - cache:/home/vscode/.cache - - command: - - bash - - -c - - | - set -e - sudo chown -R vscode:vscode /home/vscode/.vscode-server - sudo chown -R vscode:vscode /home/vscode/.cache - exec sleep infinity - redis: image: redis mem_limit: 128m @@ -33,24 +8,23 @@ services: cpus: 0.25 volumes: - results_backend_data:/data - networks: - default: - aliases: - - cps-results-backend + ports: + - 6379:6379 rabbitmq: image: rabbitmq:management - mem_limit: 512m + mem_limit: 1024m mem_reservation: 128m - cpus: 0.5 + cpus: 1 # Prevent huge memory usage by limiting the number # of file descriptors ulimits: nofile: soft: 8192 hard: 8192 + ports: + - 5672:5672 + - 15672:15672 volumes: results_backend_data: {} - vscode_server_data: {} - cache: {} diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..620d3e9 --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +ignore = E501,W503 +exclude = .git,__pycache__,docs/,build,dist,demo +# max-complexity = 10 diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000..2a7c577 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,3 @@ +# Apply black and isort to project +# See https://black.readthedocs.io/en/stable/guides/introducing_black_to_your_project.html +bc7b568fe0f925558718a71fdd42b5289a20cab8 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 981445a..9920c95 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,17 +12,51 @@ env: POETRY_VIRTUALENVS_CREATE: false jobs: - test-build-docs: + test: runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v3 + + - name: Run docker-compose stack for testing + run: make docker-up + + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + # python-version: "^3.8" + cache: "pip" + cache-dependency-path: "poetry.lock" + - name: Install poetry run: pip install poetry - - uses: actions/setup-python@v3 + + - name: Install dependencies + run: poetry install + + - name: Lint + run: make lint + + - name: Test + # Wait for RabbitMQ to start. + run: sleep 10 && make test + + test-build-docs: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 with: - python-version: 3.x - cache: 'pip' - cache-dependency-path: 'poetry.lock' + python-version: "3.10" + cache: "pip" + cache-dependency-path: "poetry.lock" + + - name: Install poetry + run: pip install poetry + - name: Install dependencies run: poetry install - - run: mkdocs build --verbose --clean --strict + + - run: poetry run mkdocs build --verbose --clean --strict diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 7d890ef..27497e5 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -8,7 +8,6 @@ on: env: POETRY_VIRTUALENVS_CREATE: false - jobs: deploy: permissions: @@ -16,14 +15,17 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Install poetry - run: pip install poetry - uses: actions/setup-python@v3 with: - python-version: 3.x - cache: 'pip' - cache-dependency-path: 'poetry.lock' + python-version: "3.10" + cache: "pip" + cache-dependency-path: "poetry.lock" + + - name: Install poetry + run: pip install poetry + - name: Install dependencies run: poetry install + - name: Build and push docs - run: mkdocs gh-deploy --force + run: poetry run mkdocs gh-deploy --force diff --git a/.gitignore b/.gitignore index 488605e..91511a0 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,8 @@ # Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,windows,macos,virtualenv,linux,vim,emacs # Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode,windows,macos,virtualenv,linux,vim,emacs +/config.json + ### Emacs ### # -*- mode: gitignore; -*- *~ diff --git a/.pylintrc b/.pylintrc index fb37b99..62a3d22 100644 --- a/.pylintrc +++ b/.pylintrc @@ -6,7 +6,7 @@ extension-pkg-whitelist= # Specify a score threshold to be exceeded before program exits with error. -fail-under=10.0 +fail-under=8.0 # Add files or directories to the blacklist. They should be base names, not # paths. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..5c4dfa8 --- /dev/null +++ b/Makefile @@ -0,0 +1,49 @@ + +DOCKER_COMPOSE_FILE=.devcontainer/docker-compose.yaml + +.PHONY: format-black +format-black: + @poetry run black . + +.PHONY: format-isort +format-isort: + @poetry run isort . + +.PHONY: format +format: format-isort format-black + +.PHONY: lint-isort +lint-isort: + @poetry run isort --check . + +.PHONY: lint-black +lint-black: + @poetry run black --check . + +lint-pylint: + @poetry run pylint mognet + +lint-flake8: + @poetry run flake8 + +lint-mypy: + @poetry run mypy mognet + +.PHONY: lint +lint: lint-isort lint-black lint-flake8 lint-pylint lint-mypy + +.PHONY: docker-up +docker-up: + @docker-compose \ + -f $(DOCKER_COMPOSE_FILE) \ + up -d + +.PHONY: docker-down +docker-down: + @docker-compose \ + -f $(DOCKER_COMPOSE_FILE) \ + down + +.PHONY: test +test: + @poetry run pytest test/ diff --git a/README.md b/README.md index ce30bea..99a3d83 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,10 @@ Mognet is a fast, simple framework to build distributed applications using task queues. +[![PyPI](https://img.shields.io/pypi/v/mognet)](https://www.pypi.org/project/mognet) +![PyPI - Python Version](https://img.shields.io/pypi/pyversions/mognet) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) + ## Installing Mognet can be installed via pip, with: diff --git a/demo/mognet_demo/api.py b/demo/mognet_demo/api.py index 568d8fe..8b0f1ce 100644 --- a/demo/mognet_demo/api.py +++ b/demo/mognet_demo/api.py @@ -3,8 +3,8 @@ """ import asyncio -from contextlib import suppress import os +from contextlib import suppress from typing import Optional from uuid import UUID, uuid4 @@ -18,13 +18,12 @@ Response, UploadFile, ) -from pydantic import confloat - from mognet_demo.config import DemoConfig from mognet_demo.models import Job, Upload, UploadJobResult from mognet_demo.mognet_app import app as mognet_app from mognet_demo.s3 import get_s3_client from mognet_demo.tasks import process_document_upload +from pydantic import confloat app = FastAPI( title="Mognet Demo API", diff --git a/demo/mognet_demo/config.py b/demo/mognet_demo/config.py index e3ff9d5..4bb75c0 100644 --- a/demo/mognet_demo/config.py +++ b/demo/mognet_demo/config.py @@ -11,10 +11,10 @@ from pydantic import BaseModel, BaseSettings from mognet import AppConfig -from mognet.app.app_config import ResultBackendConfig, StateBackendConfig, BrokerConfig +from mognet.app.app_config import BrokerConfig, ResultBackendConfig, StateBackendConfig from mognet.backend.backend_config import RedisResultBackendSettings -from mognet.state.state_backend_config import RedisStateBackendSettings from mognet.broker.broker_config import AmqpBrokerSettings +from mognet.state.state_backend_config import RedisStateBackendSettings class S3Config(BaseModel): diff --git a/demo/mognet_demo/middleware/auto_shutdown.py b/demo/mognet_demo/middleware/auto_shutdown.py index 552936e..4896e4b 100644 --- a/demo/mognet_demo/middleware/auto_shutdown.py +++ b/demo/mognet_demo/middleware/auto_shutdown.py @@ -16,11 +16,12 @@ import asyncio import logging from typing import TYPE_CHECKING, NoReturn, Optional -from mognet.middleware.middleware import Middleware + from mognet.cli.exceptions import GracefulShutdown +from mognet.middleware.middleware import Middleware if TYPE_CHECKING: - from mognet import Result, Context, App + from mognet import App, Context, Result _log = logging.getLogger(__name__) diff --git a/demo/mognet_demo/models.py b/demo/mognet_demo/models.py index ccb5855..e00b8e7 100644 --- a/demo/mognet_demo/models.py +++ b/demo/mognet_demo/models.py @@ -3,8 +3,8 @@ """ from typing import List, Optional from uuid import UUID -from pydantic import BaseModel +from pydantic import BaseModel from mognet.model.result_state import ResultState diff --git a/demo/mognet_demo/mognet_app.py b/demo/mognet_demo/mognet_app.py index bb08bac..0aa3d28 100644 --- a/demo/mognet_demo/mognet_app.py +++ b/demo/mognet_demo/mognet_app.py @@ -4,10 +4,9 @@ It can be used both for submitting jobs and for launching the worker process via the CLI """ -from mognet import App - from mognet_demo.config import DemoConfig -from mognet_demo.middleware.auto_shutdown import AutoShutdownMiddleware + +from mognet import App _mognet_config = DemoConfig.instance().mognet diff --git a/demo/mognet_demo/s3.py b/demo/mognet_demo/s3.py index 3bd91fd..2b307dd 100644 --- a/demo/mognet_demo/s3.py +++ b/demo/mognet_demo/s3.py @@ -5,10 +5,8 @@ from contextlib import asynccontextmanager from aiobotocore.session import get_session - -from mognet_demo.config import DemoConfig - from botocore.exceptions import ClientError +from mognet_demo.config import DemoConfig @asynccontextmanager diff --git a/demo/mognet_demo/tasks.py b/demo/mognet_demo/tasks.py index 73a1736..5ebf12d 100644 --- a/demo/mognet_demo/tasks.py +++ b/demo/mognet_demo/tasks.py @@ -7,15 +7,15 @@ import shutil import tempfile from pathlib import Path -from typing import Set, List +from typing import List, Set from uuid import uuid4 -from mognet import Context, Request, task - from mognet_demo.config import DemoConfig from mognet_demo.models import Document, Upload, UploadResult from mognet_demo.s3 import get_s3_client +from mognet import Context, Request, task + _log = logging.getLogger(__name__) diff --git a/demo/test/conftest.py b/demo/test/conftest.py index 2ccc950..fd28ed7 100644 --- a/demo/test/conftest.py +++ b/demo/test/conftest.py @@ -1,11 +1,9 @@ import pytest - - -from mognet.testing.pytest_integration import create_app_fixture from mognet_demo.config import DemoConfig - from mognet_demo.mognet_app import app +from mognet.testing.pytest_integration import create_app_fixture + @pytest.fixture def config(): diff --git a/demo/test/test_api.py b/demo/test/test_api.py index 45c670e..5cf3139 100644 --- a/demo/test/test_api.py +++ b/demo/test/test_api.py @@ -1,9 +1,10 @@ import time -import requests from pathlib import Path + +import requests from mognet_demo.models import Job, UploadJobResult -from mognet.model.result_state import READY_STATES, SUCCESS_STATES +from mognet.model.result_state import READY_STATES, SUCCESS_STATES _cwd = Path(__file__).parent diff --git a/demo/test/test_worker_task.py b/demo/test/test_worker_task.py index 57137df..36e4dc5 100644 --- a/demo/test/test_worker_task.py +++ b/demo/test/test_worker_task.py @@ -9,12 +9,12 @@ from uuid import uuid4 import pytest -from mognet import App, Request from mognet_demo.config import DemoConfig from mognet_demo.models import Upload, UploadResult from mognet_demo.s3 import get_s3_client from mognet_demo.tasks import InvalidFile +from mognet import App, Request _cwd = Path(__file__).parent diff --git a/mognet/__init__.py b/mognet/__init__.py index 5d414d3..597466c 100644 --- a/mognet/__init__.py +++ b/mognet/__init__.py @@ -1,11 +1,11 @@ from .app.app import App from .app.app_config import AppConfig from .context.context import Context -from .model.result import Result -from .primitives.request import Request from .decorators.task_decorator import task -from .model.result_state import ResultState from .middleware.middleware import Middleware +from .model.result import Result +from .model.result_state import ResultState +from .primitives.request import Request __all__ = [ "App", diff --git a/mognet/app/app.py b/mognet/app/app.py index 99fc071..6731583 100644 --- a/mognet/app/app.py +++ b/mognet/app/app.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import importlib import logging @@ -15,6 +17,7 @@ List, Optional, Set, + Type, TypeVar, Union, cast, @@ -22,7 +25,7 @@ ) from aio_pika.tools import shield -from mognet.backend.base_result_backend import BaseResultBackend + from mognet.backend.redis_result_backend import RedisResultBackend from mognet.broker.amqp_broker import AmqpBroker from mognet.broker.base_broker import ( @@ -34,7 +37,8 @@ from mognet.context.context import Context from mognet.exceptions.base_exceptions import CouldNotSubmit, ImproperlyConfigured from mognet.middleware.middleware import Middleware -from mognet.model.result import Result, ResultState +from mognet.model.result import Result +from mognet.model.result_state import ResultState from mognet.primitives.queries import QueryRequestMessage, StatusResponseMessage from mognet.primitives.request import Request from mognet.primitives.revoke import Revoke @@ -46,19 +50,23 @@ from mognet.worker.worker import MessageCancellationAction, Worker if sys.version_info >= (3, 10): - from typing import ParamSpec, Concatenate + from typing import Concatenate, ParamSpec else: - from typing_extensions import ParamSpec, Concatenate + from typing_extensions import Concatenate, ParamSpec if TYPE_CHECKING: from mognet.app.app_config import AppConfig + from mognet.backend.base_result_backend import BaseResultBackend + _log = logging.getLogger(__name__) _P = ParamSpec("_P") _Return = TypeVar("_Return") +_AnySvc = Union[Type[ClassService[Any]], Callable[..., Any]] + class App: """ @@ -85,7 +93,7 @@ class App: # Mapping of [service name] -> dependency object, # should be accessed via Context#get_service. - services: Dict[Any, Callable] + services: Dict[_AnySvc, _AnySvc] # Holds references to all the tasks. task_registry: TaskRegistry @@ -99,10 +107,10 @@ class App: worker: Optional[Worker] # Background tasks spawned by this app. - _consume_control_task: Optional[Future] = None - _heartbeat_task: Optional[Future] = None + _consume_control_task: Optional["Future[None]"] = None + _heartbeat_task: Optional["Future[None]"] = None - _worker_task: Optional[Future] + _worker_task: Optional["Future[None]"] _middleware: List[Middleware] @@ -139,9 +147,9 @@ def __init__( self.worker = None # Event that gets set when the app is closed - self._run_result = None + self._run_result: Optional["asyncio.Future[None]"] = None - def add_middleware(self, mw_inst: Middleware): + def add_middleware(self, mw_inst: Middleware) -> None: """ Adds middleware to this app. @@ -153,7 +161,7 @@ def add_middleware(self, mw_inst: Middleware): self._middleware.append(mw_inst) - async def start(self): + async def start(self) -> None: """ Starts the app. """ @@ -211,7 +219,9 @@ async def get_current_status_of_nodes( finally: await responses.aclose() - async def submit(self, req: "Request", context: Optional[Context] = None) -> Result: + async def submit( + self, req: "Request[_Return]", context: Optional[Context] = None + ) -> Result[_Return]: """ Submits a request for execution. @@ -230,7 +240,7 @@ async def submit(self, req: "Request", context: Optional[Context] = None) -> Res req.kwargs_repr = format_kwargs_repr(req.args, req.kwargs) _log.debug("Set default kwargs_repr on Request %r", req) - res = Result( + res: Result[_Return] = Result( self.result_backend, id=req.id, name=req.name, @@ -332,7 +342,7 @@ def create_request( func: Callable[Concatenate["Context", _P], Any], *args: _P.args, **kwargs: _P.kwargs, - ) -> Request: + ) -> Request[Any]: """ Creates a Request object from the function that was decorated with @task, and the provided arguments. @@ -369,7 +379,7 @@ async def run( ... - async def run(self, request, *args, **kwargs) -> Any: + async def run(self, request: Any, *args: Any, **kwargs: Any) -> Any: if not isinstance(request, Request): request = self.create_request(*args, **kwargs) @@ -378,7 +388,9 @@ async def run(self, request, *args, **kwargs) -> Any: return await res - async def revoke(self, request_id: uuid.UUID, *, force: bool = False) -> Result: + async def revoke( + self, request_id: uuid.UUID, *, force: bool = False + ) -> Result[Any]: """ Revoke the execution of a request. @@ -422,7 +434,7 @@ async def revoke(self, request_id: uuid.UUID, *, force: bool = False) -> Result: return res - async def connect(self): + async def connect(self) -> None: """Connect this app and its components to their respective backends.""" if self._connected: return @@ -448,16 +460,16 @@ async def connect(self): _log.debug("Connected to state backend %s", self.state_backend) - async def __aenter__(self): + async def __aenter__(self): # type: ignore await self.connect() return self - async def __aexit__(self, *args, **kwargs): + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: await self.close() - @shield - async def close(self): + @shield # type: ignore + async def close(self) -> None: """Close this app and its components's backends.""" _log.info("Closing app") @@ -469,7 +481,7 @@ async def close(self): _log.info("Closed app") - async def _stop(self): + async def _stop(self) -> None: await self._call_on_stopping_middleware() if self._heartbeat_task is not None: @@ -570,11 +582,11 @@ def _create_result_backend(self) -> BaseResultBackend: def _create_state_backend(self) -> BaseStateBackend: return RedisStateBackend(self.config.state_backend, app=self) - def _load_modules(self): + def _load_modules(self) -> None: for module in self.config.imports: importlib.import_module(module) - def _log_tasks_and_queues(self): + def _log_tasks_and_queues(self) -> None: all_tasks = self.task_registry.registered_task_names @@ -590,7 +602,7 @@ def _log_tasks_and_queues(self): _log.info("Registered %r queues:\n%s", len(all_queues), queues_msg) - async def _setup_broker(self): + async def _setup_broker(self) -> None: _log.debug("Connecting to broker %s", self.broker) await self.broker.connect() @@ -604,7 +616,7 @@ async def _setup_broker(self): _log.debug("Setup queues") - async def _stop_worker(self): + async def _stop_worker(self) -> None: if self.worker is None or not self._worker_task: _log.debug("No worker running") return @@ -628,7 +640,7 @@ async def _stop_worker(self): self.worker = None self._worker_task = None - async def _background_heartbeat(self): + async def _background_heartbeat(self) -> None: """ Background task that checks if the event loop was blocked for too long. @@ -658,7 +670,7 @@ async def _background_heartbeat(self): else: _log.debug("Event loop heartbeat: %.2fs", diff) - async def _consume_control_queue(self): + async def _consume_control_queue(self) -> None: """ Reads messages from the control queue and dispatches them. """ @@ -677,7 +689,7 @@ async def _consume_control_queue(self): "Could not process control queue message %r", msg, exc_info=exc ) - async def _process_control_message(self, msg: IncomingMessagePayload): + async def _process_control_message(self, msg: IncomingMessagePayload) -> None: _log.debug("Received control message id=%r", msg.id) try: @@ -735,14 +747,16 @@ async def _process_control_message(self, msg: IncomingMessagePayload): finally: await msg.ack() - async def _on_submitting(self, req: "Request", context: Optional["Context"]): + async def _on_submitting( + self, req: Request[Any], context: Optional["Context"] + ) -> None: for mw_inst in self._middleware: try: await mw_inst.on_request_submitting(req, context=context) except Exception as mw_exc: # pylint: disable=broad-except _log.error("Middleware failed", exc_info=mw_exc) - def _get_task_route(self, req: Union[str, Request]): + def _get_task_route(self, req: Union[str, Request[Any]]) -> str: if isinstance(req, Request): if req.queue_name is not None: _log.debug( @@ -774,7 +788,7 @@ def _get_task_route(self, req: Union[str, Request]): return default_queue - async def _call_on_starting_middleware(self): + async def _call_on_starting_middleware(self) -> None: for mw in self._middleware: try: await mw.on_app_starting(self) @@ -783,14 +797,14 @@ async def _call_on_starting_middleware(self): "Middleware %r failed on 'on_app_starting'", mw, exc_info=exc ) - async def _call_on_started_middleware(self): + async def _call_on_started_middleware(self) -> None: for mw in self._middleware: try: await mw.on_app_started(self) except Exception as exc: # pylint: disable=broad-except _log.debug("Middleware %r failed on 'on_app_started'", mw, exc_info=exc) - async def _call_on_stopping_middleware(self): + async def _call_on_stopping_middleware(self) -> None: for mw in self._middleware: try: await mw.on_app_stopping(self) @@ -799,7 +813,7 @@ async def _call_on_stopping_middleware(self): "Middleware %r failed on 'on_app_stopping'", mw, exc_info=exc ) - async def _call_on_stopped_middleware(self): + async def _call_on_stopped_middleware(self) -> None: for mw in self._middleware: try: await mw.on_app_stopped(self) diff --git a/mognet/app/app_config.py b/mognet/app/app_config.py index 46630c5..98c103d 100644 --- a/mognet/app/app_config.py +++ b/mognet/app/app_config.py @@ -2,11 +2,12 @@ import socket from typing import Any, Dict, List, Optional, Set +from pydantic.fields import Field +from pydantic.main import BaseModel + from mognet.backend.backend_config import ResultBackendConfig from mognet.broker.broker_config import BrokerConfig from mognet.state.state_backend_config import StateBackendConfig -from pydantic.fields import Field -from pydantic.main import BaseModel def _default_node_id() -> str: @@ -18,10 +19,10 @@ class Queues(BaseModel): exclude: Set[str] = Field(default_factory=set) @property - def is_valid(self): + def is_valid(self) -> bool: return not (len(self.include) > 0 and len(self.exclude) > 0) - def ensure_valid(self): + def ensure_valid(self) -> None: if not self.is_valid: raise ValueError( "Cannot specify both 'include' and 'exclude'. Choose either or none." diff --git a/mognet/backend/backend_config.py b/mognet/backend/backend_config.py index 27a0a2f..cb52315 100644 --- a/mognet/backend/backend_config.py +++ b/mognet/backend/backend_config.py @@ -1,6 +1,7 @@ from datetime import timedelta from enum import Enum from typing import Optional + from pydantic import BaseModel diff --git a/mognet/backend/base_result_backend.py b/mognet/backend/base_result_backend.py index 6d2fa44..df4403d 100644 --- a/mognet/backend/base_result_backend.py +++ b/mognet/backend/base_result_backend.py @@ -1,22 +1,14 @@ +from __future__ import annotations + import asyncio from abc import ABCMeta, abstractmethod from datetime import timedelta -from typing import ( - Any, - AsyncGenerator, - Dict, - List, - Optional, - Protocol, - Union, -) +from typing import Any, AsyncIterable, Dict, List, Optional, Protocol from uuid import UUID from mognet.backend.backend_config import ResultBackendConfig from mognet.model.result import Result, ResultValueHolder -ResultOrId = Union[UUID, "Result"] - class AppParameters(Protocol): name: str @@ -35,14 +27,14 @@ def __init__(self, config: ResultBackendConfig, app: AppParameters) -> None: self.app = app @abstractmethod - async def get(self, result_id: UUID) -> Optional[Result]: + async def get(self, result_id: UUID) -> Optional[Result[Any]]: """ Get a Result by it's ID. If it doesn't exist, this method returns None. """ raise NotImplementedError - async def get_many(self, *result_ids: UUID) -> List[Result]: + async def get_many(self, *result_ids: UUID) -> List[Result[Any]]: """ Get a list of Results by specifying their IDs. Results that don't exist will be removed from this list. @@ -51,7 +43,7 @@ async def get_many(self, *result_ids: UUID) -> List[Result]: return [r for r in all_results if r if r is not None] - async def get_or_create(self, result_id: UUID) -> Result: + async def get_or_create(self, result_id: UUID) -> Result[Any]: """ Get a Result by it's ID. If it doesn't exist, this method creates one. @@ -68,7 +60,7 @@ async def get_or_create(self, result_id: UUID) -> Result: return res @abstractmethod - async def set(self, result_id: UUID, result: Result) -> None: + async def set(self, result_id: UUID, result: Result[Any]) -> None: """ Save a Result. """ @@ -76,14 +68,14 @@ async def set(self, result_id: UUID, result: Result) -> None: async def wait( self, result_id: UUID, timeout: Optional[float] = None, poll: float = 0.1 - ) -> Result: + ) -> Result[Any]: """ Wait until a result is ready. Raises `asyncio.TimeoutError` if a timeout is set and exceeded. """ - async def waiter(): + async def waiter() -> Result[Any]: while True: result = await self.get(result_id) @@ -109,7 +101,7 @@ async def get_children_count(self, parent_result_id: UUID) -> int: @abstractmethod def iterate_children_ids( self, parent_result_id: UUID, *, count: Optional[int] = None - ) -> AsyncGenerator[UUID, None]: + ) -> AsyncIterable[UUID]: """ Get an AsyncGenerator for the IDs for the children of a Result. @@ -119,7 +111,7 @@ def iterate_children_ids( def iterate_children( self, parent_result_id: UUID, *, count: Optional[int] = None - ) -> AsyncGenerator[Result, None]: + ) -> AsyncIterable[Result[Any]]: """ Get an AsyncGenerator for the children of a Result. @@ -127,19 +119,19 @@ def iterate_children( """ raise NotImplementedError - async def __aenter__(self): + async def __aenter__(self): # type: ignore return self - async def __aexit__(self, *args, **kwargs): + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: return None - async def connect(self): + async def connect(self) -> None: """ Explicit method to connect to the backend provided by this Result backend. """ - async def close(self): + async def close(self) -> None: """ Explicit method to close the backend provided by this Result backend. diff --git a/mognet/backend/memory_result_backend.py b/mognet/backend/memory_result_backend.py index f416efe..252fbca 100644 --- a/mognet/backend/memory_result_backend.py +++ b/mognet/backend/memory_result_backend.py @@ -1,10 +1,15 @@ +from __future__ import annotations + from datetime import timedelta -from typing import Any, AsyncGenerator, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Optional, Set from uuid import UUID + from mognet.backend.backend_config import ResultBackendConfig from mognet.backend.base_result_backend import AppParameters, BaseResultBackend from mognet.exceptions.result_exceptions import ResultValueLost -from mognet.model.result import Result, ResultValueHolder + +if TYPE_CHECKING: + from mognet.model.result import Result, ResultValueHolder class MemoryResultBackend(BaseResultBackend): @@ -16,22 +21,22 @@ class MemoryResultBackend(BaseResultBackend): def __init__(self, config: ResultBackendConfig, app: AppParameters) -> None: super().__init__(config, app) - self._results: Dict[UUID, Result] = {} + self._results: Dict[UUID, Result[Any]] = {} self._result_tree: Dict[UUID, Set[UUID]] = {} self._values: Dict[UUID, ResultValueHolder] = {} self._metadata: Dict[UUID, Dict[str, Any]] = {} - async def get(self, result_id: UUID) -> Optional[Result]: + async def get(self, result_id: UUID) -> Optional[Result[Any]]: return self._results.get(result_id, None) - async def set(self, result_id: UUID, result: Result): + async def set(self, result_id: UUID, result: Result[Any]) -> None: self._results[result_id] = result async def get_children_count(self, parent_result_id: UUID) -> int: return len(self._result_tree.get(parent_result_id, set())) async def iterate_children_ids( - self, parent_result_id: UUID, *, count: int = None + self, parent_result_id: UUID, *, count: Optional[int] = None ) -> AsyncGenerator[UUID, None]: children = self._result_tree[parent_result_id] @@ -42,8 +47,8 @@ async def iterate_children_ids( break async def iterate_children( - self, parent_result_id: UUID, *, count: int = None - ) -> AsyncGenerator[Result, None]: + self, parent_result_id: UUID, *, count: Optional[int] = None + ) -> AsyncGenerator[Result[Any], None]: async for child_id in self.iterate_children_ids(parent_result_id, count=count): child = self._results.get(child_id, None) @@ -61,7 +66,7 @@ async def get_value(self, result_id: UUID) -> ResultValueHolder: return value - async def set_value(self, result_id: UUID, value: ResultValueHolder): + async def set_value(self, result_id: UUID, value: ResultValueHolder) -> None: self._values[result_id] = value async def get_metadata(self, result_id: UUID) -> Dict[str, Any]: @@ -71,7 +76,7 @@ async def get_metadata(self, result_id: UUID) -> Dict[str, Any]: async def set_metadata(self, result_id: UUID, **kwargs: Any) -> None: self._metadata.setdefault(result_id, {}).update(kwargs) - async def delete(self, result_id: UUID, include_children: bool = True): + async def delete(self, result_id: UUID, include_children: bool = True) -> None: if include_children: for child_id in self._result_tree.get(result_id, set()): await self.delete(child_id, include_children=include_children) @@ -82,10 +87,10 @@ async def delete(self, result_id: UUID, include_children: bool = True): async def set_ttl( self, result_id: UUID, ttl: timedelta, include_children: bool = True - ): + ) -> None: pass - async def close(self): + async def close(self) -> None: self._metadata = {} self._result_tree = {} self._results = {} diff --git a/mognet/backend/redis_result_backend.py b/mognet/backend/redis_result_backend.py index 5383f02..69f050e 100644 --- a/mognet/backend/redis_result_backend.py +++ b/mognet/backend/redis_result_backend.py @@ -5,13 +5,31 @@ import gzip import json import logging +import sys from asyncio import shield from datetime import timedelta -from typing import Any, AnyStr, Dict, Iterable, List, Optional, Set +from typing import ( + TYPE_CHECKING, + Any, + AnyStr, + AsyncIterable, + Awaitable, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Type, + TypeVar, + Union, + cast, +) from uuid import UUID from pydantic.tools import parse_raw_as -from redis.asyncio import Redis, from_url +from redis.asyncio import from_url from redis.exceptions import ConnectionError, TimeoutError from mognet.backend.backend_config import Encoding, ResultBackendConfig @@ -22,15 +40,29 @@ from mognet.model.result_state import READY_STATES, ResultState from mognet.tools.urls import censor_credentials +if sys.version_info < (3, 10): + from typing_extensions import TypeAlias +else: + from typing import TypeAlias + +if TYPE_CHECKING: + from redis.asyncio import Redis # noqa: F401 + _log = logging.getLogger(__name__) -def _retry(func): +_EncodedHSetPayload = Mapping[Union[str, bytes], Union[bytes, float, int, str]] + +_F = TypeVar("_F", bound=Callable[..., Awaitable[Any]]) +_Redis: TypeAlias = "Redis[Any]" + + +def _retry(func: _F) -> _F: @functools.wraps(func) - async def retry_wrapper(self: RedisResultBackend, *args, **kwargs): + async def retry_wrapper(self: RedisResultBackend, *args: Any, **kwargs: Any) -> Any: last_err = None - sleep_s = 1 + sleep_s = 1.0 for attempt in range(self._retry_connect_attempts): try: @@ -60,7 +92,7 @@ async def retry_wrapper(self: RedisResultBackend, *args, **kwargs): if last_err is not None: raise last_err - return retry_wrapper + return cast(_F, retry_wrapper) class RedisResultBackend(BaseResultBackend): @@ -72,11 +104,11 @@ def __init__(self, config: ResultBackendConfig, app: AppParameters) -> None: super().__init__(config, app) self._url = config.redis.url - self.__redis = None + self.__redis: Optional[_Redis] = None self._connected = False # Holds references to tasks which are spawned by .wait() - self._waiters: List[asyncio.Future] = [] + self._waiters: List["asyncio.Future[Any]"] = [] # Attributes for @_retry self._retry_connect_attempts = self.config.redis.retry_connect_attempts @@ -84,21 +116,21 @@ def __init__(self, config: ResultBackendConfig, app: AppParameters) -> None: self._retry_lock = asyncio.Lock() @property - def _redis(self) -> Redis: + def _redis(self) -> _Redis: if self.__redis is None: raise NotConnected return self.__redis @_retry - async def get(self, result_id: UUID) -> Optional[Result]: + async def get(self, result_id: UUID) -> Optional[Result[Any]]: obj_key = self._format_key(result_id) async with self._redis.pipeline(transaction=True) as pip: # Since HGETALL returns an empty HASH for keys that don't exist, # test if it exists at all and use that to check if we should return null. - pip.exists(obj_key) - pip.hgetall(obj_key) + _ = pip.exists(obj_key) + _ = pip.hgetall(obj_key) exists, value, *_ = await shield(pip.execute()) @@ -108,7 +140,7 @@ async def get(self, result_id: UUID) -> Optional[Result]: return self._decode_result(value) @_retry - async def get_or_create(self, result_id: UUID) -> Result: + async def get_or_create(self, result_id: UUID) -> Result[Any]: """ Gets a result, or creates one if it doesn't exist. """ @@ -116,11 +148,11 @@ async def get_or_create(self, result_id: UUID) -> Result: result_key = self._format_key(result_id) - pip.hsetnx(result_key, "id", json.dumps(str(result_id)).encode()) - pip.hgetall(result_key) + _ = pip.hsetnx(result_key, "id", json.dumps(str(result_id)).encode()) + _ = pip.hgetall(result_key) if self.config.redis.result_ttl is not None: - pip.expire(result_key, self.config.redis.result_ttl) + _ = pip.expire(result_key, self.config.redis.result_ttl) # Also set the value, to a default holding an absence of result. value_key = self._format_key(result_id, "value") @@ -129,10 +161,10 @@ async def get_or_create(self, result_id: UUID) -> Result: encoded = self._encode_result_value(default_not_ready) if self.config.redis.result_value_ttl is not None: - pip.expire(value_key, self.config.redis.result_value_ttl) + _ = pip.expire(value_key, self.config.redis.result_value_ttl) for encoded_k, encoded_v in encoded.items(): - pip.hsetnx(value_key, encoded_k, encoded_v) + _ = pip.hsetnx(value_key, encoded_k, encoded_v) existed, value, *_ = await shield(pip.execute()) @@ -141,7 +173,7 @@ async def get_or_create(self, result_id: UUID) -> Result: return self._decode_result(value) - def _encode_result_value(self, value: ResultValueHolder) -> Dict[str, bytes]: + def _encode_result_value(self, value: ResultValueHolder) -> _EncodedHSetPayload: contents = value.json().encode() encoding = b"null" @@ -167,21 +199,21 @@ def _decode_result_value(self, encoded: Dict[bytes, bytes]) -> ResultValueHolder return ResultValueHolder.parse_raw(contents, content_type="application/json") @_retry - async def set(self, result_id: UUID, result: Result): + async def set(self, result_id: UUID, result: Result[Any]) -> Any: key = self._format_key(result_id) async with self._redis.pipeline(transaction=True) as pip: encoded = _encode_result(result) - pip.hset(key, None, None, encoded) + _ = pip.hset(key, None, None, encoded) if self.config.redis.result_ttl is not None: - pip.expire(key, self.config.redis.result_ttl) + _ = pip.expire(key, self.config.redis.result_ttl) await shield(pip.execute()) - def _format_key(self, result_id: UUID, subkey: str = None) -> str: + def _format_key(self, result_id: UUID, subkey: Optional[str] = None) -> str: key = f"{self.app.name}.mognet.result.{str(result_id)}" if subkey: @@ -194,7 +226,7 @@ def _format_key(self, result_id: UUID, subkey: str = None) -> str: return key @_retry - async def add_children(self, result_id: UUID, *children: UUID): + async def add_children(self, result_id: UUID, *children: UUID) -> None: if not children: return @@ -204,10 +236,10 @@ async def add_children(self, result_id: UUID, *children: UUID): async with self._redis.pipeline(transaction=True) as pip: - pip.sadd(children_key, *_encode_children(children)) + _ = pip.sadd(children_key, *_encode_children(children)) if self.config.redis.result_ttl is not None: - pip.expire(children_key, self.config.redis.result_ttl) + _ = pip.expire(children_key, self.config.redis.result_ttl) await shield(pip.execute()) @@ -216,31 +248,31 @@ async def get_value(self, result_id: UUID) -> ResultValueHolder: async with self._redis.pipeline(transaction=True) as pip: - pip.exists(value_key) - pip.hgetall(value_key) + _ = pip.exists(value_key) + _ = pip.hgetall(value_key) exists, contents = await shield(pip.execute()) - if not exists: - raise ResultValueLost(result_id) + if not exists: + raise ResultValueLost(result_id) - return self._decode_result_value(contents) + return self._decode_result_value(contents) - async def set_value(self, result_id: UUID, value: ResultValueHolder): + async def set_value(self, result_id: UUID, value: ResultValueHolder) -> None: value_key = self._format_key(result_id, "value") encoded = self._encode_result_value(value) async with self._redis.pipeline(transaction=True) as pip: - pip.hset(value_key, None, None, encoded) + _ = pip.hset(value_key, None, None, encoded) if self.config.redis.result_value_ttl is not None: - pip.expire(value_key, self.config.redis.result_value_ttl) + _ = pip.expire(value_key, self.config.redis.result_value_ttl) await shield(pip.execute()) - async def delete(self, result_id: UUID, include_children: bool = True): + async def delete(self, result_id: UUID, include_children: bool = True) -> None: if include_children: async for child_id in self.iterate_children_ids(result_id): await self.delete(child_id, include_children=True) @@ -254,7 +286,7 @@ async def delete(self, result_id: UUID, include_children: bool = True): async def set_ttl( self, result_id: UUID, ttl: timedelta, include_children: bool = True - ): + ) -> None: if include_children: async for child_id in self.iterate_children_ids(result_id): await self.set_ttl(child_id, ttl, include_children=True) @@ -269,7 +301,7 @@ async def set_ttl( await shield(self._redis.expire(value_key, ttl)) await shield(self._redis.expire(metadata_key, ttl)) - async def connect(self): + async def connect(self) -> None: if self._connected: return @@ -277,7 +309,7 @@ async def connect(self): await self._connect() - async def close(self): + async def close(self) -> None: self._connected = False await self._close_waiters() @@ -290,8 +322,8 @@ async def get_children_count(self, parent_result_id: UUID) -> int: return await shield(self._redis.scard(children_key)) async def iterate_children_ids( - self, parent_result_id: UUID, *, count: Optional[float] = None - ): + self, parent_result_id: UUID, *, count: Optional[int] = None + ) -> AsyncIterable[UUID]: children_key = self._format_key(parent_result_id, "children") raw_child_id: bytes @@ -300,8 +332,8 @@ async def iterate_children_ids( yield child_id async def iterate_children( - self, parent_result_id: UUID, *, count: Optional[float] = None - ): + self, parent_result_id: UUID, *, count: Optional[int] = None + ) -> AsyncIterable[Result[Any]]: async for child_id in self.iterate_children_ids(parent_result_id, count=count): child = await self.get(child_id) @@ -311,8 +343,8 @@ async def iterate_children( @_retry async def wait( self, result_id: UUID, timeout: Optional[float] = None, poll: float = 1 - ) -> Result: - async def waiter(): + ) -> Result[Any]: + async def waiter() -> Result[Any]: key = self._format_key(result_id=result_id) # Type def for the state key. It can (but shouldn't) @@ -322,7 +354,7 @@ async def waiter(): while True: raw_state = await shield(self._redis.hget(key, "state")) or b"null" - state = parse_raw_as(t, raw_state) + state = parse_raw_as(cast(Type[Optional[ResultState]], t), raw_state) if state is None: raise ResultValueLost(result_id) @@ -366,25 +398,25 @@ async def set_metadata(self, result_id: UUID, **kwargs: Any) -> None: async with self._redis.pipeline(transaction=True) as pip: - pip.hset(key, None, None, _dict_to_json_dict(kwargs)) + _ = pip.hset(key, None, None, _dict_to_json_dict(kwargs)) if self.config.redis.result_ttl is not None: - pip.expire(key, self.config.redis.result_ttl) + _ = pip.expire(key, self.config.redis.result_ttl) await shield(pip.execute()) - def __repr__(self): + def __repr__(self) -> str: return f"RedisResultBackend(url={censor_credentials(self._url)!r})" - async def __aenter__(self): + async def __aenter__(self) -> RedisResultBackend: await self.connect() return self - async def __aexit__(self, *args, **kwargs): + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: await self.close() - async def _close_waiters(self): + async def _close_waiters(self) -> None: """ Cancel any wait loop we have running. """ @@ -401,9 +433,9 @@ async def _close_waiters(self): except Exception as exc: # pylint: disable=broad-except _log.debug("Error on waiter task %r", waiter_task, exc_info=exc) - async def _create_redis(self): + async def _create_redis(self) -> _Redis: _log.debug("Creating Redis connection") - redis: Redis = await from_url( + redis: _Redis = await from_url( self._url, max_connections=self.config.redis.max_connections, ) @@ -411,13 +443,13 @@ async def _create_redis(self): return redis @_retry - async def _connect(self): + async def _connect(self) -> None: if self.__redis is None: self.__redis = await self._create_redis() await shield(self._redis.ping()) - async def _disconnect(self): + async def _disconnect(self) -> None: redis = self.__redis if redis is not None: @@ -425,18 +457,18 @@ async def _disconnect(self): _log.debug("Closing Redis connection") await redis.close() - def _decode_result(self, json_dict: Dict[bytes, bytes]) -> Result: + def _decode_result(self, json_dict: Dict[bytes, bytes]) -> Result[Any]: # Load the dict of JSON values first; then update it with overrides. value = _decode_json_dict(json_dict) return Result(self, **value) -def _encode_result(result: Result) -> Dict[str, bytes]: - json_dict: dict = json.loads(result.json()) +def _encode_result(result: Result[Any]) -> _EncodedHSetPayload: + json_dict: Dict[str, Any] = json.loads(result.json()) return _dict_to_json_dict(json_dict) -def _dict_to_json_dict(value: Dict[str, Any]) -> Dict[str, bytes]: +def _dict_to_json_dict(value: Dict[str, Any]) -> _EncodedHSetPayload: return {k: json.dumps(v).encode() for k, v in value.items()} @@ -446,7 +478,7 @@ def _encode_children(children: Iterable[UUID]) -> Set[bytes]: return {c.bytes for c in children} -def _decode_json_dict(json_dict: Dict[bytes, AnyStr]) -> dict: +def _decode_json_dict(json_dict: Dict[bytes, AnyStr]) -> Dict[str, Any]: """ Decode a dict coming from hgetall/hmget, which is encoded with bytes as keys and bytes or str as values, containing JSON values. diff --git a/mognet/broker/amqp_broker.py b/mognet/broker/amqp_broker.py index 8ff7ea9..e71ba22 100644 --- a/mognet/broker/amqp_broker.py +++ b/mognet/broker/amqp_broker.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import json import logging @@ -7,8 +9,8 @@ TYPE_CHECKING, Any, AsyncGenerator, - Awaitable, Callable, + Coroutine, Dict, List, Optional, @@ -23,6 +25,8 @@ from aio_pika.message import IncomingMessage, Message from aio_pika.queue import Queue from aiormq.exceptions import AMQPChannelError, ChannelInvalidStateError +from pydantic.fields import PrivateAttr + from mognet.broker.base_broker import ( BaseBroker, IncomingMessagePayload, @@ -36,7 +40,6 @@ from mognet.primitives.queries import QueryResponseMessage from mognet.tools.retries import retryableasyncmethod from mognet.tools.urls import censor_credentials -from pydantic.fields import PrivateAttr _log = logging.getLogger(__name__) @@ -58,7 +61,7 @@ def __init__( self._incoming_message = incoming_message self._processed = False - async def ack(self): + async def ack(self) -> bool: if self._processed: return True @@ -76,7 +79,7 @@ async def ack(self): _log.error("Could not ACK message %r", self.id, exc_info=exc) return False - async def nack(self): + async def nack(self) -> bool: if self._processed: return True @@ -124,12 +127,12 @@ def __init__(self, app: "App", config: BrokerConfig) -> None: super().__init__() self._connected = False - self.__connection = None + self.__connection: Optional[aio_pika.RobustConnection] = None self.config = config self._task_queues = {} - self._control_queue = None + self._control_queue: Optional[Queue] = None # Lock to prevent duplicate queue declaration self._lock = Lock() @@ -142,7 +145,7 @@ def __init__(self, app: "App", config: BrokerConfig) -> None: # List of callbacks for when connection drops self._on_connection_failed_callbacks: List[ - Callable[[Optional[BaseException]], Awaitable] + Callable[[Optional[BaseException]], Coroutine[Any, Any, Any]] ] = [] @property @@ -152,18 +155,18 @@ def _connection(self) -> Connection: return self.__connection - async def ack(self, delivery_tag: str): + async def ack(self, delivery_tag: str) -> None: await self._task_channel.channel.basic_ack(delivery_tag) - async def nack(self, delivery_tag: str): + async def nack(self, delivery_tag: str) -> None: await self._task_channel.channel.basic_nack(delivery_tag) @_retry - async def set_task_prefetch(self, prefetch: int): + async def set_task_prefetch(self, prefetch: int) -> None: await self._task_channel.set_qos(prefetch_count=prefetch, global_=True) @_retry - async def send_task_message(self, queue: str, payload: MessagePayload): + async def send_task_message(self, queue: str, payload: MessagePayload) -> None: amqp_queue = self._task_queue_name(queue) msg = Message( @@ -199,7 +202,7 @@ async def consume_control_queue( yield message @_retry - async def send_control_message(self, payload: MessagePayload): + async def send_control_message(self, payload: MessagePayload) -> None: msg = Message( body=payload.json().encode(), content_type="application/json", @@ -212,7 +215,7 @@ async def send_control_message(self, payload: MessagePayload): await self._control_exchange.publish(msg, "") @_retry - async def _send_query_message(self, payload: MessagePayload): + async def _send_query_message(self, payload: MessagePayload) -> Queue: callback_queue = await self._task_channel.declare_queue( name=self._callback_queue_name, durable=False, @@ -253,7 +256,7 @@ async def send_query_message( async with callback_queue.iterator() as iterator: async for message in iterator: async with message.process(): - contents: dict = json.loads(message.body) + contents: Dict[str, Any] = json.loads(message.body) msg = _AmqpIncomingMessagePayload( broker=self, incoming_message=message, **contents ) @@ -262,15 +265,15 @@ async def send_query_message( if callback_queue is not None: await callback_queue.delete() - async def setup_control_queue(self): + async def setup_control_queue(self) -> None: await self._get_or_create_control_queue() - async def setup_task_queue(self, queue: TaskQueue): + async def setup_task_queue(self, queue: TaskQueue) -> None: await self._get_or_create_task_queue(queue) @_retry - async def _create_connection(self): - connection = await aio_pika.connect_robust( + async def _create_connection(self) -> aio_pika.RobustConnection: + connection: aio_pika.RobustConnection = await aio_pika.connect_robust( self.config.amqp.url, reconnect_interval=self.app.config.reconnect_interval, client_properties={ @@ -284,11 +287,13 @@ async def _create_connection(self): return connection def add_connection_failed_callback( - self, cb: Callable[[Optional[BaseException]], Awaitable] - ): + self, cb: Callable[[Optional[BaseException]], Coroutine[Any, Any, Any]] + ) -> None: self._on_connection_failed_callbacks.append(cb) - def _send_connection_failed_events(self, connection, exc=None): + def _send_connection_failed_events( + self, connection: aio_pika.RobustConnection, exc: Optional[BaseException] = None + ) -> None: if not self._connected: _log.debug( "Not sending connection closed events because we are disconnected" @@ -304,7 +309,7 @@ def _send_connection_failed_events(self, connection, exc=None): len(tasks), ) - def notify_task_completion_callback(fut: asyncio.Future): + def notify_task_completion_callback(fut: "asyncio.Future[Any]") -> None: exc = fut.exception() if exc and not fut.cancelled(): @@ -314,7 +319,7 @@ def notify_task_completion_callback(fut: asyncio.Future): notify_task = asyncio.create_task(task) notify_task.add_done_callback(notify_task_completion_callback) - async def connect(self): + async def connect(self) -> None: if self._connected: return @@ -337,10 +342,10 @@ async def connect(self): _log.debug("Connected") - async def set_control_prefetch(self, prefetch: int): + async def set_control_prefetch(self, prefetch: int) -> None: await self._control_channel.set_qos(prefetch_count=prefetch, global_=False) - async def close(self): + async def close(self) -> None: self._connected = False connection = self.__connection @@ -349,11 +354,13 @@ async def close(self): self.__connection = None _log.debug("Closing connections") - await connection.close() + await connection.close() # type: ignore _log.debug("Connection closed") @_retry - async def send_reply(self, message: IncomingMessagePayload, reply: MessagePayload): + async def send_reply( + self, message: IncomingMessagePayload, reply: MessagePayload + ) -> None: if not message.reply_to: raise ValueError("Message has no reply_to set") @@ -395,22 +402,22 @@ async def purge_control_queue(self) -> int: result = await self._control_queue.purge() - return result.message_count + return result.message_count # type: ignore - def __repr__(self): + def __repr__(self) -> str: return f"AmqpBroker(url={censor_credentials(self.config.amqp.url)!r})" - async def __aenter__(self): + async def __aenter__(self) -> AmqpBroker: await self.connect() return self - async def __aexit__(self, *args, **kwargs): + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: await self.close() return None - async def _create_exchanges(self): + async def _create_exchanges(self) -> None: self._direct_exchange = await self._task_channel.declare_exchange( self._direct_exchange_name, type=ExchangeType.DIRECT, @@ -430,7 +437,7 @@ async def _consume( async for msg in queue_iterator: try: - contents: dict = json.loads(msg.body) + contents: Dict[str, Any] = json.loads(msg.body) payload = _AmqpIncomingMessagePayload( broker=self, incoming_message=msg, **contents @@ -494,6 +501,8 @@ async def _get_or_create_control_queue(self) -> Queue: _log.debug("Prepared control queue=%r", self._control_queue.name) + assert self._control_queue is not None + return self._control_queue @_retry diff --git a/mognet/broker/base_broker.py b/mognet/broker/base_broker.py index a78422b..8a6bc13 100644 --- a/mognet/broker/base_broker.py +++ b/mognet/broker/base_broker.py @@ -1,15 +1,16 @@ from abc import ABCMeta, abstractmethod -from mognet.model.queue_stats import QueueStats -from mognet.primitives.queries import QueryResponseMessage -from typing import AsyncGenerator, Awaitable, Callable, List, Optional +from typing import Any, AsyncGenerator, Callable, Coroutine, Optional from pydantic.main import BaseModel +from mognet.model.queue_stats import QueueStats +from mognet.primitives.queries import QueryResponseMessage + class MessagePayload(BaseModel): id: str kind: str - payload: dict + payload: Any priority: int = 5 @@ -35,11 +36,11 @@ class TaskQueue(BaseModel): class BaseBroker(metaclass=ABCMeta): @abstractmethod - async def send_task_message(self, queue: str, payload: MessagePayload): + async def send_task_message(self, queue: str, payload: MessagePayload) -> None: raise NotImplementedError @abstractmethod - async def send_control_message(self, payload: MessagePayload): + async def send_control_message(self, payload: MessagePayload) -> None: raise NotImplementedError @abstractmethod @@ -50,30 +51,30 @@ def consume_tasks(self, queue: str) -> AsyncGenerator[IncomingMessagePayload, No def consume_control_queue(self) -> AsyncGenerator[IncomingMessagePayload, None]: raise NotImplementedError - async def __aenter__(self): + async def __aenter__(self): # type: ignore return self - async def __aexit__(self, *args, **kwargs): + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: return None - async def connect(self): + async def connect(self) -> None: pass - async def close(self): + async def close(self) -> None: pass @abstractmethod - async def setup_task_queue(self, queue: TaskQueue): + async def setup_task_queue(self, queue: TaskQueue) -> None: raise NotImplementedError @abstractmethod - async def setup_control_queue(self): + async def setup_control_queue(self) -> None: raise NotImplementedError - async def set_task_prefetch(self, prefetch: int): + async def set_task_prefetch(self, prefetch: int) -> None: pass - async def set_control_prefetch(self, prefetch: int): + async def set_control_prefetch(self, prefetch: int) -> None: pass @abstractmethod @@ -83,7 +84,9 @@ def send_query_message( raise NotImplementedError @abstractmethod - async def send_reply(self, message: IncomingMessagePayload, reply: MessagePayload): + async def send_reply( + self, message: IncomingMessagePayload, reply: MessagePayload + ) -> None: raise NotImplementedError @abstractmethod @@ -95,8 +98,8 @@ async def purge_control_queue(self) -> int: raise NotImplementedError def add_connection_failed_callback( - self, cb: Callable[[Optional[BaseException]], Awaitable] - ): + self, cb: Callable[[Optional[BaseException]], Coroutine[Any, Any, Any]] + ) -> None: pass @abstractmethod diff --git a/mognet/broker/broker_config.py b/mognet/broker/broker_config.py index c96dfd3..2e73799 100644 --- a/mognet/broker/broker_config.py +++ b/mognet/broker/broker_config.py @@ -1,4 +1,3 @@ -from typing import Optional from pydantic import BaseModel diff --git a/mognet/broker/memory_broker.py b/mognet/broker/memory_broker.py index 793312e..621d4fd 100644 --- a/mognet/broker/memory_broker.py +++ b/mognet/broker/memory_broker.py @@ -1,8 +1,10 @@ -from asyncio import Queue import asyncio +from asyncio import Queue +from typing import Any, AsyncGenerator, Dict from uuid import uuid4 from pydantic.fields import PrivateAttr + from mognet.broker.base_broker import ( BaseBroker, IncomingMessagePayload, @@ -11,15 +13,16 @@ ) from mognet.model.queue_stats import QueueStats from mognet.primitives.queries import QueryResponseMessage -from typing import AsyncGenerator, Dict class _InMemoryIncomingMessagePayload(IncomingMessagePayload): _broker: "InMemoryBroker" = PrivateAttr() - _queue: Queue = PrivateAttr() + _queue: "Queue[MessagePayload]" = PrivateAttr() - def __init__(self, broker: "InMemoryBroker", queue: Queue, **data): + def __init__( + self, broker: "InMemoryBroker", queue: "Queue[MessagePayload]", **data: Any + ) -> None: super().__init__(**data) self._broker = broker @@ -49,10 +52,10 @@ def __init__(self) -> None: self._callback_queues: Dict[str, "Queue[MessagePayload]"] = {} - async def send_task_message(self, queue: str, payload: MessagePayload): + async def send_task_message(self, queue: str, payload: MessagePayload) -> None: await self._task_queues[queue].put(payload) - async def send_control_message(self, payload: MessagePayload): + async def send_control_message(self, payload: MessagePayload) -> None: await self._control_queue.put(payload) async def consume_tasks( @@ -91,27 +94,27 @@ async def consume_control_queue( reply_to=None, ) - async def setup_task_queue(self, queue: TaskQueue): + async def setup_task_queue(self, queue: TaskQueue) -> None: self._task_queues[queue.name] = Queue() - async def setup_control_queue(self): + async def setup_control_queue(self) -> None: pass - def _update_task_event(self): + def _update_task_event(self) -> None: if self._unacked_task_count < self._task_prefetch: self._task_event.set() else: self._task_event.clear() - async def ack_task(self): + async def ack_task(self) -> None: self._unacked_task_count = max(0, self._unacked_task_count - 1) self._update_task_event() - async def set_task_prefetch(self, prefetch: int): + async def set_task_prefetch(self, prefetch: int) -> None: self._task_prefetch = prefetch self._update_task_event() - async def set_control_prefetch(self, prefetch: int): + async def set_control_prefetch(self, prefetch: int) -> None: self._control_prefetch = prefetch async def send_query_message( @@ -139,7 +142,9 @@ async def send_query_message( finally: self._callback_queues.pop(queue_id, None) - async def send_reply(self, message: IncomingMessagePayload, reply: MessagePayload): + async def send_reply( + self, message: IncomingMessagePayload, reply: MessagePayload + ) -> None: if not message.reply_to: raise ValueError("No one to send reply to") @@ -167,7 +172,7 @@ async def task_queue_stats(self, task_queue_name: str) -> QueueStats: ) -def _purge_queue(q: Queue) -> int: +def _purge_queue(q: Queue[Any]) -> int: old_size = q.qsize() while not q.empty(): diff --git a/mognet/cli/main.py b/mognet/cli/main.py index 60ce9c9..26ecfbf 100644 --- a/mognet/cli/main.py +++ b/mognet/cli/main.py @@ -2,17 +2,14 @@ import logging import os import sys -from typing import TYPE_CHECKING import typer + +from mognet import App from mognet.cli import nodes, queues, run, tasks from mognet.cli.cli_state import state from mognet.cli.models import LogLevel -if TYPE_CHECKING: - from mognet import App - - main = typer.Typer(name="mognet") @@ -33,7 +30,7 @@ def _get_app(app_pointer: str) -> "App": app = getattr(app_module, app_var_name or "app", None) - if app is None: + if not isinstance(app, App): raise _AppNotFound( f"Could not find an app on {app_module_name!r}. Expected to find an attribute named {app_var_name!r}\n" f"You can specify a different attribute after a ':'.\n" @@ -45,12 +42,12 @@ def _get_app(app_pointer: str) -> "App": @main.callback() def callback( - app: str = typer.Argument(..., help="App module to import"), - log_level: LogLevel = typer.Option("INFO", metavar="log-level"), - log_format: str = typer.Option( + app: str = typer.Argument(..., help="App module to import"), # noqa: B008 + log_level: LogLevel = typer.Option("INFO", metavar="log-level"), # noqa: B008 + log_format: str = typer.Option( # noqa: B008 "%(asctime)s:%(name)s:%(levelname)s:%(message)s", metavar="log-format" ), -): +) -> None: """Mognet CLI""" logging.basicConfig( diff --git a/mognet/cli/nodes.py b/mognet/cli/nodes.py index 298cb09..614de34 100644 --- a/mognet/cli/nodes.py +++ b/mognet/cli/nodes.py @@ -1,16 +1,17 @@ import asyncio from datetime import datetime -from typing import List, Optional +from typing import Any, List, Optional import tabulate import typer +from pydantic import BaseModel, Field + from mognet.cli.cli_state import state from mognet.cli.models import OutputFormat from mognet.cli.run_in_loop import run_in_loop from mognet.model.result import Result from mognet.primitives.queries import StatusResponseMessage from mognet.tools.dates import now_utc -from pydantic import BaseModel, Field group = typer.Typer() @@ -18,35 +19,38 @@ @group.command("status") @run_in_loop async def status( - format: OutputFormat = typer.Option(OutputFormat.TEXT, metavar="format"), - text_label_format: str = typer.Option( + format: OutputFormat = typer.Option( # noqa: B008 + OutputFormat.TEXT, metavar="format" + ), # noqa: B008 + text_label_format: str = typer.Option( # noqa: B008 "{name}(id={id!r}, state={state!r})", metavar="text-label-format", help="Label format for text format", ), - json_indent: int = typer.Option(2, metavar="json-indent"), - poll: Optional[int] = typer.Option( + json_indent: int = typer.Option(2, metavar="json-indent"), # noqa: B008 + poll: Optional[int] = typer.Option( # noqa: B008 None, metavar="poll", help="Polling interval, in seconds (default=None)", ), - timeout: int = typer.Option( + timeout: int = typer.Option( # noqa: B008 30, help="Timeout for querying nodes", ), -): +) -> None: """Query each node for their status""" async with state["app_instance"] as app: while True: each_node_status: List[StatusResponseMessage] = [] - async def read_status(): - async for node_status in app.get_current_status_of_nodes(): - each_node_status.append(node_status) + async def read_status() -> List[StatusResponseMessage]: + return [ns async for ns in app.get_current_status_of_nodes()] try: - await asyncio.wait_for(read_status(), timeout=timeout) + each_node_status.extend( + await asyncio.wait_for(read_status(), timeout=timeout) + ) except asyncio.TimeoutError: pass @@ -115,6 +119,6 @@ async def read_status(): class _CliStatusReport(BaseModel): class NodeStatus(BaseModel): node_id: str - running_requests: List[Result] + running_requests: List[Result[Any]] node_status: List[NodeStatus] = Field(default_factory=list) diff --git a/mognet/cli/queues.py b/mognet/cli/queues.py index fc37b65..452db22 100644 --- a/mognet/cli/queues.py +++ b/mognet/cli/queues.py @@ -1,15 +1,16 @@ -import logging - import typer -from mognet.cli.run_in_loop import run_in_loop + from mognet.cli.cli_state import state +from mognet.cli.run_in_loop import run_in_loop group = typer.Typer() @group.command("purge") @run_in_loop -async def purge(force: bool = typer.Option(False)): +async def purge( + force: bool = typer.Option(False), # noqa: B008 +) -> None: """Purge task and control queues""" if not force: diff --git a/mognet/cli/run.py b/mognet/cli/run.py index 180637d..3b8a5f3 100644 --- a/mognet/cli/run.py +++ b/mognet/cli/run.py @@ -1,9 +1,11 @@ -import aiorun -from asyncio import AbstractEventLoop import asyncio import logging -from typing import Optional +from asyncio import AbstractEventLoop +from typing import Any, Dict, Optional + +import aiorun import typer + from mognet.cli.cli_state import state from mognet.cli.exceptions import GracefulShutdown @@ -14,17 +16,17 @@ @group.callback() def run( - include_queues: Optional[str] = typer.Option( + include_queues: Optional[str] = typer.Option( # noqa: B008 None, metavar="include-queues", help="Comma-separated list of the ONLY queues to listen on.", ), - exclude_queues: Optional[str] = typer.Option( + exclude_queues: Optional[str] = typer.Option( # noqa: B008 None, metavar="exclude-queues", help="Comma-separated list of the ONLY queues to NOT listen on.", ), -): +) -> int: """Run the app""" app = state["app_instance"] @@ -40,17 +42,19 @@ def run( queues.ensure_valid() - async def start(): + async def start() -> None: async with app: await app.start() - async def stop(_: AbstractEventLoop): + async def stop(_: AbstractEventLoop) -> None: _log.info("Going to close app as part of a shut down") await app.close() - pending_exception_to_raise = SystemExit(0) + pending_exception_to_raise: BaseException = SystemExit(0) - def custom_exception_handler(loop: AbstractEventLoop, context: dict): + def custom_exception_handler( + loop: AbstractEventLoop, context: Dict[Any, Any] + ) -> None: """See: https://docs.python.org/3/library/asyncio-eventloop.html#error-handling-api""" nonlocal pending_exception_to_raise diff --git a/mognet/cli/run_in_loop.py b/mognet/cli/run_in_loop.py index a93083e..c6cee30 100644 --- a/mognet/cli/run_in_loop.py +++ b/mognet/cli/run_in_loop.py @@ -1,15 +1,21 @@ import asyncio from functools import wraps +from typing import Any, Awaitable, Callable, TypeVar +from typing_extensions import ParamSpec -def run_in_loop(f): +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +def run_in_loop(f: Callable[_P, Awaitable[_R]]) -> Callable[_P, _R]: """ Utility to run a click/typer command function in an event loop (because they don't support it out of the box) """ @wraps(f) - def in_loop(*args, **kwargs): + def in_loop(*args: Any, **kwargs: Any) -> _R: loop = asyncio.get_event_loop() return loop.run_until_complete(f(*args, **kwargs)) diff --git a/mognet/cli/tasks.py b/mognet/cli/tasks.py index ca70658..3ab4026 100644 --- a/mognet/cli/tasks.py +++ b/mognet/cli/tasks.py @@ -1,16 +1,17 @@ import asyncio -import treelib import logging from typing import Optional from uuid import UUID import tabulate +import treelib import typer + +from mognet.cli.cli_state import state from mognet.cli.models import OutputFormat from mognet.cli.run_in_loop import run_in_loop from mognet.exceptions.result_exceptions import ResultValueLost from mognet.model.result import _ExceptionInfo -from mognet.cli.cli_state import state from mognet.model.result_tree import ResultTree _log = logging.getLogger(__name__) @@ -21,17 +22,17 @@ @group.command("get") @run_in_loop async def get( - task_id: UUID = typer.Argument( + task_id: UUID = typer.Argument( # noqa: B008 ..., metavar="id", help="Task ID to get", ), - include_value: bool = typer.Option( + include_value: bool = typer.Option( # noqa: B008 False, metavar="include-value", help="If passed, the task's result (or exception) will be printed", ), -): +) -> None: """Get a task's details""" async with state["app_instance"] as app: @@ -77,17 +78,17 @@ async def get( @group.command("revoke") @run_in_loop async def revoke( - task_id: UUID = typer.Argument( + task_id: UUID = typer.Argument( # noqa: B008 ..., metavar="id", help="Task ID to revoke", ), - force: bool = typer.Option( + force: bool = typer.Option( # noqa: B008 False, metavar="force", help="Attempt revoking anyway if the result is complete. Helps cleaning up cases where subtasks may have been spawned.", ), -): +) -> None: """Revoke a task""" async with state["app_instance"] as app: @@ -108,22 +109,24 @@ async def revoke( @group.command("tree") @run_in_loop async def tree( - task_id: UUID = typer.Argument( + task_id: UUID = typer.Argument( # noqa: B008 ..., metavar="id", help="Task ID to get tree from", ), - format: OutputFormat = typer.Option(OutputFormat.TEXT, metavar="format"), - json_indent: int = typer.Option(2, metavar="json-indent"), - text_label_format: str = typer.Option( + format: OutputFormat = typer.Option( # noqa: B008 + OutputFormat.TEXT, metavar="format" + ), # noqa: B008 + json_indent: int = typer.Option(2, metavar="json-indent"), # noqa: B008 + text_label_format: str = typer.Option( # noqa: B008 "{name}(id={id!r}, state={state!r})", metavar="text-label-format", help="Label format for text format", ), - max_depth: int = typer.Option(3, metavar="max-depth"), - max_width: int = typer.Option(16, metavar="max-width"), - poll: Optional[int] = typer.Option(None, metavar="poll"), -): + max_depth: int = typer.Option(3, metavar="max-depth"), # noqa: B008 + max_width: int = typer.Option(16, metavar="max-width"), # noqa: B008 + poll: Optional[int] = typer.Option(None, metavar="poll"), # noqa: B008 +) -> None: """Get the tree (descendants) of a task""" async with state["app_instance"] as app: @@ -143,17 +146,7 @@ async def tree( if format == "text": t = treelib.Tree() - def build_tree(n: ResultTree, parent: Optional[ResultTree] = None): - t.create_node( - tag=text_label_format.format(**n.dict()), - identifier=n.result.id, - parent=None if parent is None else parent.result.id, - ) - - for c in n.children: - build_tree(c, parent=n) - - build_tree(tree) + _build_tree(t, text_label_format, tree) t.show() @@ -164,3 +157,19 @@ def build_tree(n: ResultTree, parent: Optional[ResultTree] = None): break await asyncio.sleep(poll) + + +def _build_tree( + t: treelib.Tree, + text_label_format: str, + n: ResultTree, + parent: Optional[ResultTree] = None, +) -> None: + t.create_node( + tag=text_label_format.format(**n.dict()), + identifier=n.result.id, + parent=None if parent is None else parent.result.id, + ) + + for c in n.children: + _build_tree(t, text_label_format, c, parent=n) diff --git a/mognet/context/context.py b/mognet/context/context.py index eca36e7..3a052b7 100644 --- a/mognet/context/context.py +++ b/mognet/context/context.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import inspect import logging @@ -7,9 +9,8 @@ Any, Awaitable, Callable, - Coroutine, - Dict, List, + Optional, Set, Type, TypeVar, @@ -18,15 +19,17 @@ overload, ) +from mognet.model.result import Result + if sys.version_info >= (3, 10): - from typing import ParamSpec, Concatenate + from typing import Concatenate, ParamSpec else: from typing_extensions import ParamSpec, Concatenate from uuid import UUID from mognet.exceptions.result_exceptions import ResultLost -from mognet.model.result import ResultState +from mognet.model.result_state import ResultState from mognet.primitives.request import Request from mognet.service.class_service import ClassService @@ -34,12 +37,12 @@ if TYPE_CHECKING: from mognet.app.app import App - from mognet.model.result import Result from mognet.state.state import State from mognet.worker.worker import Worker _log = logging.getLogger(__name__) +_T = TypeVar("_T") _P = ParamSpec("_P") @@ -55,14 +58,14 @@ class Context: state: "State" - request: "Request" + request: "Request[Any]" _dependencies: Set[UUID] def __init__( self, app: "App", - request: "Request", + request: "Request[Any]", state: "State", worker: "Worker", ): @@ -75,7 +78,7 @@ def __init__( self.create_request = self.app.create_request - async def submit(self, request: "Request"): + async def submit(self, request: "Request[_Return]") -> Result[_Return]: """ Submits a new request as part of this one. @@ -98,7 +101,7 @@ async def run( self, request: Callable[Concatenate["Context", _P], _Return], *args: _P.args, - **kwargs: _P.kwargs + **kwargs: _P.kwargs, ) -> _Return: """ Short-hand method for creating a Request from a function decorated with `@task`, @@ -114,7 +117,7 @@ async def run( self, request: Callable[Concatenate["Context", _P], Awaitable[_Return]], *args: _P.args, - **kwargs: _P.kwargs + **kwargs: _P.kwargs, ) -> _Return: """ Short-hand method for creating a Request from a function decorated with `@task`, @@ -124,7 +127,7 @@ async def run( """ ... - async def run(self, request, *args, **kwargs): + async def run(self, request: Any, *args: Any, **kwargs: Any) -> Any: """ Submits and runs a new request as part of this one. @@ -159,7 +162,7 @@ async def run(self, request, *args, **kwargs): if not self._dependencies and not cancelled: await asyncio.shield(self._resume()) - def _log_dependencies(self): + def _log_dependencies(self) -> None: _log.debug( "Task %r is waiting on %r dependencies", self.request, @@ -167,16 +170,21 @@ def _log_dependencies(self): ) async def gather( - self, *results_or_ids: Union["Result", UUID], return_exceptions: bool = False + self, *results_or_ids: Union[Result[Any], UUID], return_exceptions: bool = False ) -> List[Any]: - results = [] + if not results_or_ids: + return [] + + results: List[Result[Any]] = [] cancelled = False try: + result: Union[UUID, Optional[Result[Any]]] for result in results_or_ids: if isinstance(result, UUID): result = await self.app.result_backend.get(result) - results.append(result) + if result is not None: + results.append(result) # If we transition from having no dependencies # to having some, then we should suspend. @@ -188,7 +196,9 @@ async def gather( if not had_dependencies and self._dependencies: await asyncio.shield(self._suspend()) - return await asyncio.gather(*results, return_exceptions=return_exceptions) + return list( + await asyncio.gather(*results, return_exceptions=return_exceptions) + ) except asyncio.CancelledError: cancelled = True raise @@ -202,7 +212,7 @@ async def gather( @overload def get_service( - self, func: Type[ClassService[_Return]], *args, **kwargs + self, func: Type[ClassService[_Return]], *args: Any, **kwargs: Any ) -> _Return: ... @@ -211,20 +221,22 @@ def get_service( self, func: Callable[Concatenate["Context", _P], _Return], *args: _P.args, - **kwargs: _P.kwargs + **kwargs: _P.kwargs, ) -> _Return: ... - def get_service(self, func, *args, **kwargs): + def get_service(self, func: Any, *args: Any, **kwargs: Any) -> Any: """ Get a service to use in the task function. This can be used for dependency injection purposes. """ + svc: Callable[[Context], Any] + if inspect.isclass(func) and issubclass(func, ClassService): if func not in self.app.services: # This cast() is only here to silence Pylance (because it thinks the class is abstract) - instance: ClassService = cast(Any, func)(self.app.config) + instance: ClassService[Any] = cast(Any, func)(self.app.config) self.app.services[func] = instance.__enter__() svc = self.app.services[func] @@ -233,7 +245,7 @@ def get_service(self, func, *args, **kwargs): return svc(self, *args, **kwargs) - async def _suspend(self): + async def _suspend(self) -> None: _log.debug("Suspending %r", self.request) result = await self.get_result() @@ -243,7 +255,7 @@ async def _suspend(self): await self._worker.add_waiting_task(self.request.id) - async def get_result(self): + async def get_result(self) -> Result[Any]: """ Gets the Result associated with this task. @@ -274,7 +286,7 @@ def call_threadsafe(self, coro: Awaitable[_Return]) -> _Return: """ return asyncio.run_coroutine_threadsafe(coro, loop=self.app.loop).result() - async def set_metadata(self, **kwargs: Any): + async def set_metadata(self, **kwargs: Any) -> None: """ Update metadata on the Result associated with the current task. """ @@ -282,7 +294,7 @@ async def set_metadata(self, **kwargs: Any): result = await self.get_result() return await result.set_metadata(**kwargs) - async def _resume(self): + async def _resume(self) -> None: _log.debug("Resuming %r", self.request) result = await self.get_result() diff --git a/mognet/decorators/task_decorator.py b/mognet/decorators/task_decorator.py index ae2fccd..768df57 100644 --- a/mognet/decorators/task_decorator.py +++ b/mognet/decorators/task_decorator.py @@ -1,14 +1,15 @@ import logging from typing import Any, Callable, Optional, TypeVar, cast + from mognet.tasks.task_registry import TaskRegistry, task_registry _log = logging.getLogger(__name__) -_T = TypeVar("_T") +_T = TypeVar("_T", bound=Callable[..., Any]) -def task(*, name: Optional[str] = None): +def task(*, name: Optional[str] = None) -> Callable[[_T], _T]: """ Register a function as a task that can be run. @@ -29,7 +30,7 @@ def task_decorator(t: _T) -> _T: reg = TaskRegistry() reg.register_globally() - reg.add_task_function(cast(Callable, t), name=name) + reg.add_task_function(cast(Callable[..., Any], t), name=name) return t diff --git a/mognet/exceptions/result_exceptions.py b/mognet/exceptions/result_exceptions.py index 191dd6a..fadbe50 100644 --- a/mognet/exceptions/result_exceptions.py +++ b/mognet/exceptions/result_exceptions.py @@ -1,7 +1,7 @@ -from mognet.exceptions.base_exceptions import MognetError -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from uuid import UUID +from mognet.exceptions.base_exceptions import MognetError if TYPE_CHECKING: from mognet.model.result import Result @@ -16,7 +16,7 @@ class ResultNotReady(ResultError): class ResultFailed(ResultError): - def __init__(self, result: "Result") -> None: + def __init__(self, result: "Result[Any]") -> None: super().__init__(result) self.result = result diff --git a/mognet/exceptions/task_exceptions.py b/mognet/exceptions/task_exceptions.py index 5de9b39..96375a3 100644 --- a/mognet/exceptions/task_exceptions.py +++ b/mognet/exceptions/task_exceptions.py @@ -1,6 +1,7 @@ +from typing import List, Tuple, Union + from pydantic import BaseModel from pydantic.error_wrappers import ValidationError -from typing import Any, Dict, List, Tuple, Union # Taken from pydantic.error_wrappers Loc = Tuple[Union[int, str], ...] @@ -33,5 +34,7 @@ def __init__(self, errors: List[InvalidErrorInfo]) -> None: self.errors = errors @classmethod - def from_validation_error(cls, validation_error: ValidationError): + def from_validation_error( + cls, validation_error: ValidationError + ) -> "InvalidTaskArguments": return cls([InvalidErrorInfo.parse_obj(e) for e in validation_error.errors()]) diff --git a/mognet/middleware/middleware.py b/mognet/middleware/middleware.py index bfd1d2d..61c014c 100644 --- a/mognet/middleware/middleware.py +++ b/mognet/middleware/middleware.py @@ -1,20 +1,13 @@ -import sys -from typing import TYPE_CHECKING, Optional - -if sys.version_info >= (3, 10): - from typing import Protocol -else: - from typing_extensions import Protocol - +from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: + from mognet.app.app import App from mognet.context.context import Context from mognet.model.result import Result from mognet.primitives.request import Request - from mognet.app.app import App -class Middleware(Protocol): +class Middleware: """ Defines middleware that can hook into different parts of a Mognet App's lifecycle. """ @@ -47,7 +40,7 @@ async def on_app_stopped(self, app: "App") -> None: For example, you can use this for cleaning up objects that were previously set up. """ - async def on_task_starting(self, context: "Context"): + async def on_task_starting(self, context: "Context") -> None: """ Called when a task is starting. @@ -55,8 +48,8 @@ async def on_task_starting(self, context: "Context"): """ async def on_task_completed( - self, result: "Result", context: Optional["Context"] = None - ): + self, result: "Result[Any]", context: Optional["Context"] = None + ) -> None: """ Called when a task has completed it's execution. @@ -64,8 +57,8 @@ async def on_task_completed( """ async def on_request_submitting( - self, request: "Request", context: Optional["Context"] = None - ): + self, request: "Request[Any]", context: Optional["Context"] = None + ) -> None: """ Called when a Request object is going to be submitted to the Broker. @@ -73,7 +66,7 @@ async def on_request_submitting( the Request object (e.g., to modify arguments, or set metadata). """ - async def on_running_task_count_changed(self, running_task_count: int): + async def on_running_task_count_changed(self, running_task_count: int) -> None: """ Called when the Worker's task count changes. diff --git a/mognet/model/result.py b/mognet/model/result.py index 6e4713a..26d7361 100644 --- a/mognet/model/result.py +++ b/mognet/model/result.py @@ -1,41 +1,51 @@ +from __future__ import annotations + import base64 -from mognet.exceptions.result_exceptions import ResultFailed, ResultNotReady, Revoked -import pickle import importlib -import traceback import logging +import pickle +import traceback +from datetime import datetime, timedelta from typing import ( + TYPE_CHECKING, Any, - AsyncGenerator, + AsyncIterable, Dict, + Generator, + Generic, Optional, - TYPE_CHECKING, + Type, + TypeVar, ) -from .result_state import ( +from uuid import UUID + +from pydantic.fields import PrivateAttr +from pydantic.main import BaseModel +from pydantic.tools import parse_obj_as + +from mognet.exceptions.result_exceptions import ResultFailed, ResultNotReady, Revoked +from mognet.model.result_state import ( ERROR_STATES, READY_STATES, - ResultState, SUCCESS_STATES, + ResultState, ) - -from datetime import datetime, timedelta -from pydantic.fields import PrivateAttr -from pydantic.main import BaseModel from mognet.tools.dates import now_utc -from pydantic.tools import parse_obj_as -from uuid import UUID if TYPE_CHECKING: from mognet.backend.base_result_backend import BaseResultBackend + from .result_tree import ResultTree +_Return = TypeVar("_Return") + _log = logging.getLogger(__name__) class ResultChildren: """The children of a Result.""" - def __init__(self, result: "Result", backend: "BaseResultBackend") -> None: + def __init__(self, result: "Result[Any]", backend: BaseResultBackend) -> None: self._result = result self._backend = backend @@ -43,17 +53,17 @@ async def count(self) -> int: """The number of children.""" return await self._backend.get_children_count(self._result.id) - def iter_ids(self, *, count: Optional[int] = None) -> AsyncGenerator[UUID, None]: + def iter_ids(self, *, count: Optional[int] = None) -> AsyncIterable[UUID]: """Iterate the IDs of the children, optionally limited to a set count.""" return self._backend.iterate_children_ids(self._result.id, count=count) def iter_instances( self, *, count: Optional[int] = None - ) -> AsyncGenerator["Result", None]: + ) -> AsyncIterable[Result[Any]]: """Iterate the instances of the children, optionally limited to a set count.""" return self._backend.iterate_children(self._result.id, count=count) - async def add(self, *children_ids: UUID): + async def add(self, *children_ids: UUID) -> None: """For internal use.""" await self._backend.add_children(self._result.id, *children_ids) @@ -82,7 +92,7 @@ def deserialize(self) -> Any: return value @classmethod - def not_ready(cls): + def not_ready(cls) -> ResultValueHolder: """ Creates a value holder which is not ready yet. """ @@ -90,12 +100,12 @@ def not_ready(cls): return cls(value_type=_serialize_name(value), raw_value=value) -class ResultValue: +class ResultValue(Generic[_Return]): """ Represents information about the value of a Result. """ - def __init__(self, result: "Result", backend: "BaseResultBackend") -> None: + def __init__(self, result: "Result[_Return]", backend: BaseResultBackend) -> None: self._result = result self._backend = backend @@ -112,7 +122,7 @@ async def get_raw_value(self) -> Any: holder = await self.get_value_holder() return holder.deserialize() - async def set_raw_value(self, value: Any): + async def set_raw_value(self, value: Any) -> None: if isinstance(value, BaseException): value = _ExceptionInfo.from_exception(value) @@ -130,7 +140,7 @@ class _ExceptionInfo(BaseModel): raw_data_encoding: Optional[str] @classmethod - def from_exception(cls, exception: BaseException): + def from_exception(cls, exception: BaseException) -> _ExceptionInfo: try: raw_data = pickle.dumps(exception) except Exception as unencodable: # pylint: disable=broad-except @@ -154,17 +164,20 @@ def from_exception(cls, exception: BaseException): ) @property - def exception(self) -> Exception: + def exception(self) -> BaseException: if ( self.raw_data is not None and self.raw_data_encoding == "application/python-pickle" ): - return pickle.loads(base64.b64decode(self.raw_data)) + v = pickle.loads(base64.b64decode(self.raw_data)) + + if isinstance(v, BaseException): + return v return Exception(self.message) -class Result(BaseModel): +class Result(BaseModel, Generic[_Return]): """ Represents the result of executing a [`Request`][mognet.Request]. @@ -190,11 +203,11 @@ class Result(BaseModel): request_kwargs_repr: Optional[str] - _backend: "BaseResultBackend" = PrivateAttr() + _backend: BaseResultBackend = PrivateAttr() _children: Optional[ResultChildren] = PrivateAttr() - _value: Optional[ResultValue] = PrivateAttr() + _value: Optional[ResultValue[_Return]] = PrivateAttr() - def __init__(self, backend: "BaseResultBackend", **data) -> None: + def __init__(self, backend: BaseResultBackend, **data: Any) -> None: super().__init__(**data) self._backend = backend self._children = None @@ -210,7 +223,7 @@ def children(self) -> ResultChildren: return self._children @property - def value(self) -> ResultValue: + def value(self) -> ResultValue[_Return]: """Get information about the value of this Result""" if self._value is None: self._value = ResultValue(self, self._backend) @@ -242,7 +255,7 @@ def queue_time(self) -> Optional[timedelta]: return self.started - self.created @property - def done(self): + def done(self) -> bool: """ True if the result is in a terminal state (e.g., SUCCESS, FAILURE). See `READY_STATES`. @@ -250,17 +263,17 @@ def done(self): return self.state in READY_STATES @property - def successful(self): + def successful(self) -> bool: """True if the result was successful.""" return self.state in SUCCESS_STATES @property - def failed(self): + def failed(self) -> bool: """True if the result failed or was revoked.""" return self.state in ERROR_STATES @property - def revoked(self): + def revoked(self) -> bool: """True if the result was revoked.""" return self.state == ResultState.REVOKED @@ -278,7 +291,7 @@ async def wait(self, *, timeout: Optional[float] = None, poll: float = 0.1) -> N await self._refresh(updated_result) - async def revoke(self) -> "Result": + async def revoke(self) -> "Result[_Return]": """ Revoke this Result. @@ -291,7 +304,7 @@ async def revoke(self) -> "Result": await self._backend.set(self.id, self) return self - async def get(self) -> Any: + async def get(self) -> _Return: """ Gets the value of this `Result` instance. @@ -327,13 +340,13 @@ async def get(self) -> Any: raise value - return value + return value # type: ignore async def set_result( self, value: Any, state: ResultState = ResultState.SUCCESS, - ) -> "Result": + ) -> "Result[_Return]": """ Set this Result to a success state, and store the value which will be return when one `get()`s this Result's value. @@ -353,7 +366,7 @@ async def set_error( self, exc: BaseException, state: ResultState = ResultState.FAILURE, - ) -> "Result": + ) -> "Result[_Return]": """ Set this Result to an error state, and store the exception which will be raised if one attempts to `get()` this Result's @@ -373,7 +386,7 @@ async def set_error( return self - async def start(self, *, node_id: Optional[str] = None) -> "Result": + async def start(self, *, node_id: Optional[str] = None) -> "Result[_Return]": """ Sets this `Result` as RUNNING, and logs the event. """ @@ -387,7 +400,7 @@ async def start(self, *, node_id: Optional[str] = None) -> "Result": return self - async def resume(self, *, node_id: Optional[str] = None) -> "Result": + async def resume(self, *, node_id: Optional[str] = None) -> "Result[_Return]": if node_id is not None: self.node_id = node_id @@ -398,7 +411,7 @@ async def resume(self, *, node_id: Optional[str] = None) -> "Result": return self - async def suspend(self) -> "Result": + async def suspend(self) -> "Result[_Return]": """ Sets this `Result` as SUSPENDED, and logs the event. """ @@ -419,7 +432,7 @@ async def tree(self, max_depth: int = 3, max_width: int = 500) -> "ResultTree": """ from .result_tree import ResultTree - async def get_tree(result: Result, depth=1): + async def get_tree(result: Result[Any], depth: int = 1) -> ResultTree: _log.debug( "Getting tree of result id=%r, depth=%r max_depth=%r", result.id, @@ -464,7 +477,7 @@ async def set_metadata(self, **kwargs: Any) -> None: """Set metadata on this Result.""" await self._backend.set_metadata(self.id, **kwargs) - async def _refresh(self, updated_result: Optional["Result"] = None): + async def _refresh(self, updated_result: Optional["Result[Any]"] = None) -> None: updated_result = updated_result or await self._backend.get(self.id) if updated_result is None: @@ -476,10 +489,10 @@ async def _refresh(self, updated_result: Optional["Result"] = None): setattr(self, k, v) - async def _update(self): + async def _update(self) -> None: await self._backend.set(self.id, self) - def __repr__(self): + def __repr__(self) -> str: v = f"Result[{self.name or 'unknown'}, id={self.id!r}, state={self.state!r}]" if self.request_kwargs_repr is not None: @@ -491,12 +504,12 @@ def __repr__(self): def __hash__(self) -> int: return hash(f"Result_{self.id}") - def __await__(self): + def __await__(self) -> Generator[Any, None, _Return]: yield from self.wait().__await__() value = yield from self.get().__await__() return value - async def delete(self, include_children: bool = True): + async def delete(self, include_children: bool = True) -> None: """ Delete this Result from the backend. @@ -504,7 +517,7 @@ async def delete(self, include_children: bool = True): """ await self._backend.delete(self.id, include_children=include_children) - async def set_ttl(self, ttl: timedelta, include_children: bool = True): + async def set_ttl(self, ttl: timedelta, include_children: bool = True) -> None: """ Set TTL on this Result. @@ -513,16 +526,16 @@ async def set_ttl(self, ttl: timedelta, include_children: bool = True): await self._backend.set_ttl(self.id, ttl, include_children=include_children) -def _get_attr(obj_spec): +def _get_attr(obj_spec: str) -> Type[Any]: module, cls_name = obj_spec.split(":") mod = importlib.import_module(module) - cls = getattr(mod, cls_name) + cls: Type[Any] = getattr(mod, cls_name) return cls -def _serialize_name(v): +def _serialize_name(v: Any) -> str: if isinstance(v, type): return f"{v.__module__}:{v.__name__}" diff --git a/mognet/model/result_state.py b/mognet/model/result_state.py index 41a6819..174f6be 100644 --- a/mognet/model/result_state.py +++ b/mognet/model/result_state.py @@ -29,7 +29,7 @@ class ResultState(str, Enum): # Invalid task INVALID = "INVALID" - def __repr__(self): + def __repr__(self) -> str: return f"{self.name!r}" diff --git a/mognet/model/result_tree.py b/mognet/model/result_tree.py index e7ca997..8adfe6e 100644 --- a/mognet/model/result_tree.py +++ b/mognet/model/result_tree.py @@ -1,17 +1,18 @@ -from typing import List +from typing import Any, Dict, List + from pydantic import BaseModel -from .result import Result +from mognet.model.result import Result class ResultTree(BaseModel): - result: "Result" + result: Result[Any] children: List["ResultTree"] def __str__(self) -> str: return f"{self.result.name}(id={self.result.id!r}, state={self.result.state!r}, node_id={self.result.node_id!r})" - def dict(self, **kwargs): + def dict(self, **kwargs: Any) -> Dict[str, Any]: return { "id": self.result.id, "name": self.result.name, diff --git a/mognet/primitives/queries.py b/mognet/primitives/queries.py index b87a327..5503af6 100644 --- a/mognet/primitives/queries.py +++ b/mognet/primitives/queries.py @@ -1,5 +1,6 @@ +from typing import Any, List, Literal from uuid import UUID, uuid4 -from typing import List, Literal + from pydantic import BaseModel, Field @@ -13,7 +14,7 @@ class QueryResponseMessage(BaseModel): kind: str node_id: str - payload: dict + payload: Any class StatusResponseMessage(QueryResponseMessage): diff --git a/mognet/primitives/request.py b/mognet/primitives/request.py index 1bb5095..a136edf 100644 --- a/mognet/primitives/request.py +++ b/mognet/primitives/request.py @@ -1,22 +1,33 @@ -from datetime import timedelta, datetime -from typing import Any, Dict, Generic, List, Optional, TypeVar, Union +from datetime import datetime, timedelta +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Optional, + Tuple, + TypeVar, + Union, +) from uuid import UUID, uuid4 -from pydantic import conint +from pydantic import BaseModel, conint from pydantic.fields import Field -from pydantic.generics import GenericModel - TReturn = TypeVar("TReturn") -Priority = conint(ge=0, le=10) +if TYPE_CHECKING: + Priority = int +else: + Priority = conint(ge=0, le=10) -class Request(GenericModel, Generic[TReturn]): +class Request(BaseModel, Generic[TReturn]): id: UUID = Field(default_factory=uuid4) name: str - args: tuple = () + args: Tuple[Any, ...] = () kwargs: Dict[str, Any] = Field(default_factory=dict) stack: List[UUID] = Field(default_factory=list) @@ -48,7 +59,7 @@ class Request(GenericModel, Generic[TReturn]): # Task priority. The higher the value, the higher the priority. priority: Priority = 5 - def __repr__(self): + def __repr__(self) -> str: msg = f"{self.name}[id={self.id!r}]" if self.kwargs_repr is not None: diff --git a/mognet/primitives/revoke.py b/mognet/primitives/revoke.py index 380e9b3..5491401 100644 --- a/mognet/primitives/revoke.py +++ b/mognet/primitives/revoke.py @@ -1,5 +1,6 @@ from typing import ClassVar from uuid import UUID + from pydantic.main import BaseModel diff --git a/mognet/service/class_service.py b/mognet/service/class_service.py index 471baba..f43861c 100644 --- a/mognet/service/class_service.py +++ b/mognet/service/class_service.py @@ -1,6 +1,5 @@ from abc import ABCMeta, abstractmethod -from typing import TYPE_CHECKING, TypeVar, Generic - +from typing import TYPE_CHECKING, Any, Generic, TypeVar if TYPE_CHECKING: from mognet.app.app_config import AppConfig @@ -8,6 +7,8 @@ _TReturn = TypeVar("_TReturn") +_TSelf = TypeVar("_TSelf") + class ClassService(Generic[_TReturn], metaclass=ABCMeta): """ @@ -27,17 +28,17 @@ def __init__(self, config: "AppConfig") -> None: self.config = config @abstractmethod - def __call__(self, context: "Context", *args, **kwds) -> _TReturn: + def __call__(self, context: "Context", *args: Any, **kwds: Any) -> _TReturn: raise NotImplementedError - def __enter__(self): + def __enter__(self: _TSelf) -> _TSelf: return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() - def close(self): + def close(self) -> None: pass - async def wait_closed(self): - pass \ No newline at end of file + async def wait_closed(self) -> None: + pass diff --git a/mognet/state/base_state_backend.py b/mognet/state/base_state_backend.py index 25a4caa..a7cee5b 100644 --- a/mognet/state/base_state_backend.py +++ b/mognet/state/base_state_backend.py @@ -8,34 +8,34 @@ class BaseStateBackend(metaclass=ABCMeta): @abstractmethod async def get( - self, request_id: UUID, key: str, default: _TValue = None + self, request_id: UUID, key: str, default: Optional[_TValue] = None ) -> Optional[_TValue]: raise NotImplementedError @abstractmethod - async def set(self, request_id: UUID, key: str, value: Any): + async def set(self, request_id: UUID, key: str, value: Any) -> None: raise NotImplementedError @abstractmethod async def pop( - self, request_id: UUID, key: str, default: _TValue = None + self, request_id: UUID, key: str, default: Optional[_TValue] = None ) -> Optional[_TValue]: raise NotImplementedError @abstractmethod - async def clear(self, request_id: UUID): + async def clear(self, request_id: UUID) -> None: raise NotImplementedError - async def __aenter__(self): + async def __aenter__(self): # type: ignore return self - async def __aexit__(self, *args, **kwargs): + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: return None @abstractmethod - async def connect(self): + async def connect(self) -> None: raise NotImplementedError @abstractmethod - async def close(self): + async def close(self) -> None: raise NotImplementedError diff --git a/mognet/state/redis_state_backend.py b/mognet/state/redis_state_backend.py index c840cf5..1139a3f 100644 --- a/mognet/state/redis_state_backend.py +++ b/mognet/state/redis_state_backend.py @@ -1,19 +1,32 @@ +from __future__ import annotations + import json import logging -from typing import TYPE_CHECKING, Any, Optional, TypeVar +import sys +from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast +from uuid import UUID -from redis.asyncio import Redis, from_url +from redis.asyncio import from_url from mognet.exceptions.base_exceptions import NotConnected from mognet.state.base_state_backend import BaseStateBackend from mognet.state.state_backend_config import StateBackendConfig from mognet.tools.urls import censor_credentials +if sys.version_info < (3, 10): + from typing_extensions import TypeAlias +else: + from typing import TypeAlias + + if TYPE_CHECKING: + from redis.asyncio import Redis # noqa: F401 + from mognet.app.app import App _TValue = TypeVar("_TValue") +_Redis: TypeAlias = "Redis[Any]" _log = logging.getLogger(__name__) @@ -23,26 +36,26 @@ def __init__(self, config: StateBackendConfig, app: "App") -> None: super().__init__() self.config = config - self.__redis = None + self.__redis: Optional[_Redis] = None self.app = app @property - def _redis(self) -> Redis: + def _redis(self) -> _Redis: if self.__redis is None: raise NotConnected return self.__redis async def get( - self, request_id: str, key: str, default: _TValue = None + self, request_id: UUID, key: str, default: Optional[_TValue] = None ) -> Optional[_TValue]: state_key = self._format_key(request_id) async with self._redis.pipeline(transaction=True) as tr: - tr.hexists(state_key, key) - tr.hget(state_key, key) - tr.expire(state_key, self.config.redis.state_ttl) + _ = tr.hexists(state_key, key) + _ = tr.hget(state_key, key) + _ = tr.expire(state_key, self.config.redis.state_ttl) exists, value, *_ = await tr.execute() @@ -54,29 +67,29 @@ async def get( ) return default - return json.loads(value) + return cast(_TValue, json.loads(value)) - async def set(self, request_id: str, key: str, value: Any): + async def set(self, request_id: UUID, key: str, value: Any) -> None: state_key = self._format_key(request_id) async with self._redis.pipeline(transaction=True) as tr: - tr.hset(state_key, key, json.dumps(value).encode()) - tr.expire(state_key, self.config.redis.state_ttl) + _ = tr.hset(state_key, key, json.dumps(value).encode()) + _ = tr.expire(state_key, self.config.redis.state_ttl) await tr.execute() async def pop( - self, request_id: str, key: str, default: _TValue = None + self, request_id: UUID, key: str, default: Optional[_TValue] = None ) -> Optional[_TValue]: state_key = self._format_key(request_id) async with self._redis.pipeline(transaction=True) as tr: - tr.hexists(state_key, key) - tr.hget(state_key, key) - tr.hdel(state_key, key) - tr.expire(state_key, self.config.redis.state_ttl) + _ = tr.hexists(state_key, key) + _ = tr.hget(state_key, key) + _ = tr.hdel(state_key, key) + _ = tr.expire(state_key, self.config.redis.state_ttl) exists, value, *_ = await tr.execute() @@ -88,42 +101,42 @@ async def pop( ) return default - return json.loads(value) + return cast(_TValue, json.loads(value)) - async def clear(self, request_id: str): + async def clear(self, request_id: UUID) -> None: state_key = self._format_key(request_id) _log.debug("Clearing state of id=%r", state_key) - return await self._redis.delete(state_key) + await self._redis.delete(state_key) - def _format_key(self, result_id: str) -> str: + def _format_key(self, result_id: UUID) -> str: key = f"{self.app.name}.mognet.state.{result_id}" _log.debug("Formatted state key=%r for id=%r", key, result_id) return key - async def __aenter__(self): + async def __aenter__(self) -> RedisStateBackend: await self.connect() return self - async def __aexit__(self, *args, **kwargs): + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: await self.close() - async def connect(self): - redis: Redis = from_url( + async def connect(self) -> None: + redis: _Redis = from_url( self.config.redis.url, max_connections=self.config.redis.max_connections, ) self.__redis = redis - async def close(self): + async def close(self) -> None: redis = self.__redis if redis is not None: self.__redis = None await redis.close() - def __repr__(self): + def __repr__(self) -> str: return f"RedisStateBackend(url={censor_credentials(self.config.redis.url)!r})" diff --git a/mognet/state/state.py b/mognet/state/state.py index 0ca5196..dad9082 100644 --- a/mognet/state/state.py +++ b/mognet/state/state.py @@ -1,9 +1,11 @@ -from typing import Any, TYPE_CHECKING -from uuid import UUID +from __future__ import annotations +from typing import TYPE_CHECKING, Any +from uuid import UUID if TYPE_CHECKING: from mognet import App + from mognet.state.base_state_backend import BaseStateBackend class State: @@ -22,14 +24,14 @@ def __init__(self, app: "App", request_id: UUID) -> None: self.request_id = request_id @property - def _backend(self): + def _backend(self) -> BaseStateBackend: return self._app.state_backend async def get(self, key: str, default: Any = None) -> Any: """Get a value.""" return await self._backend.get(self.request_id, key, default) - async def set(self, key: str, value: Any): + async def set(self, key: str, value: Any) -> None: """Set a value.""" return await self._backend.set(self.request_id, key, value) @@ -37,6 +39,6 @@ async def pop(self, key: str, default: Any = None) -> Any: """Delete a value from the state and return it's value.""" return await self._backend.pop(self.request_id, key, default) - async def clear(self): + async def clear(self) -> None: """Clear all values.""" return await self._backend.clear(self.request_id) diff --git a/mognet/state/state_backend_config.py b/mognet/state/state_backend_config.py index a77c649..4a79de9 100644 --- a/mognet/state/state_backend_config.py +++ b/mognet/state/state_backend_config.py @@ -1,4 +1,5 @@ from typing import Optional + from pydantic import BaseModel diff --git a/mognet/tasks/task_registry.py b/mognet/tasks/task_registry.py index d5df0c1..71a2fea 100644 --- a/mognet/tasks/task_registry.py +++ b/mognet/tasks/task_registry.py @@ -1,10 +1,8 @@ -from contextvars import ContextVar import logging +from contextvars import ContextVar +from typing import Any, Dict, List, Optional, Protocol from mognet.context.context import Context -from typing import Any, Callable, Dict, List, Optional - -from typing import Protocol _log = logging.getLogger(__name__) @@ -21,7 +19,7 @@ def __init__(self, task_name: str) -> None: super().__init__(task_name) self.task_name = task_name - def __str__(self): + def __str__(self) -> str: return f"Unknown task: {self.task_name!r}; did you forget to register it, or import it's module in the app's configuration?" @@ -30,7 +28,7 @@ class TaskRegistry: _names_to_tasks: Dict[str, TaskProtocol] _tasks_to_names: Dict[TaskProtocol, str] - def __init__(self): + def __init__(self) -> None: self._names_to_tasks = {} self._tasks_to_names = {} @@ -47,7 +45,9 @@ def get_task_name(self, func: TaskProtocol) -> str: def registered_task_names(self) -> List[str]: return list(self._names_to_tasks) - def add_task_function(self, func: Callable, *, name: Optional[str] = None): + def add_task_function( + self, func: TaskProtocol, *, name: Optional[str] = None + ) -> None: full_func_name = _get_full_func_name(func) if name is None: @@ -63,11 +63,11 @@ def add_task_function(self, func: Callable, *, name: Optional[str] = None): _log.debug("Registered function %r as task %r", full_func_name, name) - def register_globally(self): + def register_globally(self) -> None: task_registry.set(self) -def _get_full_func_name(func) -> str: +def _get_full_func_name(func: Any) -> str: func_name = getattr(func, "__qualname__", None) or getattr(func, "__name__", None) full_func_name = ".".join( diff --git a/mognet/testing/pytest_integration.py b/mognet/testing/pytest_integration.py index 4e6793e..fb4aa4f 100644 --- a/mognet/testing/pytest_integration.py +++ b/mognet/testing/pytest_integration.py @@ -1,14 +1,15 @@ import asyncio -import pytest +from typing import Any, AsyncIterable + import pytest_asyncio + from mognet import App -def create_app_fixture(app: App): +def create_app_fixture(app: App) -> Any: """Create a Pytest fixture for a Mognet application.""" - @pytest_asyncio.fixture - async def app_fixture(): + async def app_fixture() -> AsyncIterable[App]: async with app: start_task = asyncio.create_task(app.start()) yield app @@ -20,4 +21,4 @@ async def app_fixture(): except BaseException: # pylint: disable=broad-except pass - return app_fixture + return pytest_asyncio.fixture(app_fixture) diff --git a/mognet/tools/backports/aioitertools.py b/mognet/tools/backports/aioitertools.py index 50d6350..9dfcbed 100644 --- a/mognet/tools/backports/aioitertools.py +++ b/mognet/tools/backports/aioitertools.py @@ -6,7 +6,7 @@ import asyncio from contextlib import suppress -from typing import AsyncIterable, Iterable, TypeVar, AsyncGenerator +from typing import Any, AsyncGenerator, AsyncIterable, Dict, Iterable, TypeVar T = TypeVar("T") @@ -34,7 +34,7 @@ async def generator(x): ... # intermixed values yielded from gen1 and gen2 """ - queue: asyncio.Queue[dict] = asyncio.Queue() + queue: asyncio.Queue[Dict[Any, Any]] = asyncio.Queue() tailer_count: int = 0 diff --git a/mognet/tools/kwargs_repr.py b/mognet/tools/kwargs_repr.py index e9569de..4783161 100644 --- a/mognet/tools/kwargs_repr.py +++ b/mognet/tools/kwargs_repr.py @@ -1,4 +1,4 @@ -from typing import Optional, Any +from typing import Any, Dict, Optional, Tuple def _format_value(value: Any, *, max_length: Optional[int]) -> str: @@ -11,8 +11,8 @@ def _format_value(value: Any, *, max_length: Optional[int]) -> str: def format_kwargs_repr( - args: tuple, - kwargs: dict, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], *, value_max_length: Optional[int] = 64, ) -> str: diff --git a/mognet/tools/retries.py b/mognet/tools/retries.py index 11765f9..af70203 100644 --- a/mognet/tools/retries.py +++ b/mognet/tools/retries.py @@ -1,26 +1,15 @@ import asyncio +import inspect import logging from functools import wraps -import inspect -from typing import ( - Any, - Awaitable, - Callable, - NamedTuple, - Tuple, - Type, - TypeVar, - Union, - cast, -) - +from typing import Any, Awaitable, Callable, Optional, Tuple, Type, TypeVar, Union, cast _T = TypeVar("_T") _log = logging.getLogger(__name__) -async def _noop(*args, **kwargs): +async def _noop(*args: Any, **kwargs: Any) -> None: pass @@ -29,9 +18,9 @@ def retryableasyncmethod( *, max_attempts: Union[int, str], wait_timeout: Union[float, str], - lock: Union[asyncio.Lock, str] = None, - on_retry: Union[Callable[[BaseException], Awaitable], str] = None, -): + lock: Optional[Union[asyncio.Lock, str]] = None, + on_retry: Optional[Union[Callable[[BaseException], Awaitable[Any]], str]] = None, +) -> Callable[[_T], _T]: """ Decorator to wrap an async method and make it retryable. """ @@ -43,10 +32,12 @@ def make_retryable(func: _T) -> _T: f: Any = cast(Any, func) @wraps(f) - async def async_retryable_decorator(self, *args, **kwargs): + async def async_retryable_decorator( + self: Any, *args: Any, **kwargs: Any + ) -> Any: last_exc = None - retry = _noop + retry: Callable[..., Any] = _noop if isinstance(on_retry, str): retry = getattr(self, on_retry) elif callable(on_retry): diff --git a/mognet/worker/worker.py b/mognet/worker/worker.py index 96ab4a3..d560a51 100644 --- a/mognet/worker/worker.py +++ b/mognet/worker/worker.py @@ -1,32 +1,27 @@ -from datetime import datetime, timedelta -from enum import Enum -import inspect +from __future__ import annotations -from mognet.exceptions.task_exceptions import InvalidTaskArguments, Pause import asyncio +import inspect import logging from asyncio.futures import Future -from mognet.broker.base_broker import IncomingMessagePayload -from typing import ( - AsyncGenerator, - Optional, - Set, - TYPE_CHECKING, - Dict, - List, -) +from datetime import datetime, timedelta +from enum import Enum +from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Set from uuid import UUID -from mognet.tools.backports.aioitertools import as_generated +from pydantic import ValidationError +from pydantic.decorator import ValidatedFunction + from mognet.broker.base_broker import IncomingMessagePayload from mognet.context.context import Context +from mognet.exceptions.task_exceptions import InvalidTaskArguments, Pause from mognet.exceptions.too_many_retries import TooManyRetries -from mognet.model.result import Result, ResultState +from mognet.model.result import Result +from mognet.model.result_state import ResultState +from mognet.primitives.request import Request from mognet.state.state import State from mognet.tasks.task_registry import UnknownTask -from pydantic import ValidationError -from pydantic.decorator import ValidatedFunction -from mognet.primitives.request import Request +from mognet.tools.backports.aioitertools import as_generated if TYPE_CHECKING: from mognet.app.app import App @@ -47,7 +42,7 @@ class Worker: loop, for the task queues that are configured. """ - running_tasks: Dict[UUID, "_RequestProcessorHolder"] + running_tasks: Dict[UUID, _RequestProcessorHolder] # Set of tasks that are suspended _waiting_tasks: Set[UUID] @@ -58,7 +53,7 @@ def __init__( self, *, app: "App", - middleware: List["Middleware"] = None, + middleware: Optional[List["Middleware"]] = None, ) -> None: self.app = app self.running_tasks = {} @@ -67,10 +62,12 @@ def __init__( self._current_prefetch = 1 - self._queue_consumption_tasks: List[AsyncGenerator] = [] - self._consume_task = None + self._queue_consumption_tasks: List[ + AsyncGenerator[IncomingMessagePayload, None] + ] = [] + self._consume_task: Optional[asyncio.Task[None]] = None - async def run(self): + async def run(self) -> None: _log.debug("Starting worker") try: @@ -83,13 +80,17 @@ async def run(self): except Exception as exc: # pylint: disable=broad-except _log.error("Error during consumption", exc_info=exc) - async def _handle_connection_lost(self, exc: BaseException = None): + async def _handle_connection_lost( + self, exc: Optional[BaseException] = None + ) -> None: _log.error("Handling connection lost event, stopping all tasks", exc_info=exc) # No point in NACKing, because we have been disconnected await self._cancel_all_tasks(message_action=MessageCancellationAction.NOTHING) - async def _cancel_all_tasks(self, *, message_action: MessageCancellationAction): + async def _cancel_all_tasks( + self, *, message_action: MessageCancellationAction + ) -> None: all_req_ids = list(self.running_tasks) _log.debug("Cancelling all %r running tasks", len(all_req_ids)) @@ -102,7 +103,7 @@ async def _cancel_all_tasks(self, *, message_action: MessageCancellationAction): finally: self._waiting_tasks.clear() - async def stop_consuming(self): + async def stop_consuming(self) -> None: _log.debug("Closing queue consumption tasks") consumers = self._queue_consumption_tasks @@ -132,7 +133,7 @@ async def stop_consuming(self): except Exception as consume_err: # pylint: disable=broad-except _log.error("Error shutting down consumer task", exc_info=consume_err) - async def close(self): + async def close(self) -> None: """ Stops execution, cancelling all running tasks. """ @@ -146,18 +147,20 @@ async def close(self): _log.debug("Closed worker") - def _remove_running_task(self, req_id: UUID): + def _remove_running_task(self, req_id: UUID) -> Optional[_RequestProcessorHolder]: fut = self.running_tasks.pop(req_id, None) asyncio.create_task(self._emit_running_task_count_change()) return fut - def _add_running_task(self, req_id: UUID, holder: "_RequestProcessorHolder"): + def _add_running_task(self, req_id: UUID, holder: _RequestProcessorHolder) -> None: self.running_tasks[req_id] = holder asyncio.create_task(self._emit_running_task_count_change()) - async def cancel(self, req_id: UUID, *, message_action: MessageCancellationAction): + async def cancel( + self, req_id: UUID, *, message_action: MessageCancellationAction + ) -> None: """ Cancels, if any, the execution of a request. Whoever calls this method is responsible for updating the result on the backend @@ -197,7 +200,7 @@ async def cancel(self, req_id: UUID, *, message_action: MessageCancellationActio _log.debug("Stopped handler of task id=%r", req_id) - def _create_context(self, request: "Request") -> "Context": + def _create_context(self, request: "Request[Any]") -> "Context": if not self.app.state_backend: raise RuntimeError("No state backend defined") @@ -208,7 +211,7 @@ def _create_context(self, request: "Request") -> "Context": self, ) - async def _run_request(self, req: Request) -> None: + async def _run_request(self, req: Request[Any]) -> None: """ Processes a request, validating it before running. """ @@ -219,9 +222,9 @@ async def _run_request(self, req: Request) -> None: # for cases when a request is cancelled before it's started. # Even worse, check that we're not trying to start a request whose # result might have been evicted. - result = await self.app.result_backend.get(req.id) + res = await self.app.result_backend.get(req.id) - if result is None: + if res is None: _log.error( "Attempting to run task %r, but it's result doesn't exist on the backend. Discarding", req, @@ -229,6 +232,9 @@ async def _run_request(self, req: Request) -> None: await self.remove_suspended_task(req.id) return + # Shut up, mypy. 'res' cannot be None after this point. + result: Result[Any] = res + context = self._create_context(req) if result.done: @@ -389,7 +395,7 @@ async def _run_request(self, req: Request) -> None: value = await fut if req.id in self.running_tasks: - await asyncio.shield(result.set_result(value)) + result = await asyncio.shield(result.set_result(value)) _log.info( "Request %r finished with status %r in %.2fs", @@ -447,7 +453,7 @@ async def _run_request(self, req: Request) -> None: exc_info=exc, ) - async def _on_complete(self, context: "Context", result: Result): + async def _on_complete(self, context: "Context", result: Result[Any]) -> None: if result.done: await context.state.clear() @@ -462,7 +468,7 @@ async def _on_complete(self, context: "Context", result: Result): except Exception as mw_exc: # pylint: disable=broad-except _log.error("Middleware %r failed", middleware, exc_info=mw_exc) - async def _on_starting(self, context: "Context"): + async def _on_starting(self, context: "Context") -> None: _log.info("Starting task %r", context.request) for middleware in self._middleware: @@ -472,7 +478,9 @@ async def _on_starting(self, context: "Context"): except Exception as mw_exc: # pylint: disable=broad-except _log.error("Middleware %r failed", middleware, exc_info=mw_exc) - def _process_request_message(self, payload: IncomingMessagePayload) -> asyncio.Task: + def _process_request_message( + self, payload: IncomingMessagePayload + ) -> asyncio.Task[None]: """ Creates an asyncio.Task which will process the enclosed Request in the background. @@ -480,9 +488,9 @@ def _process_request_message(self, payload: IncomingMessagePayload) -> asyncio.T Returns said task, after adding completion handlers to it. """ _log.debug("Parsing input of message id=%r as Request", payload.id) - req = Request.parse_obj(payload.payload) + req: Request[Any] = Request.parse_obj(payload.payload) - async def request_processor(): + async def request_processor() -> None: try: await self._run_request(req) @@ -500,7 +508,7 @@ async def request_processor(): ) await asyncio.shield(payload.nack()) - def on_processing_done(fut: Future): + def on_processing_done(fut: "Future[Any]") -> None: self._remove_running_task(req.id) exc = fut.exception() @@ -519,7 +527,7 @@ def on_processing_done(fut: Future): return task - def start_consuming(self): + def start_consuming(self) -> asyncio.Task[None]: if self._consume_task is not None: return self._consume_task @@ -527,7 +535,7 @@ def start_consuming(self): return self._consume_task - async def _start_consuming(self): + async def _start_consuming(self) -> None: queues = self.app.get_task_queue_names() @@ -560,11 +568,11 @@ async def _start_consuming(self): finally: _log.debug("Stopped consuming task queues") - async def add_waiting_task(self, task_id: UUID): + async def add_waiting_task(self, task_id: UUID) -> None: self._waiting_tasks.add(task_id) await self._adjust_prefetch() - async def remove_suspended_task(self, task_id: UUID): + async def remove_suspended_task(self, task_id: UUID) -> None: try: self._waiting_tasks.remove(task_id) except KeyError: @@ -572,10 +580,10 @@ async def remove_suspended_task(self, task_id: UUID): await self._adjust_prefetch() @property - def waiting_task_count(self): + def waiting_task_count(self) -> int: return len(self._waiting_tasks) - async def _emit_running_task_count_change(self): + async def _emit_running_task_count_change(self) -> None: for middleware in self._middleware: try: _log.debug("Calling 'on_running_task_count_changed' on %r", middleware) @@ -587,7 +595,7 @@ async def _emit_running_task_count_change(self): exc_info=mw_exc, ) - async def _adjust_prefetch(self): + async def _adjust_prefetch(self) -> None: if self._consume_task is None: _log.debug("Not adjusting prefetch because not consuming the queue") return @@ -644,14 +652,14 @@ class _RequestProcessorHolder: def __init__( self, incoming_message: IncomingMessagePayload, - request: Request, - task: asyncio.Task, + request: Request[Any], + task: asyncio.Task[Any], ) -> None: self.message = incoming_message self.request = request - self._task: Optional[asyncio.Task] = task + self._task: Optional[asyncio.Task[Any]] = task - async def cancel(self, *, message_action: MessageCancellationAction): + async def cancel(self, *, message_action: MessageCancellationAction) -> None: try: if message_action == MessageCancellationAction.ACK: diff --git a/poetry.lock b/poetry.lock index c0c5c85..0529707 100644 --- a/poetry.lock +++ b/poetry.lock @@ -136,14 +136,14 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [[package]] name = "coverage" -version = "6.3.3" +version = "6.4.2" description = "Code coverage measurement for Python" category = "dev" optional = false python-versions = ">=3.7" [package.dependencies] -tomli = {version = "*", optional = true, markers = "extra == \"toml\""} +tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} [package.extras] toml = ["tomli"] @@ -173,6 +173,34 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, [package.extras] graph = ["objgraph (>=1.7.2)"] +[[package]] +name = "flake8" +version = "4.0.1" +description = "the modular source code checker: pep8 pyflakes and co" +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +mccabe = ">=0.6.0,<0.7.0" +pycodestyle = ">=2.8.0,<2.9.0" +pyflakes = ">=2.4.0,<2.5.0" + +[[package]] +name = "flake8-bugbear" +version = "22.7.1" +description = "A plugin for flake8 finding likely bugs and design problems in your program. Contains warnings that don't belong in pyflakes and pycodestyle." +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +attrs = ">=19.2.0" +flake8 = ">=3.0.0" + +[package.extras] +dev = ["coverage", "hypothesis", "hypothesmith (>=0.2)", "pre-commit"] + [[package]] name = "future" version = "0.18.2" @@ -295,11 +323,11 @@ python-versions = ">=3.7" [[package]] name = "mccabe" -version = "0.7.0" +version = "0.6.1" description = "McCabe checker, plugin for flake8" category = "dev" optional = false -python-versions = ">=3.6" +python-versions = "*" [[package]] name = "mergedeep" @@ -425,6 +453,24 @@ category = "main" optional = false python-versions = ">=3.7" +[[package]] +name = "mypy" +version = "0.961" +description = "Optional static typing for Python" +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +mypy-extensions = ">=0.4.3" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = ">=3.10" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +python2 = ["typed-ast (>=1.4.0,<2)"] +reports = ["lxml"] + [[package]] name = "mypy-extensions" version = "0.4.3" @@ -495,6 +541,14 @@ category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +[[package]] +name = "pycodestyle" +version = "2.8.0" +description = "Python style guide checker" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" + [[package]] name = "pydantic" version = "1.9.1" @@ -510,6 +564,14 @@ typing-extensions = ">=3.7.4.3" dotenv = ["python-dotenv (>=0.10.4)"] email = ["email-validator (>=1.0.3)"] +[[package]] +name = "pyflakes" +version = "2.4.0" +description = "passive checker of Python programs" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + [[package]] name = "pygments" version = "2.12.0" @@ -657,7 +719,7 @@ pyyaml = "*" [[package]] name = "redis" -version = "4.3.1" +version = "4.3.4" description = "Python client for Redis database and key-value store" category = "main" optional = false @@ -728,6 +790,22 @@ dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)"] doc = ["mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "mdx-include (>=1.4.1,<2.0.0)"] test = ["shellingham (>=1.3.0,<2.0.0)", "pytest (>=4.4.0,<5.4.0)", "pytest-cov (>=2.10.0,<3.0.0)", "coverage (>=5.2,<6.0)", "pytest-xdist (>=1.32.0,<2.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "mypy (==0.910)", "black (>=22.3.0,<23.0.0)", "isort (>=5.0.6,<6.0.0)"] +[[package]] +name = "types-redis" +version = "4.3.4" +description = "Typing stubs for redis" +category = "dev" +optional = false +python-versions = "*" + +[[package]] +name = "types-tabulate" +version = "0.8.11" +description = "Typing stubs for tabulate" +category = "dev" +optional = false +python-versions = "*" + [[package]] name = "typing-extensions" version = "4.2.0" @@ -782,7 +860,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "bdf23376d830298b4ecb4eae9cba31205922455bb0717246fb93b8178d077d16" +content-hash = "57196e45c63b185c8c20bffa358f4a2d55ff81de92c6e6f71215b64f91bd7260" [metadata.files] aio-pika = [ @@ -850,49 +928,7 @@ colorama = [ {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, ] -coverage = [ - {file = "coverage-6.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df32ee0f4935a101e4b9a5f07b617d884a531ed5666671ff6ac66d2e8e8246d8"}, - {file = "coverage-6.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:75b5dbffc334e0beb4f6c503fb95e6d422770fd2d1b40a64898ea26d6c02742d"}, - {file = "coverage-6.3.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:114944e6061b68a801c5da5427b9173a0dd9d32cd5fcc18a13de90352843737d"}, - {file = "coverage-6.3.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2ab88a01cd180b5640ccc9c47232e31924d5f9967ab7edd7e5c91c68eee47a69"}, - {file = "coverage-6.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad8f9068f5972a46d50fe5f32c09d6ee11da69c560fcb1b4c3baea246ca4109b"}, - {file = "coverage-6.3.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4cd696aa712e6cd16898d63cf66139dc70d998f8121ab558f0e1936396dbc579"}, - {file = "coverage-6.3.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c1a9942e282cc9d3ed522cd3e3cab081149b27ea3bda72d6f61f84eaf88c1a63"}, - {file = "coverage-6.3.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c06455121a089252b5943ea682187a4e0a5cf0a3fb980eb8e7ce394b144430a9"}, - {file = "coverage-6.3.3-cp310-cp310-win32.whl", hash = "sha256:cb5311d6ccbd22578c80028c5e292a7ab9adb91bd62c1982087fad75abe2e63d"}, - {file = "coverage-6.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:6d4a6f30f611e657495cc81a07ff7aa8cd949144e7667c5d3e680d73ba7a70e4"}, - {file = "coverage-6.3.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:79bf405432428e989cad7b8bc60581963238f7645ae8a404f5dce90236cc0293"}, - {file = "coverage-6.3.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:338c417613f15596af9eb7a39353b60abec9d8ce1080aedba5ecee6a5d85f8d3"}, - {file = "coverage-6.3.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:db094a6a4ae6329ed322a8973f83630b12715654c197dd392410400a5bfa1a73"}, - {file = "coverage-6.3.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1414e8b124611bf4df8d77215bd32cba6e3425da8ce9c1f1046149615e3a9a31"}, - {file = "coverage-6.3.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:93b16b08f94c92cab88073ffd185070cdcb29f1b98df8b28e6649145b7f2c90d"}, - {file = "coverage-6.3.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:fbc86ae8cc129c801e7baaafe3addf3c8d49c9c1597c44bdf2d78139707c3c62"}, - {file = "coverage-6.3.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:b5ba058610e8289a07db2a57bce45a1793ec0d3d11db28c047aae2aa1a832572"}, - {file = "coverage-6.3.3-cp37-cp37m-win32.whl", hash = "sha256:8329635c0781927a2c6ae068461e19674c564e05b86736ab8eb29c420ee7dc20"}, - {file = "coverage-6.3.3-cp37-cp37m-win_amd64.whl", hash = "sha256:e5af1feee71099ae2e3b086ec04f57f9950e1be9ecf6c420696fea7977b84738"}, - {file = "coverage-6.3.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e814a4a5a1d95223b08cdb0f4f57029e8eab22ffdbae2f97107aeef28554517e"}, - {file = "coverage-6.3.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:61f4fbf3633cb0713437291b8848634ea97f89c7e849c2be17a665611e433f53"}, - {file = "coverage-6.3.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3401b0d2ed9f726fadbfa35102e00d1b3547b73772a1de5508ef3bdbcb36afe7"}, - {file = "coverage-6.3.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8586b177b4407f988731eb7f41967415b2197f35e2a6ee1a9b9b561f6323c8e9"}, - {file = "coverage-6.3.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:892e7fe32191960da559a14536768a62e83e87bbb867e1b9c643e7e0fbce2579"}, - {file = "coverage-6.3.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:afb03f981fadb5aed1ac6e3dd34f0488e1a0875623d557b6fad09b97a942b38a"}, - {file = "coverage-6.3.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:cbe91bc84be4e5ef0b1480d15c7b18e29c73bdfa33e07d3725da7d18e1b0aff2"}, - {file = "coverage-6.3.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:91502bf27cbd5c83c95cfea291ef387469f2387508645602e1ca0fd8a4ba7548"}, - {file = "coverage-6.3.3-cp38-cp38-win32.whl", hash = "sha256:c488db059848702aff30aa1d90ef87928d4e72e4f00717343800546fdbff0a94"}, - {file = "coverage-6.3.3-cp38-cp38-win_amd64.whl", hash = "sha256:ceb6534fcdfb5c503affb6b1130db7b5bfc8a0f77fa34880146f7a5c117987d0"}, - {file = "coverage-6.3.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cc692c9ee18f0dd3214843779ba6b275ee4bb9b9a5745ba64265bce911aefd1a"}, - {file = "coverage-6.3.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:462105283de203df8de58a68c1bb4ba2a8a164097c2379f664fa81d6baf94b81"}, - {file = "coverage-6.3.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc972d829ad5ef4d4c5fcabd2bbe2add84ce8236f64ba1c0c72185da3a273130"}, - {file = "coverage-6.3.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:06f54765cdbce99901871d50fe9f41d58213f18e98b170a30ca34f47de7dd5e8"}, - {file = "coverage-6.3.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7835f76a081787f0ca62a53504361b3869840a1620049b56d803a8cb3a9eeea3"}, - {file = "coverage-6.3.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:6f5fee77ec3384b934797f1873758f796dfb4f167e1296dc00f8b2e023ce6ee9"}, - {file = "coverage-6.3.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:baa8be8aba3dd1e976e68677be68a960a633a6d44c325757aefaa4d66175050f"}, - {file = "coverage-6.3.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4d06380e777dd6b35ee936f333d55b53dc4a8271036ff884c909cf6e94be8b6c"}, - {file = "coverage-6.3.3-cp39-cp39-win32.whl", hash = "sha256:f8cabc5fd0091976ab7b020f5708335033e422de25e20ddf9416bdce2b7e07d8"}, - {file = "coverage-6.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:9c9441d57b0963cf8340268ad62fc83de61f1613034b79c2b1053046af0c5284"}, - {file = "coverage-6.3.3-pp36.pp37.pp38-none-any.whl", hash = "sha256:d522f1dc49127eab0bfbba4e90fa068ecff0899bbf61bf4065c790ddd6c177fe"}, - {file = "coverage-6.3.3.tar.gz", hash = "sha256:2781c43bffbbec2b8867376d4d61916f5e9c4cc168232528562a61d1b4b01879"}, -] +coverage = [] deprecated = [ {file = "Deprecated-1.2.13-py2.py3-none-any.whl", hash = "sha256:64756e3e14c8c5eea9795d93c524551432a0be75629f8f29e67ab8caf076c76d"}, {file = "Deprecated-1.2.13.tar.gz", hash = "sha256:43ac5335da90c31c24ba028af536a91d41d53f9e6901ddb021bcc572ce44e38d"}, @@ -901,6 +937,8 @@ dill = [ {file = "dill-0.3.5.1-py2.py3-none-any.whl", hash = "sha256:33501d03270bbe410c72639b350e941882a8b0fd55357580fbc873fba0c59302"}, {file = "dill-0.3.5.1.tar.gz", hash = "sha256:d75e41f3eff1eee599d738e76ba8f4ad98ea229db8b085318aa2b3333a208c86"}, ] +flake8 = [] +flake8-bugbear = [] future = [ {file = "future-0.18.2.tar.gz", hash = "sha256:b1bead90b70cf6ec3f0710ae53a525360fa360d306a86583adc6bf83a4db537d"}, ] @@ -1057,8 +1095,8 @@ markupsafe = [ {file = "MarkupSafe-2.1.1.tar.gz", hash = "sha256:7f91197cc9e48f989d12e4e6fbc46495c446636dfc81b9ccf50bb0ec74b91d4b"}, ] mccabe = [ - {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, - {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, + {file = "mccabe-0.6.1-py2.py3-none-any.whl", hash = "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42"}, + {file = "mccabe-0.6.1.tar.gz", hash = "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"}, ] mergedeep = [ {file = "mergedeep-1.3.4-py3-none-any.whl", hash = "sha256:70775750742b25c0d8f36c55aed03d24c3384d17c951b3175d898bd778ef0307"}, @@ -1153,6 +1191,7 @@ multidict = [ {file = "multidict-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:4bae31803d708f6f15fd98be6a6ac0b6958fcf68fda3c77a048a4f9073704aae"}, {file = "multidict-6.0.2.tar.gz", hash = "sha256:5ff3bd75f38e4c43f1f470f2df7a4d430b821c4ce22be384e1459cb57d6bb013"}, ] +mypy = [] mypy-extensions = [ {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, @@ -1181,6 +1220,7 @@ py = [ {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"}, {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, ] +pycodestyle = [] pydantic = [ {file = "pydantic-1.9.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c8098a724c2784bf03e8070993f6d46aa2eeca031f8d8a048dff277703e6e193"}, {file = "pydantic-1.9.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c320c64dd876e45254bdd350f0179da737463eea41c43bacbee9d8c9d1021f11"}, @@ -1218,6 +1258,7 @@ pydantic = [ {file = "pydantic-1.9.1-py3-none-any.whl", hash = "sha256:4988c0f13c42bfa9ddd2fe2f569c9d54646ce84adc5de84228cfe83396f3bd58"}, {file = "pydantic-1.9.1.tar.gz", hash = "sha256:1ed987c3ff29fff7fd8c3ea3a3ea877ad310aae2ef9889a119e22d3f2db0691a"}, ] +pyflakes = [] pygments = [ {file = "Pygments-2.12.0-py3-none-any.whl", hash = "sha256:dc9c10fb40944260f6ed4c688ece0cd2048414940f1cea51b8b226318411c519"}, {file = "Pygments-2.12.0.tar.gz", hash = "sha256:5eb116118f9612ff1ee89ac96437bb6b49e8f04d8a13b514ba26f620208e26eb"}, @@ -1295,8 +1336,8 @@ pyyaml-env-tag = [ {file = "pyyaml_env_tag-0.1.tar.gz", hash = "sha256:70092675bda14fdec33b31ba77e7543de9ddc88f2e5b99160396572d11525bdb"}, ] redis = [ - {file = "redis-4.3.1-py3-none-any.whl", hash = "sha256:84316970995a7adb907a56754d2b92d88fc2d252963dc5ac34c88f0f1a22c25d"}, - {file = "redis-4.3.1.tar.gz", hash = "sha256:94b617b4cd296e94991146f66fc5559756fbefe9493604f0312e4d3298ac63e9"}, + {file = "redis-4.3.4-py3-none-any.whl", hash = "sha256:a52d5694c9eb4292770084fa8c863f79367ca19884b329ab574d5cb2036b3e54"}, + {file = "redis-4.3.4.tar.gz", hash = "sha256:ddf27071df4adf3821c4f2ca59d67525c3a82e5f268bed97b813cb4fabf87880"}, ] six = [ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, @@ -1318,6 +1359,8 @@ typer = [ {file = "typer-0.4.1-py3-none-any.whl", hash = "sha256:e8467f0ebac0c81366c2168d6ad9f888efdfb6d4e1d3d5b4a004f46fa444b5c3"}, {file = "typer-0.4.1.tar.gz", hash = "sha256:5646aef0d936b2c761a10393f0384ee6b5c7fe0bb3e5cd710b17134ca1d99cff"}, ] +types-redis = [] +types-tabulate = [] typing-extensions = [ {file = "typing_extensions-4.2.0-py3-none-any.whl", hash = "sha256:6657594ee297170d19f67d55c05852a874e7eb634f4f753dbd667855e07c1708"}, {file = "typing_extensions-4.2.0.tar.gz", hash = "sha256:f1c24655a0da0d1b67f07e17a5e6b2a105894e6824b92096378bb3668ef02376"}, diff --git a/pyproject.toml b/pyproject.toml index 063274a..776c3fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,12 @@ Jinja2 = "<3.1.0" # See https://github.com/mkdocs/mkdocs/issues/2794 mkdocs-material = "^8.2.6" mkdocstrings = {version = "^0.18.1", extras = ["python-legacy"]} mkdocs-typer = "^0.0.2" +isort = "^5.10.1" +types-tabulate = "^0.8.11" +mypy = "^0.961" +types-redis = "^4.3.4" +flake8 = "^4.0.1" +flake8-bugbear = "^22.7.1" [build-system] # Allow the package to be installed with `pip install -e ...`, @@ -55,3 +61,32 @@ requires = [ "setuptools" ] build-backend = "poetry.core.masonry.api" + +[tool.black] +line-length = 88 +target-version = ["py38"] +include = '\.pyi?$' + +[tool.isort] +py_version = 38 +profile = "black" +src_paths = ["mognet", "test"] +known_first_party = ["mognet"] + +[tool.mypy] +python_version = 3.8 +strict = true + +# Silence errors for third-party packages that don't have +# typings available. +[[tool.mypy.overrides]] +module = "pytest_asyncio.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "aiorun.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "treelib.*" +ignore_missing_imports = true \ No newline at end of file diff --git a/test/app_instance.py b/test/app_instance.py index e069a54..7de65bd 100644 --- a/test/app_instance.py +++ b/test/app_instance.py @@ -1,29 +1,36 @@ -from mognet.state.state_backend_config import ( - RedisStateBackendSettings, - StateBackendConfig, -) -from mognet.broker.broker_config import AmqpBrokerSettings, BrokerConfig +import os +from pathlib import Path + +from mognet.app.app import App +from mognet.app.app_config import AppConfig from mognet.backend.backend_config import ( RedisResultBackendSettings, ResultBackendConfig, ) -from mognet.app.app import App -from mognet.app.app_config import AppConfig +from mognet.broker.broker_config import AmqpBrokerSettings, BrokerConfig +from mognet.state.state_backend_config import ( + RedisStateBackendSettings, + StateBackendConfig, +) -config = AppConfig( - result_backend=ResultBackendConfig( - redis=RedisResultBackendSettings(url="redis://redis") - ), - broker=BrokerConfig(amqp=AmqpBrokerSettings(url="amqp://rabbitmq")), - state_backend=StateBackendConfig( - redis=RedisStateBackendSettings(url="redis://redis") - ), - task_routes={}, - minimum_concurrency=1, -) +def get_config(): + config_file_path = Path(os.getenv("MOGNET_CONFIG_FILE", "config.json")) + + if config_file_path.is_file(): + return AppConfig.parse_file(config_file_path) -config.imports = ["test.test_tasks"] + return AppConfig( + result_backend=ResultBackendConfig( + redis=RedisResultBackendSettings(url="redis://localhost:6379/0") + ), + broker=BrokerConfig(amqp=AmqpBrokerSettings(url="amqp://localhost:5672")), + state_backend=StateBackendConfig( + redis=RedisStateBackendSettings(url="redis://localhost:6379/0") + ), + task_routes={}, + minimum_concurrency=1, + ) -app = App(name="test", config=config) +app = App(name="test", config=get_config()) diff --git a/test/conftest.py b/test/conftest.py index 071b52b..410e670 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,12 +1,12 @@ import pytest -from mognet.testing.pytest_integration import create_app_fixture -from .app_instance import config, app +from mognet.testing.pytest_integration import create_app_fixture +from .app_instance import app, get_config test_app = create_app_fixture(app) @pytest.fixture def app_config(): - return config + return get_config() diff --git a/test/test_broker.py b/test/test_broker.py index 2b9d1b7..cebd850 100644 --- a/test/test_broker.py +++ b/test/test_broker.py @@ -1,8 +1,8 @@ -from mognet.exceptions.broker_exceptions import QueueNotFound -import pytest - from typing import TYPE_CHECKING +import pytest + +from mognet.exceptions.broker_exceptions import QueueNotFound if TYPE_CHECKING: from mognet import App @@ -12,6 +12,8 @@ async def test_broker_stats(test_app: "App"): stats = await test_app.broker.task_queue_stats("tasks") + assert stats is not None + @pytest.mark.asyncio async def test_broker_stats_fails_when_queue_not_found(test_app: "App"): @@ -21,3 +23,5 @@ async def test_broker_stats_fails_when_queue_not_found(test_app: "App"): # But subsequent calls should still work... stats = await test_app.broker.task_queue_stats("tasks") + + assert stats is not None diff --git a/test/test_execution_time.py b/test/test_execution_time.py index e69ff9d..bfbd3a4 100644 --- a/test/test_execution_time.py +++ b/test/test_execution_time.py @@ -1,13 +1,12 @@ -from datetime import timedelta, datetime, timezone -from mognet.exceptions.result_exceptions import Revoked -from mognet.model.result_state import ResultState import time -import pytest - +from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING -from mognet import Request +import pytest +from mognet import Request +from mognet.exceptions.result_exceptions import Revoked +from mognet.model.result_state import ResultState if TYPE_CHECKING: from mognet import App diff --git a/test/test_pausing.py b/test/test_pausing.py index 7c22fce..b540ed9 100644 --- a/test/test_pausing.py +++ b/test/test_pausing.py @@ -1,8 +1,10 @@ -import pytest from typing import List -from mognet.primitives.request import Request + +import pytest + from mognet import App, Context, task from mognet.exceptions.task_exceptions import Pause +from mognet.primitives.request import Request @task(name="test.paused_sum") diff --git a/test/test_recursion.py b/test/test_recursion.py index 7caf1ec..b2422a7 100644 --- a/test/test_recursion.py +++ b/test/test_recursion.py @@ -1,6 +1,6 @@ import pytest -from mognet import task, Request, App, Context +from mognet import App, Context, Request, task @task(name="test.recursive_factorial") diff --git a/test/test_result_backend.py b/test/test_result_backend.py index 4527b9b..7ce25bf 100644 --- a/test/test_result_backend.py +++ b/test/test_result_backend.py @@ -1,6 +1,6 @@ -from typing import Type -from uuid import uuid4 from dataclasses import dataclass +from uuid import uuid4 + import pytest import pytest_asyncio diff --git a/test/test_retry.py b/test/test_retry.py index 4045eac..d360615 100644 --- a/test/test_retry.py +++ b/test/test_retry.py @@ -1,8 +1,10 @@ import os + +import pytest + +from mognet import App, Context, Request, task from mognet.exceptions.too_many_retries import TooManyRetries from mognet.model.result_state import ResultState -import pytest -from mognet import App, Request, task, Context @pytest.mark.asyncio diff --git a/test/test_revoke.py b/test/test_revoke.py index 8c0a827..5933af4 100644 --- a/test/test_revoke.py +++ b/test/test_revoke.py @@ -1,9 +1,11 @@ import asyncio import uuid + +import pytest + +from mognet import App, Context, Request, task from mognet.model.result import ResultFailed from mognet.model.result_state import ResultState -import pytest -from mognet import App, Request, Context, task @pytest.mark.asyncio diff --git a/test/test_sync_tasks.py b/test/test_sync_tasks.py index 75b6183..3ad19e5 100644 --- a/test/test_sync_tasks.py +++ b/test/test_sync_tasks.py @@ -1,10 +1,12 @@ import asyncio -import pytest import time -from mognet import task, Context, App, Request -from mognet.model.result_state import ResultState from test.test_tasks import add +import pytest + +from mognet import App, Context, Request, task +from mognet.model.result_state import ResultState + @task(name="test.sync_slow_add") def sync_slow_add(context: Context, n1: float, n2: float, delay: float) -> float: diff --git a/test/test_tasks.py b/test/test_tasks.py index 8d23b4d..1255bba 100644 --- a/test/test_tasks.py +++ b/test/test_tasks.py @@ -1,7 +1,6 @@ import asyncio from mognet import Context - from mognet.decorators.task_decorator import task diff --git a/test/test_validation.py b/test/test_validation.py index b339a78..b2bf0cc 100644 --- a/test/test_validation.py +++ b/test/test_validation.py @@ -1,10 +1,12 @@ from dataclasses import dataclass + +import pytest from pydantic.main import BaseModel + +from mognet import App, Context, Request, task from mognet.exceptions.task_exceptions import InvalidTaskArguments from mognet.model.result_state import ResultState from mognet.tasks.task_registry import UnknownTask -import pytest -from mognet import App, Request, task, Context @pytest.mark.asyncio