8
8
from ._constants import CALLBACK_BARRIER_TTL , OPTION_KEY_CALLBACKS
9
9
from ._helpers import workflow_with_completion_callbacks
10
10
from ._middleware import WorkflowMiddleware , workflow_noop
11
- from ._models import Barrier , Chain , CompletionCallbacks , Group , Message , WithDelay , WorkflowType
12
- from ._serialize import serialize_callbacks , serialize_workflow
11
+ from ._models import Barrier , Chain , Group , Message , SerializedCompletionCallbacks , WithDelay , WorkflowType
12
+ from ._serialize import serialize_workflow
13
13
14
14
logger = logging .getLogger (__name__ )
15
15
@@ -91,15 +91,15 @@ def __init__(
91
91
self .broker = broker or dramatiq .get_broker ()
92
92
93
93
self ._delay = None
94
- self ._completion_callbacks = []
94
+ self ._completion_callbacks : SerializedCompletionCallbacks | None = None
95
95
96
96
while isinstance (self .workflow , WithDelay ):
97
97
self ._delay = (self ._delay or 0 ) + self .workflow .delay
98
98
self .workflow = self .workflow .task
99
99
100
100
def run (self ):
101
101
current = self .workflow
102
- completion_callbacks = self ._completion_callbacks . copy ()
102
+ completion_callbacks = self ._completion_callbacks or []
103
103
104
104
if isinstance (current , Message ):
105
105
current = self .__augment_message (current , completion_callbacks )
@@ -115,7 +115,10 @@ def run(self):
115
115
task = tasks .pop (0 )
116
116
if tasks :
117
117
completion_id = self .__create_barrier (1 )
118
- completion_callbacks .append ((completion_id , Chain (* tasks ), False ))
118
+ completion_callbacks = [
119
+ * completion_callbacks ,
120
+ (completion_id , serialize_workflow (Chain (* tasks )), False ),
121
+ ]
119
122
self .__workflow_with_completion_callbacks (task , completion_callbacks ).run ()
120
123
return
121
124
@@ -126,7 +129,7 @@ def run(self):
126
129
return
127
130
128
131
completion_id = self .__create_barrier (len (tasks ))
129
- completion_callbacks . append (( completion_id , None , True ))
132
+ completion_callbacks = [ * completion_callbacks , ( completion_id , None , True )]
130
133
for task in tasks :
131
134
self .__workflow_with_completion_callbacks (task , completion_callbacks ).run ()
132
135
return
@@ -141,18 +144,22 @@ def __workflow_with_completion_callbacks(self, task, completion_callbacks) -> "W
141
144
delay = self ._delay ,
142
145
)
143
146
144
- def __schedule_noop (self , completion_callbacks : CompletionCallbacks ):
147
+ def __schedule_noop (self , completion_callbacks : SerializedCompletionCallbacks ):
145
148
noop_message = workflow_noop .message ()
146
149
noop_message = self .__augment_message (noop_message , completion_callbacks )
147
150
self .broker .enqueue (noop_message , delay = self ._delay )
148
151
149
- def __augment_message (self , message : Message , completion_callbacks : CompletionCallbacks ) -> Message :
152
+ def __augment_message (self , message : Message , completion_callbacks : SerializedCompletionCallbacks ) -> Message :
153
+ options = {}
154
+ if completion_callbacks :
155
+ options = {OPTION_KEY_CALLBACKS : completion_callbacks }
156
+
150
157
return message .copy (
151
158
# We reset the message timestamp to better represent the time the
152
159
# message was actually enqueued. This is to avoid tripping the max_age
153
160
# check in the broker.
154
161
message_timestamp = time .time () * 1000 ,
155
- options = { OPTION_KEY_CALLBACKS : serialize_callbacks ( completion_callbacks )} ,
162
+ options = options ,
156
163
)
157
164
158
165
@property
@@ -170,7 +177,7 @@ def __rate_limiter_backend(self):
170
177
)
171
178
return self .__cached_rate_limiter_backend
172
179
173
- def __create_barrier (self , count : int ):
180
+ def __create_barrier (self , count : int ) -> str | None :
174
181
if count == 1 :
175
182
# No need to create a distributed barrier if there is only one task
176
183
return None
0 commit comments