Skip to content

Commit 7d014d3

Browse files
authored
Merge branch 'main' into always-use-barrier
2 parents 193d8ad + 9fddc1b commit 7d014d3

10 files changed

+210
-33
lines changed

README.md

+44
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,50 @@ workflow = Workflow(
153153
In this example, Task 2 will run roughly 1 second after Task 1 finishes, and
154154
Task 3 and will run 2 seconds after Task 2 finishes.
155155

156+
### Large Workflows
157+
158+
Because of how `dramatiq-workflow` is implemented, each task in a workflow has
159+
to know about the remaining tasks in the workflow that could potentially run
160+
after it. When a workflow has a large number of tasks, this can lead to an
161+
increase of memory usage in the broker and increased network traffic between
162+
the broker and the workers, especially when using `Group` tasks: Each task in a
163+
`Group` can potentially be the last one to finish, so each task has to retain a
164+
copy of the remaining tasks that run after the `Group`.
165+
166+
There are a few things you can do to alleviate this issue:
167+
168+
- Minimize the usage of parameters in the `message` method. Instead, consider
169+
using a database to store data that is required by your tasks.
170+
- Limit the size of groups to a reasonable number of tasks. Instead of
171+
scheduling one task with 1000 tasks in a group, consider scheduling 10 groups
172+
with 100 tasks each and chaining them together.
173+
- Consider breaking down large workflows into smaller partial workflows that
174+
then schedule a subsequent workflow at the very end of the outermost `Chain`.
175+
176+
Lastly, you can use compression to reduce the size of the messages in your
177+
queue. While dramatiq does not provide a compression implementation by default,
178+
one can be added with just a few lines of code. For example:
179+
180+
```python
181+
import dramatiq
182+
from dramatiq.encoder import JSONEncoder, MessageData
183+
import lz4.frame
184+
185+
class DramatiqLz4JSONEncoder(JSONEncoder):
186+
def encode(self, data: MessageData) -> bytes:
187+
return lz4.frame.compress(super().encode(data))
188+
189+
def decode(self, data: bytes) -> MessageData:
190+
try:
191+
decompressed = lz4.frame.decompress(data)
192+
except RuntimeError:
193+
# Uncompressed data from before the switch to lz4
194+
decompressed = data
195+
return super().decode(decompressed)
196+
197+
dramatiq.set_encoder(DramatiqLz4JSONEncoder())
198+
```
199+
156200
### Barrier
157201

158202
`dramatiq-workflow` uses a barrier mechanism to keep track of the current state

dramatiq_workflow/_barrier.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self, backend, key, *args, ttl=900000):
2020
self.ran_key = f"{key}_ran"
2121

2222
def create(self, parties):
23-
self.backend.add(self.ran_key, 0, self.ttl)
23+
self.backend.add(self.ran_key, -1, self.ttl)
2424
return super().create(parties)
2525

2626
def wait(self, *args, block=True, timeout=None):
@@ -32,7 +32,7 @@ def wait(self, *args, block=True, timeout=None):
3232

3333
released = super().wait(*args, block=False)
3434
if released:
35-
never_released = self.backend.incr(self.ran_key, 1, 1, self.ttl)
35+
never_released = self.backend.incr(self.ran_key, 1, 0, self.ttl)
3636
return never_released
3737

3838
return False

dramatiq_workflow/_base.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from ._constants import CALLBACK_BARRIER_TTL, OPTION_KEY_CALLBACKS
99
from ._helpers import workflow_with_completion_callbacks
1010
from ._middleware import WorkflowMiddleware, workflow_noop
11-
from ._models import Chain, CompletionCallbacks, Group, Message, WithDelay, WorkflowType
12-
from ._serialize import serialize_callbacks, serialize_workflow
11+
from ._models import Chain, Group, Message, SerializedCompletionCallbacks, WithDelay, WorkflowType
12+
from ._serialize import serialize_workflow
1313

1414
logger = logging.getLogger(__name__)
1515

