@@ -186,7 +186,7 @@ def test_workforce_callback_registration_and_metrics_handling():
186186 Workforce ("CB Test - Invalid" , callbacks = [object ()])
187187
188188
189- def assert_event_sequence (events : list [str ], min_worker_count : int ):
189+ def assert_event_sequence (events : list [WorkforceEvent ], min_worker_count : int ):
190190 """
191191 Validate that the given event sequence follows the expected logical order.
192192 This version is flexible to handle:
@@ -198,7 +198,7 @@ def assert_event_sequence(events: list[str], min_worker_count: int):
198198
199199 # 1. Expect at least min_worker_count WorkerCreatedEvent events first
200200 initial_worker_count = 0
201- while idx < n and events [idx ] == " WorkerCreatedEvent" :
201+ while idx < n and events [idx ] == WorkerCreatedEvent :
202202 initial_worker_count += 1
203203 idx += 1
204204 assert initial_worker_count >= min_worker_count , (
@@ -207,8 +207,8 @@ def assert_event_sequence(events: list[str], min_worker_count: int):
207207 )
208208
209209 # 2. Expect one main TaskCreatedEvent
210- assert idx < n and events [idx ] == " TaskCreatedEvent" , (
211- f"Event { idx } should be TaskCreatedEvent, got "
210+ assert idx < n and events [idx ] == TaskCreatedEvent , (
211+ f"Event { idx } should be { TaskCreatedEvent . __name__ } , got "
212212 f"{ events [idx ] if idx < n else 'END' } "
213213 )
214214 idx += 1
@@ -217,49 +217,61 @@ def assert_event_sequence(events: list[str], min_worker_count: int):
217217 # (depends on coordinator behavior)
218218 # If the coordinator can't parse stub responses, it may skip
219219 # decomposition
220- has_decomposition = idx < n and events [idx ] == " TaskDecomposedEvent"
220+ has_decomposition = idx < n and events [idx ] == TaskDecomposedEvent
221221 if has_decomposition :
222222 idx += 1
223223
224224 # 4. Count all event types in the remaining events
225225 all_events = events [idx :]
226- task_assigned_count = all_events .count ("TaskAssignedEvent" )
227- task_started_count = all_events .count ("TaskStartedEvent" )
228- task_completed_count = all_events .count ("TaskCompletedEvent" )
229- all_tasks_completed_count = all_events .count ("AllTasksCompletedEvent" )
226+ task_assigned_count = sum (
227+ isinstance (e , TaskAssignedEvent ) for e in all_events
228+ )
229+ task_started_count = sum (
230+ isinstance (e , TaskStartedEvent ) for e in all_events
231+ )
232+ task_completed_count = sum (
233+ isinstance (e , TaskCompletedEvent ) for e in all_events
234+ )
235+ all_tasks_completed_count = sum (
236+ isinstance (e , AllTasksCompletedEvent ) for e in all_events
237+ )
230238
231239 # 5. Validate basic invariants
232240 # At minimum, the main task should be assigned and processed
233- assert (
234- task_assigned_count >= 1
235- ), f"Expected at least 1 TaskAssignedEvent, got { task_assigned_count } "
236- assert (
237- task_started_count >= 1
238- ), f"Expected at least 1 TaskStartedEvent, got { task_started_count } "
239- assert (
240- task_completed_count >= 1
241- ), f"Expected at least 1 TaskCompletedEvent, got { task_completed_count } "
241+ assert task_assigned_count >= 1 , (
242+ f"Expected at least 1 { TaskAssignedEvent .__name__ } , "
243+ f"got { task_assigned_count } "
244+ )
245+ assert task_started_count >= 1 , (
246+ f"Expected at least 1 { TaskStartedEvent .__name__ } , "
247+ f"got { task_started_count } "
248+ )
249+ assert task_completed_count >= 1 , (
250+ f"Expected at least 1 { TaskCompletedEvent .__name__ } , "
251+ f"got { task_completed_count } "
252+ )
242253
243254 # 6. Expect exactly one AllTasksCompletedEvent at the end
244255 assert all_tasks_completed_count == 1 , (
245- f"Expected exactly 1 AllTasksCompletedEvent, got "
256+ f"Expected exactly 1 { AllTasksCompletedEvent . __name__ } , got "
246257 f"{ all_tasks_completed_count } "
247258 )
248259 assert (
249- events [- 1 ] == " AllTasksCompletedEvent"
250- ), "Last event should be AllTasksCompletedEvent"
260+ events [- 1 ] == AllTasksCompletedEvent
261+ ), f "Last event should be { AllTasksCompletedEvent . __name__ } "
251262
252263 # 7. All events should be of expected types
253264 allowed_events = {
254- "WorkerCreatedEvent" ,
255- "WorkerDeletedEvent" ,
256- "TaskCreatedEvent" ,
257- "TaskDecomposedEvent" ,
258- "TaskAssignedEvent" ,
259- "TaskStartedEvent" ,
260- "TaskCompletedEvent" ,
261- "TaskFailedEvent" ,
262- "AllTasksCompletedEvent" ,
265+ WorkerCreatedEvent ,
266+ WorkerDeletedEvent ,
267+ TaskCreatedEvent ,
268+ TaskDecomposedEvent ,
269+ TaskAssignedEvent ,
270+ TaskStartedEvent ,
271+ TaskUpdatedEvent ,
272+ TaskCompletedEvent ,
273+ TaskFailedEvent ,
274+ AllTasksCompletedEvent ,
263275 }
264276 for i , e in enumerate (events ):
265277 assert e in allowed_events , f"Unexpected event type at { i } : { e } "
@@ -355,7 +367,7 @@ def test_workforce_emits_expected_event_sequence():
355367 workforce .process_task (human_task )
356368
357369 # test that the event sequence is as expected
358- actual_events = [e .__class__ . __name__ for e in cb .events ]
370+ actual_events = [e .__class__ for e in cb .events ]
359371 assert_event_sequence (actual_events , min_worker_count = 3 )
360372
361373 # test that metrics callback methods work as expected
0 commit comments