Skip to content

Commit aa91cdf

Browse files
authored
Merge pull request #804 from LincolnPuzey/more_actor_type_fixes
Fixes to type hints for `@actor` decorator and `Actor` class
2 parents 4a5845b + 9c1e378 commit aa91cdf

File tree

3 files changed

+134
-11
lines changed

3 files changed

+134
-11
lines changed

dramatiq/actor.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class Actor(Generic[P, R]):
6363

6464
def __init__(
6565
self,
66-
fn: Callable[P, Union[R, Awaitable[R]]],
66+
fn: Union[Callable[P, Awaitable[R]], Callable[P, R]],
6767
*,
6868
broker: Broker,
6969
actor_name: str,
@@ -75,7 +75,7 @@ def __init__(
7575
raise ValueError(f"An actor named {actor_name!r} is already registered.")
7676

7777
self.logger = get_logger(fn.__module__ or "_", actor_name)
78-
self.fn = async_to_sync(fn) if iscoroutinefunction(fn) else fn
78+
self.fn: Callable[P, R] = async_to_sync(fn) if iscoroutinefunction(fn) else fn # type: ignore[assignment]
7979
self.broker = broker
8080
self.actor_name = actor_name
8181
self.queue_name = queue_name
@@ -178,7 +178,7 @@ def send_with_options(
178178
message = self.message_with_options(args=args, kwargs=kwargs, **options)
179179
return self.broker.enqueue(message, delay=delay)
180180

181-
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Any | R | Awaitable[R]:
181+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
182182
"""Synchronously call this actor.
183183
184184
Parameters:
@@ -210,21 +210,29 @@ def __call__(self, fn: Callable[P, Awaitable[R]]) -> Actor[P, R]: ...
210210
@overload
211211
def __call__(self, fn: Callable[P, R]) -> Actor[P, R]: ...
212212

213-
def __call__(self, fn: Callable[P, Union[Awaitable[R], R]]) -> Actor[P, R]: ...
213+
def __call__(self, fn: Union[Callable[P, Awaitable[R]], Callable[P, R]]) -> Actor[P, R]: ...
214214

215215

216216
@overload
217-
def actor(fn: Callable[P, Awaitable[R]], **kwargs) -> Actor[P, R]:
217+
def actor(fn: Callable[P, Awaitable[R]]) -> Actor[P, R]:
218218
pass
219219

220220

221221
@overload
222-
def actor(fn: Callable[P, R], **kwargs) -> Actor[P, R]:
222+
def actor(fn: Callable[P, R]) -> Actor[P, R]:
223223
pass
224224

225225

226226
@overload
227-
def actor(fn: None = None, **kwargs) -> ActorDecorator:
227+
def actor(
228+
*,
229+
queue_name: str = "default",
230+
priority: int = 0,
231+
actor_name: Optional[str] = None,
232+
broker: Optional[Broker] = None,
233+
actor_class: Callable[..., Actor[Any, Any]] = Actor,
234+
**options: Any,
235+
) -> ActorDecorator:
228236
pass
229237

230238

@@ -236,7 +244,7 @@ def actor(
236244
queue_name: str = "default",
237245
priority: int = 0,
238246
broker: Optional[Broker] = None,
239-
**options,
247+
**options: Any,
240248
) -> Union[Actor[P, R], Callable]:
241249
"""Declare an actor.
242250

dramatiq/asyncio.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import functools
2323
import logging
2424
import threading
25-
from typing import Awaitable, Callable, Optional, TypeVar
25+
from typing import Awaitable, Callable, Optional, ParamSpec, TypeVar
2626

2727
from .threading import Interrupt
2828

@@ -34,6 +34,7 @@
3434
]
3535

3636
R = TypeVar("R")
37+
P = ParamSpec("P")
3738

3839
_event_loop_thread = None
3940

@@ -53,13 +54,13 @@ def set_event_loop_thread(thread: Optional[EventLoopThread]) -> None:
5354
_event_loop_thread = thread
5455

5556

56-
def async_to_sync(async_fn: Callable[..., Awaitable[R]]) -> Callable[..., R]:
57+
def async_to_sync(async_fn: Callable[P, Awaitable[R]]) -> Callable[P, R]:
5758
"""Wrap an async function to run it on the event loop thread and
5859
synchronously wait for its result on the calling thread.
5960
"""
6061

6162
@functools.wraps(async_fn)
62-
def wrapper(*args, **kwargs) -> R:
63+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
6364
event_loop_thread = get_event_loop_thread()
6465
if event_loop_thread is None:
6566
raise RuntimeError(

tests/test_types.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""Tests for dramatiq's types.
2+
3+
Unlike other test files which are run with pytest,
4+
this test file should be type-checked with mypy,
5+
to test that Dramatiq's types can be "consumed" by user code without type errors.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
from typing import TYPE_CHECKING, ParamSpec, TypeVar
11+
12+
if TYPE_CHECKING:
13+
from typing_extensions import assert_type
14+
15+
import dramatiq
16+
17+
P = ParamSpec("P")
18+
R = TypeVar("R")
19+
20+
broker = dramatiq.get_broker()
21+
22+
23+
class ArgType:
24+
pass
25+
26+
27+
class ReturnType:
28+
pass
29+
30+
31+
class CustomActor(dramatiq.Actor[P, R]):
32+
pass
33+
34+
35+
# # # Tests for @actor decorator # # #
36+
37+
38+
@dramatiq.actor
39+
def actor(arg: ArgType) -> ReturnType:
40+
return ReturnType()
41+
42+
43+
@dramatiq.actor()
44+
def actor_no_options(arg: ArgType) -> ReturnType:
45+
return ReturnType()
46+
47+
48+
@dramatiq.actor(
49+
actor_name="actor_with_options",
50+
queue_name="some_queue",
51+
priority=2,
52+
broker=broker,
53+
max_age=1,
54+
)
55+
def actor_with_options(arg: ArgType) -> ReturnType:
56+
return ReturnType()
57+
58+
59+
@dramatiq.actor(
60+
actor_class=CustomActor,
61+
actor_name="actor_with_custom_actor_class",
62+
queue_name="some_queue",
63+
priority=2,
64+
broker=broker,
65+
max_age=1,
66+
)
67+
def actor_with_custom_actor_class(arg: ArgType) -> ReturnType:
68+
return ReturnType()
69+
70+
71+
@dramatiq.actor
72+
async def async_actor(arg: ArgType) -> ReturnType:
73+
return ReturnType()
74+
75+
76+
@dramatiq.actor()
77+
async def async_actor_no_options(arg: ArgType) -> ReturnType:
78+
return ReturnType()
79+
80+
81+
@dramatiq.actor(
82+
actor_name="async_actor_with_options",
83+
queue_name="some_queue",
84+
priority=2,
85+
broker=broker,
86+
max_age=1,
87+
)
88+
async def async_actor_with_options(arg: ArgType) -> ReturnType:
89+
return ReturnType()
90+
91+
92+
@dramatiq.actor(
93+
actor_class=CustomActor,
94+
actor_name="async_actor_with_custom_actor_class",
95+
queue_name="some_queue",
96+
priority=2,
97+
broker=broker,
98+
max_age=1,
99+
)
100+
async def async_actor_with_custom_actor_class(arg: ArgType) -> ReturnType:
101+
return ReturnType()
102+
103+
104+
# # # Test that calling actors has correct arg/return type
105+
def _calling_actors_type_check() -> None:
106+
assert_type(actor(ArgType()), ReturnType)
107+
assert_type(actor_no_options(ArgType()), ReturnType)
108+
assert_type(actor_with_options(ArgType()), ReturnType)
109+
assert_type(actor_with_custom_actor_class(ArgType()), ReturnType)
110+
111+
assert_type(async_actor(ArgType()), ReturnType)
112+
assert_type(async_actor_no_options(ArgType()), ReturnType)
113+
assert_type(async_actor_with_options(ArgType()), ReturnType)
114+
assert_type(async_actor_with_custom_actor_class(ArgType()), ReturnType)

0 commit comments

Comments
 (0)