Skip to content

Commit 173e29d

Browse files
committed
fix(workforce): remove unnecessary deepcopy and improve event handling
1 parent 5de82ab commit 173e29d

File tree

2 files changed

+45
-34
lines changed

2 files changed

+45
-34
lines changed

camel/societies/workforce/workforce.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import asyncio
1717
import concurrent.futures
18-
import copy
1918
import json
2019
import os
2120
import time
@@ -1675,7 +1674,7 @@ async def _apply_recovery_strategy(
16751674
elif strategy == RecoveryStrategy.REPLAN:
16761675
# Modify the task content and retry
16771676
if recovery_decision.modified_task_content:
1678-
old_content = copy.deepcopy(task.content)
1677+
old_content = task.content
16791678
new_content = recovery_decision.modified_task_content
16801679

16811680
task.content = new_content
@@ -2096,7 +2095,7 @@ def modify_task_content(self, task_id: str, new_content: str) -> bool:
20962095

20972096
for task in self._pending_tasks:
20982097
if task.id == task_id:
2099-
old_content = copy.deepcopy(task.content)
2098+
old_content = task.content
21002099
task.content = new_content
21012100
logger.info(f"Task {task_id} content modified.")
21022101

test/workforce/test_workforce_callbacks.py

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def test_workforce_callback_registration_and_metrics_handling():
186186
Workforce("CB Test - Invalid", callbacks=[object()])
187187

188188

189-
def assert_event_sequence(events: list[str], min_worker_count: int):
189+
def assert_event_sequence(events: list[WorkforceEvent], min_worker_count: int):
190190
"""
191191
Validate that the given event sequence follows the expected logical order.
192192
This version is flexible to handle:
@@ -198,7 +198,7 @@ def assert_event_sequence(events: list[str], min_worker_count: int):
198198

199199
# 1. Expect at least min_worker_count WorkerCreatedEvent events first
200200
initial_worker_count = 0
201-
while idx < n and events[idx] == "WorkerCreatedEvent":
201+
while idx < n and events[idx] == WorkerCreatedEvent:
202202
initial_worker_count += 1
203203
idx += 1
204204
assert initial_worker_count >= min_worker_count, (
@@ -207,8 +207,8 @@ def assert_event_sequence(events: list[str], min_worker_count: int):
207207
)
208208

209209
# 2. Expect one main TaskCreatedEvent
210-
assert idx < n and events[idx] == "TaskCreatedEvent", (
211-
f"Event {idx} should be TaskCreatedEvent, got "
210+
assert idx < n and events[idx] == TaskCreatedEvent, (
211+
f"Event {idx} should be {TaskCreatedEvent.__name__}, got "
212212
f"{events[idx] if idx < n else 'END'}"
213213
)
214214
idx += 1
@@ -217,49 +217,61 @@ def assert_event_sequence(events: list[str], min_worker_count: int):
217217
# (depends on coordinator behavior)
218218
# If the coordinator can't parse stub responses, it may skip
219219
# decomposition
220-
has_decomposition = idx < n and events[idx] == "TaskDecomposedEvent"
220+
has_decomposition = idx < n and events[idx] == TaskDecomposedEvent
221221
if has_decomposition:
222222
idx += 1
223223

224224
# 4. Count all event types in the remaining events
225225
all_events = events[idx:]
226-
task_assigned_count = all_events.count("TaskAssignedEvent")
227-
task_started_count = all_events.count("TaskStartedEvent")
228-
task_completed_count = all_events.count("TaskCompletedEvent")
229-
all_tasks_completed_count = all_events.count("AllTasksCompletedEvent")
226+
task_assigned_count = sum(
227+
isinstance(e, TaskAssignedEvent) for e in all_events
228+
)
229+
task_started_count = sum(
230+
isinstance(e, TaskStartedEvent) for e in all_events
231+
)
232+
task_completed_count = sum(
233+
isinstance(e, TaskCompletedEvent) for e in all_events
234+
)
235+
all_tasks_completed_count = sum(
236+
isinstance(e, AllTasksCompletedEvent) for e in all_events
237+
)
230238

231239
# 5. Validate basic invariants
232240
# At minimum, the main task should be assigned and processed
233-
assert (
234-
task_assigned_count >= 1
235-
), f"Expected at least 1 TaskAssignedEvent, got {task_assigned_count}"
236-
assert (
237-
task_started_count >= 1
238-
), f"Expected at least 1 TaskStartedEvent, got {task_started_count}"
239-
assert (
240-
task_completed_count >= 1
241-
), f"Expected at least 1 TaskCompletedEvent, got {task_completed_count}"
241+
assert task_assigned_count >= 1, (
242+
f"Expected at least 1 {TaskAssignedEvent.__name__}, "
243+
f"got {task_assigned_count}"
244+
)
245+
assert task_started_count >= 1, (
246+
f"Expected at least 1 {TaskStartedEvent.__name__}, "
247+
f"got {task_started_count}"
248+
)
249+
assert task_completed_count >= 1, (
250+
f"Expected at least 1 {TaskCompletedEvent.__name__}, "
251+
f"got {task_completed_count}"
252+
)
242253

243254
# 6. Expect exactly one AllTasksCompletedEvent at the end
244255
assert all_tasks_completed_count == 1, (
245-
f"Expected exactly 1 AllTasksCompletedEvent, got "
256+
f"Expected exactly 1 {AllTasksCompletedEvent.__name__}, got "
246257
f"{all_tasks_completed_count}"
247258
)
248259
assert (
249-
events[-1] == "AllTasksCompletedEvent"
250-
), "Last event should be AllTasksCompletedEvent"
260+
events[-1] == AllTasksCompletedEvent
261+
), f"Last event should be {AllTasksCompletedEvent.__name__}"
251262

252263
# 7. All events should be of expected types
253264
allowed_events = {
254-
"WorkerCreatedEvent",
255-
"WorkerDeletedEvent",
256-
"TaskCreatedEvent",
257-
"TaskDecomposedEvent",
258-
"TaskAssignedEvent",
259-
"TaskStartedEvent",
260-
"TaskCompletedEvent",
261-
"TaskFailedEvent",
262-
"AllTasksCompletedEvent",
265+
WorkerCreatedEvent,
266+
WorkerDeletedEvent,
267+
TaskCreatedEvent,
268+
TaskDecomposedEvent,
269+
TaskAssignedEvent,
270+
TaskStartedEvent,
271+
TaskUpdatedEvent,
272+
TaskCompletedEvent,
273+
TaskFailedEvent,
274+
AllTasksCompletedEvent,
263275
}
264276
for i, e in enumerate(events):
265277
assert e in allowed_events, f"Unexpected event type at {i}: {e}"
@@ -355,7 +367,7 @@ def test_workforce_emits_expected_event_sequence():
355367
workforce.process_task(human_task)
356368

357369
# test that the event sequence is as expected
358-
actual_events = [e.__class__.__name__ for e in cb.events]
370+
actual_events = [e.__class__ for e in cb.events]
359371
assert_event_sequence(actual_events, min_worker_count=3)
360372

361373
# test that metrics callback methods work as expected

0 commit comments

Comments
 (0)