Skip to content

Commit 9fddc1b

Browse files
authored
More efficient completion callback serialization (#5)
As of now, running a group of workflows would lead to serializing the `completion_callbacks` separately for each member of the group. With large workflows, this could be quite slow. Additionally, there was a long-standing TODO in the middleware to not unserialize the callbacks since they just have to be serialized again immediately after. This PR aims to solve both these issues. We never actually need an unserialized version of the callbacks so we now only pass around the serialized version. Only the workflow that needs to be run is now unserialized by the middleware when needed.
1 parent 0e6ce41 commit 9fddc1b

File tree

7 files changed

+125
-32
lines changed

7 files changed

+125
-32
lines changed

dramatiq_workflow/_base.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from ._constants import CALLBACK_BARRIER_TTL, OPTION_KEY_CALLBACKS
99
from ._helpers import workflow_with_completion_callbacks
1010
from ._middleware import WorkflowMiddleware, workflow_noop
11-
from ._models import Barrier, Chain, CompletionCallbacks, Group, Message, WithDelay, WorkflowType
12-
from ._serialize import serialize_callbacks, serialize_workflow
11+
from ._models import Barrier, Chain, Group, Message, SerializedCompletionCallbacks, WithDelay, WorkflowType
12+
from ._serialize import serialize_workflow
1313

1414
logger = logging.getLogger(__name__)
1515

@@ -91,15 +91,15 @@ def __init__(
9191
self.broker = broker or dramatiq.get_broker()
9292

9393
self._delay = None
94-
self._completion_callbacks = []
94+
self._completion_callbacks: SerializedCompletionCallbacks | None = None
9595

9696
while isinstance(self.workflow, WithDelay):
9797
self._delay = (self._delay or 0) + self.workflow.delay
9898
self.workflow = self.workflow.task
9999

100100
def run(self):
101101
current = self.workflow
102-
completion_callbacks = self._completion_callbacks.copy()
102+
completion_callbacks = self._completion_callbacks or []
103103

104104
if isinstance(current, Message):
105105
current = self.__augment_message(current, completion_callbacks)
@@ -115,7 +115,10 @@ def run(self):
115115
task = tasks.pop(0)
116116
if tasks:
117117
completion_id = self.__create_barrier(1)
118-
completion_callbacks.append((completion_id, Chain(*tasks), False))
118+
completion_callbacks = [
119+
*completion_callbacks,
120+
(completion_id, serialize_workflow(Chain(*tasks)), False),
121+
]
119122
self.__workflow_with_completion_callbacks(task, completion_callbacks).run()
120123
return
121124

@@ -126,7 +129,7 @@ def run(self):
126129
return
127130

128131
completion_id = self.__create_barrier(len(tasks))
129-
completion_callbacks.append((completion_id, None, True))
132+
completion_callbacks = [*completion_callbacks, (completion_id, None, True)]
130133
for task in tasks:
131134
self.__workflow_with_completion_callbacks(task, completion_callbacks).run()
132135
return
@@ -141,18 +144,22 @@ def __workflow_with_completion_callbacks(self, task, completion_callbacks) -> "W
141144
delay=self._delay,
142145
)
143146

144-
def __schedule_noop(self, completion_callbacks: CompletionCallbacks):
147+
def __schedule_noop(self, completion_callbacks: SerializedCompletionCallbacks):
145148
noop_message = workflow_noop.message()
146149
noop_message = self.__augment_message(noop_message, completion_callbacks)
147150
self.broker.enqueue(noop_message, delay=self._delay)
148151

149-
def __augment_message(self, message: Message, completion_callbacks: CompletionCallbacks) -> Message:
152+
def __augment_message(self, message: Message, completion_callbacks: SerializedCompletionCallbacks) -> Message:
153+
options = {}
154+
if completion_callbacks:
155+
options = {OPTION_KEY_CALLBACKS: completion_callbacks}
156+
150157
return message.copy(
151158
# We reset the message timestamp to better represent the time the
152159
# message was actually enqueued. This is to avoid tripping the max_age
153160
# check in the broker.
154161
message_timestamp=time.time() * 1000,
155-
options={OPTION_KEY_CALLBACKS: serialize_callbacks(completion_callbacks)},
162+
options=options,
156163
)
157164

158165
@property
@@ -170,7 +177,7 @@ def __rate_limiter_backend(self):
170177
)
171178
return self.__cached_rate_limiter_backend
172179

