Skip to content

Commit 2aaf564

Browse files
committed
Add tests
1 parent 7294e58 commit 2aaf564

File tree

3 files changed

+97
-2
lines changed

3 files changed

+97
-2
lines changed

dramatiq_workflow/_base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,16 @@ def __schedule_noop(self, completion_callbacks: SerializedCompletionCallbacks):
150150
self.broker.enqueue(noop_message, delay=self._delay)
151151

152152
def __augment_message(self, message: Message, completion_callbacks: SerializedCompletionCallbacks) -> Message:
153+
options = {}
154+
if completion_callbacks:
155+
options = {OPTION_KEY_CALLBACKS: completion_callbacks}
156+
153157
return message.copy(
154158
# We reset the message timestamp to better represent the time the
155159
# message was actually enqueued. This is to avoid tripping the max_age
156160
# check in the broker.
157161
message_timestamp=time.time() * 1000,
158-
options={OPTION_KEY_CALLBACKS: completion_callbacks},
162+
options=options,
159163
)
160164

161165
@property
+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import unittest
2+
from unittest import mock
3+
4+
import dramatiq
5+
from dramatiq.broker import Broker
6+
from dramatiq.rate_limits.backends import StubBackend
7+
8+
from dramatiq_workflow import Chain, WorkflowMiddleware
9+
from dramatiq_workflow._barrier import AtMostOnceBarrier
10+
from dramatiq_workflow._constants import OPTION_KEY_CALLBACKS
11+
from dramatiq_workflow._serialize import serialize_workflow
12+
13+
14+
class WorkflowMiddlewareTests(unittest.TestCase):
15+
def setUp(self):
16+
# Initialize common mocks and the middleware instance for each test
17+
self.rate_limiter_backend = StubBackend()
18+
self.middleware = WorkflowMiddleware(self.rate_limiter_backend)
19+
20+
self.broker = mock.MagicMock(spec=Broker)
21+
22+
def _make_message(
23+
self, message_options: dict | None = None, message_timestamp: int = 1717526084640
24+
) -> dramatiq.broker.MessageProxy:
25+
"""
26+
Creates a dramatiq MessageProxy object with given options.
27+
"""
28+
message_id = 1 # Simplistic message ID for testing
29+
message = dramatiq.Message(
30+
message_id=str(message_id),
31+
message_timestamp=message_timestamp,
32+
queue_name="default",
33+
actor_name="test_task",
34+
args=(),
35+
kwargs={},
36+
options=message_options or {},
37+
)
38+
return dramatiq.broker.MessageProxy(message)
39+
40+
def _create_serialized_workflow(self) -> dict | None:
41+
"""
42+
Creates and serializes a simple workflow for testing.
43+
"""
44+
# Define a simple workflow (Chain with a single task)
45+
workflow = Chain(self._make_message()._message)
46+
serialized = serialize_workflow(workflow)
47+
return serialized
48+
49+
def test_after_process_message_without_callbacks(self):
50+
message = self._make_message()
51+
52+
self.middleware.after_process_message(self.broker, message)
53+
54+
self.broker.enqueue.assert_not_called()
55+
56+
def test_after_process_message_with_exception(self):
57+
message = self._make_message({OPTION_KEY_CALLBACKS: [(None, self._create_serialized_workflow(), True)]})
58+
59+
self.middleware.after_process_message(self.broker, message, exception=Exception("Test exception"))
60+
61+
self.broker.enqueue.assert_not_called()
62+
63+
def test_after_process_message_with_failed_message(self):
64+
message = self._make_message({OPTION_KEY_CALLBACKS: [(None, self._create_serialized_workflow(), True)]})
65+
message.failed = True
66+
67+
self.middleware.after_process_message(self.broker, message)
68+
69+
self.broker.enqueue.assert_not_called()
70+
71+
@mock.patch("dramatiq_workflow._base.time.time")
72+
def test_after_process_message_with_workflow(self, mock_time):
73+
mock_time.return_value = 1337
74+
message = self._make_message({OPTION_KEY_CALLBACKS: [(None, self._create_serialized_workflow(), True)]})
75+
76+
self.middleware.after_process_message(self.broker, message)
77+
78+
self.broker.enqueue.assert_called_once_with(self._make_message(message_timestamp=1337_000)._message, delay=None)
79+
80+
@mock.patch("dramatiq_workflow._base.time.time")
81+
def test_after_process_message_with_barriered_workflow(self, mock_time):
82+
mock_time.return_value = 1337
83+
barrier = AtMostOnceBarrier(self.rate_limiter_backend, "barrier_1")
84+
barrier.create(2)
85+
message = self._make_message({OPTION_KEY_CALLBACKS: [(barrier.key, self._create_serialized_workflow(), True)]})
86+
87+
self.middleware.after_process_message(self.broker, message)
88+
self.broker.enqueue.assert_not_called()
89+
90+
# Calling again, barrier should be completed now
91+
self.middleware.after_process_message(self.broker, message)
92+
self.broker.enqueue.assert_called_once_with(self._make_message(message_timestamp=1337_000)._message, delay=None)

dramatiq_workflow/tests/test_workflow.py

-1
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,6 @@ def test_nested_delays(self, time_mock):
377377
self.__make_message(
378378
1,
379379
message_timestamp=updated_timestamp,
380-
message_options={"workflow_completion_callbacks": []},
381380
),
382381
delay=20,
383382
)

0 commit comments

Comments
 (0)