Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ default_language_version:
python: "3"
repos:
- repo: https://github.com/compilerla/conventional-pre-commit
rev: v4.0.0
rev: v4.1.0
hooks:
- id: conventional-pre-commit
stages: [commit-msg]
Expand All @@ -17,7 +17,7 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.11.5"
rev: "v0.11.8"
hooks:
# Run the linter.
- id: ruff
Expand Down
2 changes: 1 addition & 1 deletion litestar_saq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class CronJob(SaqCronJob):
"""Cron Job Details"""

function: "Union[Function, str]" # type: ignore[assignment]
meta: "dict[str, Any]" = field(default_factory=dict) # pyright: ignore # noqa: PGH003
meta: "dict[str, Any]" = field(default_factory=dict) # pyright: ignore

def __post_init__(self) -> None:
self.function = self._get_or_import_function(self.function) # pyright: ignore[reportIncompatibleMethodOverride]
Expand Down
24 changes: 16 additions & 8 deletions litestar_saq/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def build_cli_app() -> "Group": # noqa: C901
from typing import cast

from click import IntRange, group, option
from litestar.cli._utils import LitestarGroup, console # pyright: ignore # noqa: PGH003
from litestar.cli._utils import LitestarGroup, console # pyright: ignore

@group(cls=LitestarGroup, name="workers", no_args_is_help=True)
def background_worker_group() -> None:
Expand Down Expand Up @@ -66,7 +66,7 @@ def run_worker( # pyright: ignore[reportUnusedFunction]
limited_start_up(plugin, queue_list)
show_saq_info(app, workers, plugin)
managed_workers = list(plugin.get_workers().values())
_processes: list[multiprocessing.Process] = []
processes: list[multiprocessing.Process] = []
if workers > 1:
for _ in range(workers - 1):
for worker in managed_workers:
Expand All @@ -78,13 +78,13 @@ def run_worker( # pyright: ignore[reportUnusedFunction]
),
)
p.start()
_processes.append(p)
processes.append(p)

if len(managed_workers) > 1:
for _ in range(len(managed_workers) - 1):
p = multiprocessing.Process(target=run_saq_worker, args=(managed_workers[_], app.logging_config))
for j in range(len(managed_workers) - 1):
p = multiprocessing.Process(target=run_saq_worker, args=(managed_workers[j + 1], app.logging_config))
p.start()
_processes.append(p)
processes.append(p)

