Skip to content

Commit 9f138cc

Browse files
committed
Improve typing of middleware functions
1 parent 5c492e4 commit 9f138cc

File tree

8 files changed

+68
-44
lines changed

8 files changed

+68
-44
lines changed

dramatiq/middleware/age_limit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
from __future__ import annotations
1919

20+
from typing import Optional
21+
2022
from ..common import current_millis
2123
from ..logging import get_logger
2224
from .middleware import Middleware, SkipMessage
@@ -32,7 +34,7 @@ class AgeLimit(Middleware):
3234
indefinitely.
3335
"""
3436

35-
def __init__(self, *, max_age=None):
37+
def __init__(self, *, max_age: Optional[int] = None) -> None:
3638
self.logger = get_logger(__name__, type(self))
3739
self.max_age = max_age
3840

dramatiq/middleware/asyncio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
class AsyncIO(Middleware):
2626
"""This middleware manages the event loop thread for async actors."""
2727

28-
def __init__(self):
28+
def __init__(self) -> None:
2929
self.logger = get_logger(__name__, type(self))
3030

3131
def before_worker_boot(self, broker, worker):

dramatiq/middleware/group_callbacks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919

2020
import os
2121

22-
from ..rate_limits import Barrier
22+
from ..rate_limits import Barrier, RateLimiterBackend
2323
from .middleware import Middleware
2424

2525
GROUP_CALLBACK_BARRIER_TTL = int(os.getenv("dramatiq_group_callback_barrier_ttl", "86400000"))
2626

2727

2828
class GroupCallbacks(Middleware):
29-
def __init__(self, rate_limiter_backend):
29+
def __init__(self, rate_limiter_backend: RateLimiterBackend) -> None:
3030
self.rate_limiter_backend = rate_limiter_backend
3131

3232
def after_process_message(self, broker, message, *, result=None, exception=None):

dramatiq/middleware/middleware.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@
1818

1919
from __future__ import annotations
2020

21+
from typing import TYPE_CHECKING, Any, Callable, Optional
22+
23+
if TYPE_CHECKING:
24+
from ..actor import Actor
25+
from ..broker import Broker, MessageProxy
26+
from ..message import Message
27+
from ..worker import Worker, _ConsumerThread, _WorkerThread
28+
2129

2230
class MiddlewareError(Exception):
2331
"""Base class for middleware errors."""
@@ -36,39 +44,39 @@ class Middleware:
3644
"""
3745

3846
@property
39-
def actor_options(self):
47+
def actor_options(self) -> set[str]:
4048
"""The set of options that may be configured on each actor."""
4149
return set()
4250

4351
@property
44-
def forks(self):
52+
def forks(self) -> list[Callable[[], int]]:
4553
"""A list of functions to run in separate forks of the main
4654
process.
4755
"""
4856
return []
4957

50-
def before_ack(self, broker, message):
58+
def before_ack(self, broker: Broker, message: MessageProxy) -> None:
5159
"""Called before a message is acknowledged."""
5260

53-
def after_ack(self, broker, message):
61+
def after_ack(self, broker: Broker, message: MessageProxy) -> None:
5462
"""Called after a message has been acknowledged."""
5563

56-
def before_nack(self, broker, message):
64+
def before_nack(self, broker: Broker, message: MessageProxy) -> None:
5765
"""Called before a message is rejected."""
5866

59-
def after_nack(self, broker, message):
67+
def after_nack(self, broker: Broker, message: MessageProxy) -> None:
6068
"""Called after a message has been rejected."""
6169

62-
def before_declare_actor(self, broker, actor):
70+
def before_declare_actor(self, broker: Broker, actor: Actor) -> None:
6371
"""Called before an actor is declared."""
6472

65-
def after_declare_actor(self, broker, actor):
73+
def after_declare_actor(self, broker: Broker, actor: Actor) -> None:
6674
"""Called after an actor has been declared."""
6775

68-
def before_declare_queue(self, broker, queue_name):
76+
def before_declare_queue(self, broker: Broker, queue_name: str) -> None:
6977
"""Called before a queue is declared."""
7078

71-
def after_declare_queue(self, broker, queue_name):
79+
def after_declare_queue(self, broker: Broker, queue_name: str) -> None:
7280
"""Called after a queue has been declared.
7381
7482
This signals that the queue has been registered with the
@@ -78,19 +86,19 @@ def after_declare_queue(self, broker, queue_name):
7886
them until messages are enqueued or consumed.
7987
"""
8088