@@ -91,15 +91,15 @@ def __init__(
9191
self.broker = broker or dramatiq.get_broker()
9292

9393
self._delay = None
94-
self._completion_callbacks = []
94+
self._completion_callbacks: SerializedCompletionCallbacks | None = None
9595

9696
while isinstance(self.workflow, WithDelay):
9797
self._delay = (self._delay or 0) + self.workflow.delay
9898
self.workflow = self.workflow.task
9999

100100
def run(self):
101101
current = self.workflow
102-
completion_callbacks = self._completion_callbacks.copy()
102+
completion_callbacks = self._completion_callbacks or []
103103

104104
if isinstance(current, Message):
105105
current = self.__augment_message(current, completion_callbacks)
@@ -115,7 +115,10 @@ def run(self):
115115
task = tasks.pop(0)
116116
if tasks:
117117
completion_id = self.__create_barrier(1)
118-
completion_callbacks.append((completion_id, Chain(*tasks), False))
118+
completion_callbacks = [
119+
*completion_callbacks,
120+
(completion_id, serialize_workflow(Chain(*tasks)), False),
121+
]
119122
self.__workflow_with_completion_callbacks(task, completion_callbacks).run()
120123
return
121124

@@ -126,7 +129,7 @@ def run(self):
126129
return
127130

128131
completion_id = self.__create_barrier(len(tasks))
129-
completion_callbacks.append((completion_id, None, True))
132+
completion_callbacks = [*completion_callbacks, (completion_id, None, True)]
130133
for task in tasks:
131134
self.__workflow_with_completion_callbacks(task, completion_callbacks).run()
132135
return
@@ -141,18 +144,22 @@ def __workflow_with_completion_callbacks(self, task, completion_callbacks) -> "W
141144
delay=self._delay,
142145
)
143146

144-
def __schedule_noop(self, completion_callbacks: CompletionCallbacks):
147+
def __schedule_noop(self, completion_callbacks: SerializedCompletionCallbacks):
145148
noop_message = workflow_noop.message()
146149
noop_message = self.__augment_message(noop_message, completion_callbacks)
147150
self.broker.enqueue(noop_message, delay=self._delay)
148151

149-
def __augment_message(self, message: Message, completion_callbacks: CompletionCallbacks) -> Message:
152+
def __augment_message(self, message: Message, completion_callbacks: SerializedCompletionCallbacks) -> Message:
153+
options = {}
154+
if completion_callbacks:
155+
options = {OPTION_KEY_CALLBACKS: completion_callbacks}
156+
150157
return message.copy(
151158
# We reset the message timestamp to better represent the time the
152159
# message was actually enqueued. This is to avoid tripping the max_age
153160
# check in the broker.
154161
message_timestamp=time.time() * 1000,
155-
options={OPTION_KEY_CALLBACKS: serialize_callbacks(completion_callbacks)},
162+
options=options,
156163
)
157164

158165
@property
@@ -178,7 +185,7 @@ def __rate_limiter_backend(self) -> dramatiq.rate_limits.RateLimiterBackend:
178185
def __barrier(self) -> type[dramatiq.rate_limits.Barrier]:
179186
return self.__middleware.barrier
180187

181-
def __create_barrier(self, count: int):
188+
def __create_barrier(self, count: int) -> str:
182189
completion_uuid = str(uuid4())
183190
completion_barrier = self.__barrier(self.__rate_limiter_backend, completion_uuid, ttl=CALLBACK_BARRIER_TTL)
184191
completion_barrier.create(count)

dramatiq_workflow/_helpers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import dramatiq
22

3-
from ._models import CompletionCallbacks, WorkflowType
3+
from ._models import SerializedCompletionCallbacks, WorkflowType
44

55

66
def workflow_with_completion_callbacks(
77
workflow: WorkflowType,
88
broker: dramatiq.Broker,
9-
completion_callbacks: CompletionCallbacks,
9+
completion_callbacks: SerializedCompletionCallbacks,
1010
delay: int | None = None,
1111
):
1212
from ._base import Workflow

dramatiq_workflow/_middleware.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from ._barrier import AtMostOnceBarrier
77
from ._constants import OPTION_KEY_CALLBACKS
88
from ._helpers import workflow_with_completion_callbacks
9-
from ._serialize import unserialize_callbacks, unserialize_workflow
9+
from ._models import SerializedCompletionCallbacks
10+
from ._serialize import unserialize_workflow
1011

1112
logger = logging.getLogger(__name__)
1213