173-
def __create_barrier(self, count: int):
180+
def __create_barrier(self, count: int) -> str | None:
174181
if count == 1:
175182
# No need to create a distributed barrier if there is only one task
176183
return None

dramatiq_workflow/_helpers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import dramatiq
22

3-
from ._models import CompletionCallbacks, WorkflowType
3+
from ._models import SerializedCompletionCallbacks, WorkflowType
44

55

66
def workflow_with_completion_callbacks(
77
workflow: WorkflowType,
88
broker: dramatiq.Broker,
9-
completion_callbacks: CompletionCallbacks,
9+
completion_callbacks: SerializedCompletionCallbacks,
1010
delay: int | None = None,
1111
):
1212
from ._base import Workflow

dramatiq_workflow/_middleware.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
from ._constants import OPTION_KEY_CALLBACKS
77
from ._helpers import workflow_with_completion_callbacks
8-
from ._models import Barrier
9-
from ._serialize import unserialize_callbacks, unserialize_workflow
8+
from ._models import Barrier, SerializedCompletionCallbacks
9+
from ._serialize import unserialize_workflow
1010

1111
logger = logging.getLogger(__name__)
1212

@@ -28,7 +28,7 @@ def after_process_message(
2828
if message.failed:
2929
return
3030

31-
completion_callbacks: list[dict] | None = message.options.get(OPTION_KEY_CALLBACKS)
31+
completion_callbacks: SerializedCompletionCallbacks | None = message.options.get(OPTION_KEY_CALLBACKS)
3232
if completion_callbacks is None:
3333
return
3434

@@ -48,9 +48,7 @@ def after_process_message(
4848
workflow_with_completion_callbacks(
4949
unserialize_workflow(remaining_workflow),
5050
broker,
51-
# TODO: This is somewhat inefficient because we're unserializing all callbacks
52-
# even though we are just going to serialize them again.
53-
unserialize_callbacks(completion_callbacks),
51+
completion_callbacks,
5452
).run()
5553

5654
if not propagate:

dramatiq_workflow/_models.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,5 @@ def __eq__(self, other):
4242
Message = dramatiq.Message
4343
WorkflowType = Message | Chain | Group | WithDelay
4444

45-
# CompletionCallback is a tuple containing: ID, WorkflowType, propagate
46-
CompletionCallback = tuple[str, WorkflowType | None, bool]
47-
CompletionCallbacks = list[CompletionCallback]
45+
SerializedCompletionCallback = tuple[str | None, dict | None, bool]
46+
SerializedCompletionCallbacks = list[SerializedCompletionCallback]

dramatiq_workflow/_serialize.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
import typing
22

3-
from ._models import Chain, CompletionCallbacks, Group, Message, WithDelay, WorkflowType
3+
from ._models import (
4+
Chain,
5+
Group,
6+
Message,
7+
WithDelay,
8+
WorkflowType,
9+
)
410

511

6-
def serialize_workflow(workflow: WorkflowType | None) -> typing.Any:
12+
def serialize_workflow(workflow: WorkflowType | None) -> dict | None:
713
"""
814
Return a serialized version of the workflow that can be JSON-encoded.
915
"""
@@ -69,11 +75,3 @@ def unserialize_workflow_or_none(workflow: typing.Any) -> WorkflowType | None:
6975
if workflow is None:
7076
return None
7177
return unserialize_workflow(workflow)
72-
73-
74-
def serialize_callbacks(callbacks: CompletionCallbacks) -> list[tuple[str, typing.Any, bool]]:
75-
return [(id, serialize_workflow(g), propagate) for id, g, propagate in callbacks]
76-
77-
78-
def unserialize_callbacks(callbacks: list[dict]) -> CompletionCallbacks:
79-
return [(id, unserialize_workflow_or_none(g), propagate) for id, g, propagate in callbacks]
+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import unittest
2+
from unittest import mock
3+
4+
import dramatiq
5+
from dramatiq.broker import Broker
6+
from dramatiq.rate_limits.backends import StubBackend
7+
8+
from dramatiq_workflow import Chain, WorkflowMiddleware
9+
from dramatiq_workflow._barrier import AtMostOnceBarrier
10+
from dramatiq_workflow._constants import OPTION_KEY_CALLBACKS
11+
from dramatiq_workflow._serialize import serialize_workflow
12+
13+
14+
class WorkflowMiddlewareTests(unittest.TestCase):
15+
def setUp(self):
16+
# Initialize common mocks and the middleware instance for each test
17+
self.rate_limiter_backend = StubBackend()
18+
self.middleware = WorkflowMiddleware(self.rate_limiter_backend)
19+
20+
self.broker = mock.MagicMock(spec=Broker)
21+
22+
def _make_message(
23+
self, message_options: dict | None = None, message_timestamp: int = 1717526084640
24+
) -> dramatiq.broker.MessageProxy:
25+
"""
26+
Creates a dramatiq MessageProxy object with given options.
27+
"""
28+
message_id = 1 # Simplistic message ID for testing
29+
message = dramatiq.Message(
30+
message_id=str(message_id),
31+
message_timestamp=message_timestamp,
32+
queue_name="default",
33+
actor_name="test_task",
34+
args=(),
35+
kwargs={},
36+
options=message_options or {},
37+
)
38+
return dramatiq.broker.MessageProxy(message)
39+
40+
def _create_serialized_workflow(self) -> dict | None:
41+
"""
42+
Creates and serializes a simple workflow for testing.
43+
"""
44+
# Define a simple workflow (Chain with a single task)
45+
workflow = Chain(self._make_message()._message)
46+
serialized = serialize_workflow(workflow)
47+
return serialized
48+
49+
def test_after_process_message_without_callbacks(self):
50+
message = self._make_message()
51+
52+
self.middleware.after_process_message(self.broker, message)
53+
54+
self.broker.enqueue.assert_not_called()
55+
56+
def test_after_process_message_with_exception(self):
57+
message = self._make_message({OPTION_KEY_CALLBACKS: [(None, self._create_serialized_workflow(), True)]})
58+
59+
self.middleware.after_process_message(self.broker, message, exception=Exception("Test exception"))
60+
61+
self.broker.enqueue.assert_not_called()
62+
63+
def test_after_process_message_with_failed_message(self):
64+
message = self._make_message({OPTION_KEY_CALLBACKS: [(None, self._create_serialized_workflow(), True)]})
65+
message.failed = True
66+
67+
self.middleware.after_process_message(self.broker, message)
68+
69+
self.broker.enqueue.assert_not_called()
70+
71+
@mock.patch("dramatiq_workflow._base.time.time")
72+
def test_after_process_message_with_workflow(self, mock_time):
73+
mock_time.return_value = 1337
74+
message = self._make_message({OPTION_KEY_CALLBACKS: [(None, self._create_serialized_workflow(), True)]})
75+
76+
self.middleware.after_process_message(self.broker, message)
77+
78+
self.broker.enqueue.assert_called_once_with(self._make_message(message_timestamp=1337_000)._message, delay=None)
79+
80+
@mock.patch("dramatiq_workflow._base.time.time")
81+
def test_after_process_message_with_barriered_workflow(self, mock_time):
82+
mock_time.return_value = 1337
83+
barrier = AtMostOnceBarrier(self.rate_limiter_backend, "barrier_1")
84+
barrier.create(2)
85+
message = self._make_message({OPTION_KEY_CALLBACKS: [(barrier.key, self._create_serialized_workflow(), True)]})
86+
87+
self.middleware.after_process_message(self.broker, message)
88+
self.broker.enqueue.assert_not_called()
89+
90+
# Calling again, barrier should be completed now
91+
self.middleware.after_process_message(self.broker, message)
92+
self.broker.enqueue.assert_called_once_with(self._make_message(message_timestamp=1337_000)._message, delay=None)

dramatiq_workflow/tests/test_workflow.py

-1
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,6 @@ def test_nested_delays(self, time_mock):
377377
self.__make_message(
378378
1,
379379
message_timestamp=updated_timestamp,
380-
message_options={"workflow_completion_callbacks": []},
381380
),
382381
delay=20,
383382
)

0 commit comments

Comments
 (0)