81-
def after_declare_delay_queue(self, broker, queue_name):
89+
def after_declare_delay_queue(self, broker: Broker, queue_name: str) -> None:
8290
"""Called after a delay queue has been declared."""
8391

84-
def before_enqueue(self, broker, message, delay):
92+
def before_enqueue(self, broker: Broker, message: Message, delay: int) -> None:
8593
"""Called before a message is enqueued."""
8694

87-
def after_enqueue(self, broker, message, delay):
95+
def after_enqueue(self, broker: Broker, message: Message, delay: int) -> None:
8896
"""Called after a message has been enqueued."""
8997

90-
def before_delay_message(self, broker, message):
98+
def before_delay_message(self, broker: Broker, message: MessageProxy) -> None:
9199
"""Called before a message has been delayed in worker memory."""
92100

93-
def before_process_message(self, broker, message):
101+
def before_process_message(self, broker: Broker, message: MessageProxy) -> None:
94102
"""Called before a message is processed.
95103
96104
Raises:
@@ -99,42 +107,49 @@ def before_process_message(self, broker, message):
99107
of ``after_process_message``.
100108
"""
101109

102-
def after_process_message(self, broker, message, *, result=None, exception=None):
110+
def after_process_message(
111+
self,
112+
broker: Broker,
113+
message: MessageProxy,
114+
*,
115+
result: Optional[Any] = None,
116+
exception: Optional[BaseException] = None,
117+
) -> None:
103118
"""Called after a message has been processed."""
104119

105-
def after_skip_message(self, broker, message):
120+
def after_skip_message(self, broker: Broker, message: MessageProxy) -> None:
106121
"""Called instead of ``after_process_message`` after a message
107122
has been skipped.
108123
"""
109124

110-
def after_process_boot(self, broker):
125+
def after_process_boot(self, broker: Broker) -> None:
111126
"""Called immediately after subprocess start up."""
112127

113-
def before_worker_boot(self, broker, worker):
128+
def before_worker_boot(self, broker: Broker, worker: Worker) -> None:
114129
"""Called before the worker process starts up."""
115130

116-
def after_worker_boot(self, broker, worker):
131+
def after_worker_boot(self, broker: Broker, worker: Worker) -> None:
117132
"""Called after the worker process has started up."""
118133

119-
def before_worker_shutdown(self, broker, worker):
134+
def before_worker_shutdown(self, broker: Broker, worker: Worker) -> None:
120135
"""Called before the worker process shuts down."""
121136

122-
def after_worker_shutdown(self, broker, worker):
137+
def after_worker_shutdown(self, broker: Broker, worker: Worker) -> None:
123138
"""Called after the worker process shuts down."""
124139

125-
def after_consumer_thread_boot(self, broker, thread):
140+
def after_consumer_thread_boot(self, broker: Broker, thread: _ConsumerThread) -> None:
126141
"""Called from a consumer thread after it starts but before it starts its run loop."""
127142

128-
def before_consumer_thread_shutdown(self, broker, thread):
143+
def before_consumer_thread_shutdown(self, broker: Broker, thread: _ConsumerThread) -> None:
129144
"""Called before a consumer thread shuts down. This may be
130145
used to clean up thread-local resources (such as Django
131146
database connections).
132147
"""
133148

134-
def after_worker_thread_boot(self, broker, thread):
149+
def after_worker_thread_boot(self, broker: Broker, thread: _WorkerThread) -> None:
135150
"""Called from a worker thread after it starts but before it starts its run loop."""
136151

137-
def before_worker_thread_shutdown(self, broker, thread):
152+
def before_worker_thread_shutdown(self, broker: Broker, thread: _WorkerThread) -> None:
138153
"""Called before a worker thread shuts down. This may be used
139154
to clean up thread-local resources (such as Django database
140155
connections).

dramatiq/middleware/prometheus.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ class Prometheus(Middleware):
4545
.. _Prometheus: https://prometheus.io
4646
"""
4747

48-
def __init__(self):
48+
def __init__(self) -> None:
4949
self.logger = get_logger(__name__, type(self))
50-
self.delayed_messages = set()
51-
self.message_start_times = {}
50+
self.delayed_messages: set[str] = set()
51+
self.message_start_times: dict[str, int] = {}
5252