try:
run_saq_worker(
Expand Down Expand Up @@ -131,8 +131,16 @@ def get_saq_plugin(app: "Litestar") -> "SAQPlugin":

This function attempts to find a SAQ plugin instance.
If plugin is not found, it raises an ImproperlyConfiguredException.
"""

Args:
app: The Litestar application instance.

Returns:
The SAQ plugin instance.

Raises:
ImproperConfigurationError: If the SAQ plugin is not found.
"""
from contextlib import suppress

from litestar_saq.exceptions import ImproperConfigurationError
Expand All @@ -149,7 +157,7 @@ def get_saq_plugin(app: "Litestar") -> "SAQPlugin":
def show_saq_info(app: "Litestar", workers: int, plugin: "SAQPlugin") -> None: # pragma: no cover
"""Display basic information about the application and its configuration."""

from litestar.cli._utils import _format_is_enabled, console # pyright: ignore # noqa: PGH003
from litestar.cli._utils import _format_is_enabled, console # pyright: ignore
from rich.table import Table
from saq import __version__ as saq_version

Expand Down
73 changes: 54 additions & 19 deletions litestar_saq/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _get_static_files() -> Path:
class TaskQueues:
"""Task queues."""

queues: "Mapping[str, Queue]" = field(default_factory=dict) # pyright: ignore # noqa: PGH003
queues: "Mapping[str, Queue]" = field(default_factory=dict) # pyright: ignore

def get(self, name: str) -> "Queue":
"""Get a queue by name.
Expand All @@ -69,7 +69,7 @@ def get(self, name: str) -> "Queue":
class SAQConfig:
"""SAQ Configuration."""

queue_configs: "Collection[QueueConfig]" = field(default_factory=list) # pyright: ignore # noqa: PGH003
queue_configs: "Collection[QueueConfig]" = field(default_factory=list) # pyright: ignore
"""Configuration for Queues"""

queue_instances: "Optional[Mapping[str, Queue]]" = None
Expand Down Expand Up @@ -119,7 +119,11 @@ def signature_namespace(self) -> "dict[str, Any]":
}

async def provide_queues(self) -> "AsyncGenerator[TaskQueues, None]":
"""Provide the configured job queues."""
"""Provide the configured job queues.

Yields:
The configured job queues.
"""
queues = self.get_queues()
for queue in queues.queues.values():
await queue.connect()
Expand All @@ -132,23 +136,27 @@ def filter_delete_queues(self, queues: "list[str]") -> None:
if self.queue_instances is not None:
for queue_name in dict(self.queue_instances):
if queue_name not in queues:
del self.queue_instances[queue_name] # type: ignore # noqa: PGH003
del self.queue_instances[queue_name] # type: ignore

def get_queues(self) -> "TaskQueues":
"""Get the configured SAQ queues."""
"""Get the configured SAQ queues.

Returns:
The configured job queues.
"""
if self.queue_instances is not None:
return TaskQueues(queues=self.queue_instances)

self.queue_instances = {}
for c in self.queue_configs:
self.queue_instances[c.name] = c.queue_class( # type: ignore # noqa: PGH003
self.queue_instances[c.name] = c.queue_class( # type: ignore
c.get_broker(),
name=c.name, # pyright: ignore[reportCallIssue]
dump=self.json_serializer,
load=self.json_deserializer,
**c.broker_options, # pyright: ignore[reportArgumentType]
)
self.queue_instances[c.name]._is_pool_provided = False # type: ignore # noqa: PGH003, SLF001
self.queue_instances[c.name]._is_pool_provided = False # type: ignore # noqa: SLF001
return TaskQueues(queues=self.queue_instances)


Expand All @@ -170,12 +178,12 @@ class PostgresQueueOptions(TypedDict, total=False):
stats_table: NotRequired[str]
min_size: NotRequired[int]
max_size: NotRequired[int]
poll_interval: NotRequired[float]
saq_lock_keyspace: NotRequired[int]
job_lock_keyspace: NotRequired[int]
job_lock_sweep: NotRequired[bool]
priorities: NotRequired[tuple[int, int]]
swept_error_message: NotRequired[str]
manage_pool_lifecycle: NotRequired[bool]


@dataclass
Expand All @@ -195,11 +203,11 @@ class QueueConfig:
"""The name of the queue to create."""
concurrency: int = 10
"""Number of jobs to process concurrently."""
broker_options: "Union[RedisQueueOptions, PostgresQueueOptions, dict[str, Any]]" = field(default_factory=dict) # pyright: ignore # noqa: PGH003
broker_options: "Union[RedisQueueOptions, PostgresQueueOptions, dict[str, Any]]" = field(default_factory=dict) # pyright: ignore
"""Broker-specific options. For Redis or Postgres backends."""
tasks: "Collection[Union[ReceivesContext, tuple[str, Function], str]]" = field(default_factory=list) # pyright: ignore # noqa: PGH003
tasks: "Collection[Union[ReceivesContext, tuple[str, Function], str]]" = field(default_factory=list) # pyright: ignore
"""Allowed list of functions to execute in this queue."""
scheduled_tasks: "Collection[CronJob]" = field(default_factory=list) # pyright: ignore # noqa: PGH003
scheduled_tasks: "Collection[CronJob]" = field(default_factory=list) # pyright: ignore
"""Scheduled cron jobs to execute in this queue."""
cron_tz: "tzinfo" = timezone.utc
"""Timezone for cron jobs."""
Expand Down Expand Up @@ -232,7 +240,6 @@ class QueueConfig:
Set it False to execute within the Litestar application."""

def __post_init__(self) -> None:
"""Post initialization."""
if self.dsn and self.broker_instance:
msg = "Cannot specify both `dsn` and `broker_instance`"
raise ImproperlyConfiguredException(msg)
Expand All @@ -248,16 +255,19 @@ def __post_init__(self) -> None:
self.before_process = [self.before_process]
if self.after_process is not None and not isinstance(self.after_process, Collection):
self.after_process = [self.after_process]
self.startup = [self._get_or_import_task(task) for task in self.startup or []] # pyright: ignore # noqa: PGH003
self.shutdown = [self._get_or_import_task(task) for task in self.shutdown or []] # pyright: ignore # noqa: PGH003
self.before_process = [self._get_or_import_task(task) for task in self.before_process or []] # pyright: ignore # noqa: PGH003
self.after_process = [self._get_or_import_task(task) for task in self.after_process or []] # pyright: ignore # noqa: PGH003
self.startup = [self._get_or_import_task(task) for task in self.startup or []] # pyright: ignore
self.shutdown = [self._get_or_import_task(task) for task in self.shutdown or []] # pyright: ignore
self.before_process = [self._get_or_import_task(task) for task in self.before_process or []] # pyright: ignore
self.after_process = [self._get_or_import_task(task) for task in self.after_process or []] # pyright: ignore
self._broker_type: Optional[Literal["redis", "postgres", "http"]] = None
self._queue_class: Optional[type[Queue]] = None

def get_broker(self) -> "Any":
"""Get the configured Broker connection.

Raises:
ImproperlyConfiguredException: If the broker type is invalid.

Returns:
Dictionary of queues.
"""
Expand Down Expand Up @@ -292,7 +302,14 @@ def get_broker(self) -> "Any":

@property
def broker_type(self) -> 'Literal["redis", "postgres", "http"]':
"""Type of broker to use."""
"""Type of broker to use.

Raises:
ImproperlyConfiguredException: If the broker type is invalid.

Returns:
The broker type.
"""
if self._broker_type is None and self.broker_instance is not None:
if self.broker_instance.__class__.__name__ == "AsyncConnectionPool":
self._broker_type = "postgres"
Expand All @@ -307,9 +324,27 @@ def broker_type(self) -> 'Literal["redis", "postgres", "http"]':
raise ImproperlyConfiguredException(msg)
return self._broker_type

@property
def _broker_options(self) -> "Union[RedisQueueOptions, PostgresQueueOptions, dict[str, Any]]":
"""Broker-specific options.

Returns:
The broker options.
"""
if self._broker_type == "postgres" and "manage_pool_lifecycle" not in self.broker_options:
self.broker_options["manage_pool_lifecycle"] = True # type: ignore[typeddict-unknown-key]
return self.broker_options

@property
def queue_class(self) -> "type[Queue]":
"""Type of queue to use."""
"""Type of queue to use.

Raises:
ImproperlyConfiguredException: If the queue class is invalid.

Returns:
The queue class.
"""
if self._queue_class is None and self.broker_instance is not None:
if self.broker_instance.__class__.__name__ == "AsyncConnectionPool":
from saq.queue.postgres import PostgresQueue
Expand Down Expand Up @@ -345,5 +380,5 @@ def _get_or_import_task(
if isinstance(task_or_import_string, str):
return cast("ReceivesContext", import_string(task_or_import_string))
if isinstance(task_or_import_string, tuple):
return task_or_import_string[1] # pyright: ignore # noqa: PGH003
return task_or_import_string[1] # pyright: ignore
return task_or_import_string
73 changes: 66 additions & 7 deletions litestar_saq/controllers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# ruff: noqa: PLR6301
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Optional, cast

Expand Down Expand Up @@ -46,7 +47,14 @@ class SAQController(Controller):
description="List configured worker queues.",
)
async def queue_list(self, task_queues: "TaskQueues") -> "dict[str, list[QueueInfo]]":
"""Get Worker queues."""
"""Get Worker queues.

Args:
task_queues: The task queues.

Returns:
The worker queues.
"""
return {"queues": [await queue.info() for queue in task_queues.queues.values()]}

@get(
Expand All @@ -59,7 +67,18 @@ async def queue_list(self, task_queues: "TaskQueues") -> "dict[str, list[QueueIn
description="List queue details.",
)
async def queue_detail(self, task_queues: "TaskQueues", queue_id: str) -> "dict[str, QueueInfo]":
"""Get queue information."""
"""Get queue information.

Args:
task_queues: The task queues.
queue_id: The queue ID.

Raises:
NotFoundException: If the queue is not found.

Returns:
The queue information.
"""
queue = task_queues.get(queue_id)
if not queue:
msg = f"Could not find the {queue_id} queue"
Expand All @@ -78,7 +97,19 @@ async def queue_detail(self, task_queues: "TaskQueues", queue_id: str) -> "dict[
async def job_detail(
self, task_queues: "TaskQueues", queue_id: str, job_id: str
) -> "dict[str, dict[str, Any]]":
"""Get job information."""
"""Get job information.

Args:
task_queues: The task queues.
queue_id: The queue ID.
job_id: The job ID.

Raises:
NotFoundException: If the queue or job is not found.

Returns:
The job information.
"""
queue = task_queues.get(queue_id)
if not queue:
msg = f"Could not find the {queue_id} queue"
Expand All @@ -102,7 +133,19 @@ async def job_detail(
status_code=HTTP_202_ACCEPTED,
)
async def job_retry(self, task_queues: "TaskQueues", queue_id: str, job_id: str) -> "dict[str, str]":
"""Retry job."""
"""Retry job.

Args:
task_queues: The task queues.
queue_id: The queue ID.
job_id: The job ID.

Raises:
NotFoundException: If the queue or job is not found.

Returns:
The job information.
"""
queue = task_queues.get(queue_id)
if not queue:
msg = f"Could not find the {queue_id} queue"
Expand All @@ -122,7 +165,19 @@ async def job_retry(self, task_queues: "TaskQueues", queue_id: str, job_id: str)
status_code=HTTP_202_ACCEPTED,
)
async def job_abort(self, task_queues: "TaskQueues", queue_id: str, job_id: str) -> "dict[str, str]":
"""Abort job."""
"""Abort job.

Args:
task_queues: The task queues.
queue_id: The queue ID.
job_id: The job ID.

Raises:
NotFoundException: If the queue or job is not found.

Returns:
The job information.
"""
queue = task_queues.get(queue_id)
if not queue:
msg = f"Could not find the {queue_id} queue"
Expand All @@ -134,7 +189,7 @@ async def job_abort(self, task_queues: "TaskQueues", queue_id: str, job_id: str)
# static site
@get(
[
url_base,
f"{url_base}/",
f"{url_base}/queues/{{queue_id:str}}",
f"{url_base}/queues/{{queue_id:str}}/jobs/{{job_id:str}}",
],
Expand All @@ -144,7 +199,11 @@ async def job_abort(self, task_queues: "TaskQueues", queue_id: str, job_id: str)
include_in_schema=False,
)
async def index(self) -> str:
"""Serve site root."""
"""Serve site root.

Returns:
The site root.
"""
return f"""
<!DOCTYPE html>
<html>
Expand Down
Loading