@@ -33,7 +34,7 @@ def after_process_message(
3334
if message.failed:
3435
return
3536

36-
completion_callbacks: list[dict] | None = message.options.get(OPTION_KEY_CALLBACKS)
37+
completion_callbacks: SerializedCompletionCallbacks | None = message.options.get(OPTION_KEY_CALLBACKS)
3738
if completion_callbacks is None:
3839
return
3940

@@ -53,9 +54,7 @@ def after_process_message(
5354
workflow_with_completion_callbacks(
5455
unserialize_workflow(remaining_workflow),
5556
broker,
56-
# TODO: This is somewhat inefficient because we're unserializing all callbacks
57-
# even though we are just going to serialize them again.
58-
unserialize_callbacks(completion_callbacks),
57+
completion_callbacks,
5958
).run()
6059

6160
if not propagate:

dramatiq_workflow/_models.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,5 @@ def __eq__(self, other):
3939
Message = dramatiq.Message
4040
WorkflowType = Message | Chain | Group | WithDelay
4141

42-
# CompletionCallback is a tuple containing: ID, WorkflowType, propagate
43-
CompletionCallback = tuple[str, WorkflowType | None, bool]
44-
CompletionCallbacks = list[CompletionCallback]
42+
SerializedCompletionCallback = tuple[str | None, dict | None, bool]
43+
SerializedCompletionCallbacks = list[SerializedCompletionCallback]

dramatiq_workflow/_serialize.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
import typing
22

3-
from ._models import Chain, CompletionCallbacks, Group, Message, WithDelay, WorkflowType
3+
from ._models import (
4+
Chain,
5+
Group,
6+
Message,
7+
WithDelay,
8+
WorkflowType,
9+
)
410

511

6-
def serialize_workflow(workflow: WorkflowType | None) -> typing.Any:
12+
def serialize_workflow(workflow: WorkflowType | None) -> dict | None:
713
"""
814
Return a serialized version of the workflow that can be JSON-encoded.
915
"""
@@ -69,11 +75,3 @@ def unserialize_workflow_or_none(workflow: typing.Any) -> WorkflowType | None:
6975
if workflow is None:
7076
return None
7177
return unserialize_workflow(workflow)
72-
73-
74-
def serialize_callbacks(callbacks: CompletionCallbacks) -> list[tuple[str, typing.Any, bool]]:
75-
return [(id, serialize_workflow(g), propagate) for id, g, propagate in callbacks]
76-
77-
78-
def unserialize_callbacks(callbacks: list[dict]) -> CompletionCallbacks:
79-
return [(id, unserialize_workflow_or_none(g), propagate) for id, g, propagate in callbacks]
+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import unittest
2+
3+
from dramatiq.rate_limits.backends import StubBackend
4+
5+
from .._barrier import AtMostOnceBarrier
6+
7+
8+
class AtMostOnceBarrierTests(unittest.TestCase):
9+
def setUp(self):
10+
self.backend = StubBackend()
11+
self.key = "test_barrier"
12+
self.parties = 3
13+
self.ttl = 900000
14+
self.barrier = AtMostOnceBarrier(self.backend, self.key, ttl=self.ttl)
15+
16+
def test_wait_block_true_raises(self):
17+
with self.assertRaises(ValueError) as context:
18+
self.barrier.wait(block=True)
19+
self.assertEqual(str(context.exception), "Blocking is not supported by AtMostOnceBarrier")
20+
21+
def test_wait_releases_once(self):
22+
self.barrier.create(self.parties)
23+
for _ in range(self.parties - 1):
24+
result = self.barrier.wait(block=False)
25+
self.assertFalse(result)
26+
result = self.barrier.wait(block=False)
27+
self.assertTrue(result)
28+
result = self.barrier.wait(block=False)
29+
self.assertFalse(result)
30+
31+
def test_wait_does_not_release_when_db_emptied(self):
32+
"""
33+
If the store is emptied, the barrier should not be released.
34+
"""
35+
self.barrier.create(self.parties)
36+
self.backend.db = {}
37+
for _ in range(self.parties):
38+
result = self.barrier.wait(block=False)
39+
self.assertFalse(result)
+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import unittest
2+
from unittest import mock
3+
4+
import dramatiq
5+
from dramatiq.broker import Broker
6+
from dramatiq.rate_limits.backends import StubBackend
7+
8+
from dramatiq_workflow import Chain, WorkflowMiddleware
9+
from dramatiq_workflow._barrier import AtMostOnceBarrier
10+
from dramatiq_workflow._constants import OPTION_KEY_CALLBACKS
11+
from dramatiq_workflow._serialize import serialize_workflow
12+
13+
14+
class WorkflowMiddlewareTests(unittest.TestCase):
15+
def setUp(self):
16+
# Initialize common mocks and the middleware instance for each test
17+
self.rate_limiter_backend = StubBackend()
18+
self.middleware = WorkflowMiddleware(self.rate_limiter_backend)
19+
20+
self.broker = mock.MagicMock(spec=Broker)
21+
22+
def _make_message(
23+
self, message_options: dict | None = None, message_timestamp: int = 1717526084640
24+
) -> dramatiq.broker.MessageProxy:
25+
"""
26+
Creates a dramatiq MessageProxy object with given options.
27+
"""
28+
message_id = 1 # Simplistic message ID for testing
29+
message = dramatiq.Message(
30+
message_id=str(message_id),
31+
message_timestamp=message_timestamp,
32+
queue_name="default",
33+
actor_name="test_task",
34+
args=(),
35+
kwargs={},
36+
options=message_options or {},
37+
)
38+
return dramatiq.broker.MessageProxy(message)
39+
40+
def _create_serialized_workflow(self) -> dict | None:
41+
"""
42+
Creates and serializes a simple workflow for testing.
43+
"""
44+
# Define a simple workflow (Chain with a single task)
45+
workflow = Chain(self._make_message()._message)
46+
serialized = serialize_workflow(workflow)
47+
return serialized
48+
49+
def test_after_process_message_without_callbacks(self):
50+
message = self._make_message()
51+
52+
self.middleware.after_process_message(self.broker, message)
53+
54+
self.broker.enqueue.assert_not_called()
55+
56+
def test_after_process_message_with_exception(self):
57+
message = self._make_message({OPTION_KEY_CALLBACKS: [(None, self._create_serialized_workflow(), True)]})
58+
59+
self.middleware.after_process_message(self.broker, message, exception=Exception("Test exception"))
60+
61+
self.broker.enqueue.assert_not_called()
62+
63+
def test_after_process_message_with_failed_message(self):
64+
message = self._make_message({OPTION_KEY_CALLBACKS: [(None, self._create_serialized_workflow(), True)]})
65+
message.failed = True
66+
67+
self.middleware.after_process_message(self.broker, message)
68+
69+
self.broker.enqueue.assert_not_called()
70+
71+
@mock.patch("dramatiq_workflow._base.time.time")
72+
def test_after_process_message_with_workflow(self, mock_time):
73+
mock_time.return_value = 1337
74+
message = self._make_message({OPTION_KEY_CALLBACKS: [(None, self._create_serialized_workflow(), True)]})
75+
76+
self.middleware.after_process_message(self.broker, message)
77+
78+
self.broker.enqueue.assert_called_once_with(self._make_message(message_timestamp=1337_000)._message, delay=None)
79+
80+
@mock.patch("dramatiq_workflow._base.time.time")
81+
def test_after_process_message_with_barriered_workflow(self, mock_time):
82+
mock_time.return_value = 1337
83+
barrier = AtMostOnceBarrier(self.rate_limiter_backend, "barrier_1")
84+
barrier.create(2)
85+
message = self._make_message({OPTION_KEY_CALLBACKS: [(barrier.key, self._create_serialized_workflow(), True)]})
86+
87+
self.middleware.after_process_message(self.broker, message)
88+
self.broker.enqueue.assert_not_called()
89+
90+
# Calling again, barrier should be completed now
91+
self.middleware.after_process_message(self.broker, message)
92+
self.broker.enqueue.assert_called_once_with(self._make_message(message_timestamp=1337_000)._message, delay=None)

dramatiq_workflow/tests/test_workflow.py

-1
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,6 @@ def test_nested_delays(self, time_mock):
386386
self.__make_message(
387387
1,
388388
message_timestamp=updated_timestamp,
389-
message_options={"workflow_completion_callbacks": []},
390389
),
391390
delay=20,
392391
)

0 commit comments

Comments
 (0)