5353
@property
5454
def forks(self):

dramatiq/middleware/retries.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import time
2121
import traceback
22+
from typing import Callable, Optional
2223

2324
from ..common import compute_backoff
2425
from ..errors import Retry
@@ -60,15 +61,22 @@ class Retries(Middleware):
6061
apply to retried tasks. Defaults to 15 seconds.
6162
max_backoff(int): The maximum amount of backoff milliseconds to
6263
apply to retried tasks. Defaults to 7 days.
63-
retry_when(Callable[[int, Exception], bool]): An optional
64+
retry_when(Callable[[int, BaseException], bool]): An optional
6465
predicate that can be used to programmatically determine
6566
whether a task should be retried or not. This takes
6667
precedence over `max_retries` when set.
6768
on_retry_exhausted(str): Name of an actor to send a message to when
6869
message is failed due to retries being exceeded.
6970
"""
7071

71-
def __init__(self, *, max_retries=20, min_backoff=None, max_backoff=None, retry_when=None):
72+
def __init__(
73+
self,
74+
*,
75+
max_retries: int = 20,
76+
min_backoff: Optional[int] = None,
77+
max_backoff: Optional[int] = None,
78+
retry_when: Optional[Callable[[int, BaseException], bool]] = None,
79+
) -> None:
7280
self.logger = get_logger(__name__, type(self))
7381
self.max_retries = max_retries
7482
self.min_backoff = min_backoff or DEFAULT_MIN_BACKOFF

dramatiq/middleware/shutdown.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import threading
2121
import warnings
22-
from typing import Optional, Type
2322

2423
from ..logging import get_logger
2524
from .middleware import Middleware
@@ -55,9 +54,11 @@ class ShutdownNotifications(Middleware):
5554
Defaults to False, meaning actors will not be interrupted, and allowed to finish.
5655
"""
5756

58-
def __init__(self, notify_shutdown=False):
57+
def __init__(self, notify_shutdown: bool = False) -> None:
5958
self.logger = get_logger(__name__, type(self))
6059
self.notify_shutdown = notify_shutdown
60+
61+
self.manager: _ShutdownManager
6162
if is_gevent_active():
6263
self.manager = _GeventShutdownManager(self.logger)
6364
else:
@@ -132,11 +133,10 @@ def shutdown(self):
132133
raise_thread_exception(thread_id, Shutdown)
133134

134135

135-
_GeventShutdownManager: Optional[Type[_ShutdownManager]] = None
136136
if is_gevent_active():
137137
from gevent import getcurrent
138138

139-
class __GeventShutdownManager(_ShutdownManager):
139+
class _GeventShutdownManager(_ShutdownManager):
140140

141141
def __init__(self, logger=None):
142142
self.logger = logger or get_logger(__name__, type(self))
@@ -161,5 +161,3 @@ def shutdown(self):
161161
thread_id,
162162
)
163163
greenlet.kill(Shutdown, block=False)
164-
165-
_GeventShutdownManager = __GeventShutdownManager

dramatiq/middleware/time_limit.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import warnings
2222
from threading import Thread
2323
from time import monotonic, sleep
24-
from typing import TYPE_CHECKING, Optional
24+
from typing import TYPE_CHECKING, Optional, Union
2525

2626
from ..logging import get_logger
2727
from .middleware import Middleware
@@ -64,10 +64,11 @@ class TimeLimit(Middleware):
6464
Defaults to 1 second (1,000 milliseconds).
6565
"""
6666

67-
def __init__(self, *, time_limit=600000, interval=1000):
67+
def __init__(self, *, time_limit: float = 600000, interval: int = 1000) -> None:
6868
self.logger = get_logger(__name__, type(self))
6969
self.time_limit = time_limit
7070

71+
self.manager: Union[_GeventTimeoutManager, _CtypesTimeoutManager]
7172
if is_gevent_active():
7273
self.manager = _GeventTimeoutManager(logger=self.logger)
7374
else:
@@ -162,7 +163,7 @@ def remove_timeout(self, thread_id):
162163
timer.close()
163164

164165

165-
_GeventTimeout: Optional["gevent.Timeout"] = None
166+
_GeventTimeout: Optional[gevent.Timeout] = None
166167
if is_gevent_active():
167168
from gevent import Timeout
168169

0 commit comments

Comments
 (0)