Skip to content

Commit e7f5d9f

Browse files
coolbeevipfengju0213Wendong-Fan
authored
feat(workforce): add TaskUpdatedEvent to track task modifications (#3671)
Co-authored-by: Tao Sun <[email protected]> Co-authored-by: Sun Tao <[email protected]> Co-authored-by: Wendong-Fan <[email protected]> Co-authored-by: Wendong-Fan <[email protected]>
1 parent 1cae995 commit e7f5d9f

File tree

6 files changed

+141
-33
lines changed

6 files changed

+141
-33
lines changed

camel/societies/workforce/events.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class WorkforceEventBase(BaseModel):
2727
"task_created",
2828
"task_assigned",
2929
"task_started",
30+
"task_updated",
3031
"task_completed",
3132
"task_failed",
3233
"worker_created",
@@ -100,6 +101,17 @@ class TaskStartedEvent(WorkforceEventBase):
100101
worker_id: str
101102

102103

104+
class TaskUpdatedEvent(WorkforceEventBase):
105+
event_type: Literal["task_updated"] = "task_updated"
106+
task_id: str
107+
worker_id: Optional[str] = None
108+
update_type: Literal["replan", "reassign", "manual"]
109+
old_value: Optional[str] = None
110+
new_value: Optional[str] = None
111+
parent_task_id: Optional[str] = None
112+
metadata: Optional[Dict[str, Any]] = None
113+
114+
103115
class TaskCompletedEvent(WorkforceEventBase):
104116
event_type: Literal["task_completed"] = "task_completed"
105117
task_id: str
@@ -136,6 +148,7 @@ class QueueStatusEvent(WorkforceEventBase):
136148
TaskCreatedEvent,
137149
TaskAssignedEvent,
138150
TaskStartedEvent,
151+
TaskUpdatedEvent,
139152
TaskCompletedEvent,
140153
TaskFailedEvent,
141154
WorkerCreatedEvent,

camel/societies/workforce/workforce.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
TaskDecomposedEvent,
103103
TaskFailedEvent,
104104
TaskStartedEvent,
105+
TaskUpdatedEvent,
105106
WorkerCreatedEvent,
106107
)
107108

@@ -1673,8 +1674,30 @@ async def _apply_recovery_strategy(
16731674
elif strategy == RecoveryStrategy.REPLAN:
16741675
# Modify the task content and retry
16751676
if recovery_decision.modified_task_content:
1676-
task.content = recovery_decision.modified_task_content
1677-
logger.info(f"Task {task.id} content modified for replan")
1677+
old_content = task.content
1678+
new_content = recovery_decision.modified_task_content
1679+
1680+
task.content = new_content
1681+
logger.info(
1682+
f"Task {task.id} content modified for replan, "
1683+
f"new_content: {new_content}"
1684+
)
1685+
1686+
task_updated_event = TaskUpdatedEvent(
1687+
task_id=task.id,
1688+
parent_task_id=task.parent.id if task.parent else None,
1689+
worker_id=task.assigned_worker_id,
1690+
update_type="replan",
1691+
old_value=old_content,
1692+
new_value=new_content,
1693+
metadata={
1694+
"quality_score": recovery_decision.quality_score,
1695+
"reasoning": recovery_decision.reasoning,
1696+
"issues": recovery_decision.issues,
1697+
},
1698+
)
1699+
for cb in self._callbacks:
1700+
cb.log_task_updated(task_updated_event)
16781701

16791702
# Repost the modified task to the same worker
16801703
await self._post_task(task, original_assignee)
@@ -1717,6 +1740,22 @@ async def _apply_recovery_strategy(
17171740
f"{new_worker}"
17181741
)
17191742

1743+
task_updated_event = TaskUpdatedEvent(
1744+
task_id=task.id,
1745+
parent_task_id=task.parent.id if task.parent else None,
1746+
worker_id=task.assigned_worker_id,
1747+
update_type="reassign",
1748+
old_value=old_worker,
1749+
new_value=new_worker,
1750+
metadata={
1751+
"quality_score": recovery_decision.quality_score,
1752+
"reasoning": recovery_decision.reasoning,
1753+
"issues": recovery_decision.issues,
1754+
},
1755+
)
1756+
for cb in self._callbacks:
1757+
cb.log_task_updated(task_updated_event)
1758+
17201759
elif strategy == RecoveryStrategy.DECOMPOSE:
17211760
# Decompose the task into subtasks
17221761
reason = (
@@ -2056,8 +2095,21 @@ def modify_task_content(self, task_id: str, new_content: str) -> bool:
20562095

20572096
for task in self._pending_tasks:
20582097
if task.id == task_id:
2098+
old_content = task.content
20592099
task.content = new_content
20602100
logger.info(f"Task {task_id} content modified.")
2101+
2102+
task_updated_event = TaskUpdatedEvent(
2103+
task_id=task.id,
2104+
parent_task_id=task.parent.id if task.parent else None,
2105+
worker_id=task.assigned_worker_id,
2106+
update_type="manual",
2107+
old_value=old_content,
2108+
new_value=new_content,
2109+
)
2110+
for cb in self._callbacks:
2111+
cb.log_task_updated(task_updated_event)
2112+
20612113
return True
20622114
logger.warning(f"Task {task_id} not found in pending tasks.")
20632115
return False

camel/societies/workforce/workforce_callback.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
TaskDecomposedEvent,
2828
TaskFailedEvent,
2929
TaskStartedEvent,
30+
TaskUpdatedEvent,
3031
WorkerCreatedEvent,
3132
WorkerDeletedEvent,
3233
)
@@ -82,6 +83,10 @@ def log_task_assigned(self, event: TaskAssignedEvent) -> None:
8283
def log_task_started(self, event: TaskStartedEvent) -> None:
8384
pass
8485

86+
@abstractmethod
87+
def log_task_updated(self, event: TaskUpdatedEvent) -> None:
88+
pass
89+
8590
@abstractmethod
8691
def log_task_completed(self, event: TaskCompletedEvent) -> None:
8792
pass

camel/societies/workforce/workforce_logger.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
TaskDecomposedEvent,
2727
TaskFailedEvent,
2828
TaskStartedEvent,
29+
TaskUpdatedEvent,
2930
WorkerCreatedEvent,
3031
WorkerDeletedEvent,
3132
)
@@ -157,6 +158,21 @@ def log_task_started(
157158
if event.task_id in self._task_hierarchy:
158159
self._task_hierarchy[event.task_id]['status'] = 'processing'
159160

161+
def log_task_updated(self, event: TaskUpdatedEvent) -> None:
162+
r"""Logs updates made to a task."""
163+
self._log_event(
164+
event_type=event.event_type,
165+
task_id=event.task_id,
166+
worker_id=event.worker_id,
167+
update_type=event.update_type,
168+
old_value=event.old_value,
169+
new_value=event.new_value,
170+
parent_task_id=event.parent_task_id,
171+
metadata=event.metadata or {},
172+
)
173+
if event.task_id in self._task_hierarchy:
174+
self._task_hierarchy[event.task_id]['status'] = 'updated'
175+
160176
def log_task_completed(self, event: TaskCompletedEvent) -> None:
161177
r"""Logs the successful completion of a task."""
162178
self._log_event(

examples/workforce/workforce_callbacks_example.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
TaskDecomposedEvent,
3535
TaskFailedEvent,
3636
TaskStartedEvent,
37+
TaskUpdatedEvent,
3738
WorkerCreatedEvent,
3839
WorkerDeletedEvent,
3940
)
@@ -78,6 +79,12 @@ def log_task_started(self, event: TaskStartedEvent) -> None:
7879
f"worker={event.worker_id}"
7980
)
8081

82+
def log_task_updated(self, event: TaskUpdatedEvent) -> None:
83+
print(
84+
f"[PrintCallback] task_updated: task={event.task_id}, "
85+
f"worker={event.worker_id}"
86+
)
87+
8188
def log_task_completed(self, event: TaskCompletedEvent) -> None:
8289
print(
8390
f"[PrintCallback] task_completed: task={event.task_id}, "

test/workforce/test_workforce_callbacks.py

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
TaskDecomposedEvent,
2828
TaskFailedEvent,
2929
TaskStartedEvent,
30+
TaskUpdatedEvent,
3031
WorkerCreatedEvent,
3132
WorkerDeletedEvent,
3233
WorkforceEvent,
@@ -61,6 +62,9 @@ def log_task_assigned(self, event: TaskAssignedEvent) -> None:
6162
def log_task_started(self, event: TaskStartedEvent) -> None:
6263
self.events.append(event)
6364

65+
def log_task_updated(self, event: TaskUpdatedEvent) -> None:
66+
self.events.append(event)
67+
6468
def log_task_completed(self, event: TaskCompletedEvent) -> None:
6569
self.events.append(event)
6670

@@ -121,6 +125,9 @@ def log_task_assigned(self, event: TaskAssignedEvent) -> None:
121125
def log_task_started(self, event: TaskStartedEvent) -> None:
122126
self.events.append(event)
123127

128+
def log_task_updated(self, event: TaskUpdatedEvent) -> None:
129+
self.events.append(event)
130+
124131
def log_task_completed(self, event: TaskCompletedEvent) -> None:
125132
self.events.append(event)
126133

@@ -179,7 +186,9 @@ def test_workforce_callback_registration_and_metrics_handling():
179186
Workforce("CB Test - Invalid", callbacks=[object()])
180187

181188

182-
def assert_event_sequence(events: list[str], min_worker_count: int):
189+
def assert_event_sequence(
190+
events: list[type[WorkforceEvent]], min_worker_count: int
191+
):
183192
"""
184193
Validate that the given event sequence follows the expected logical order.
185194
This version is flexible to handle:
@@ -191,7 +200,7 @@ def assert_event_sequence(events: list[str], min_worker_count: int):
191200

192201
# 1. Expect at least min_worker_count WorkerCreatedEvent events first
193202
initial_worker_count = 0
194-
while idx < n and events[idx] == "WorkerCreatedEvent":
203+
while idx < n and events[idx] == WorkerCreatedEvent:
195204
initial_worker_count += 1
196205
idx += 1
197206
assert initial_worker_count >= min_worker_count, (
@@ -200,8 +209,8 @@ def assert_event_sequence(events: list[str], min_worker_count: int):
200209
)
201210

202211
# 2. Expect one main TaskCreatedEvent
203-
assert idx < n and events[idx] == "TaskCreatedEvent", (
204-
f"Event {idx} should be TaskCreatedEvent, got "
212+
assert idx < n and events[idx] == TaskCreatedEvent, (
213+
f"Event {idx} should be {TaskCreatedEvent.__name__}, got "
205214
f"{events[idx] if idx < n else 'END'}"
206215
)
207216
idx += 1
@@ -210,49 +219,55 @@ def assert_event_sequence(events: list[str], min_worker_count: int):
210219
# (depends on coordinator behavior)
211220
# If the coordinator can't parse stub responses, it may skip
212221
# decomposition
213-
has_decomposition = idx < n and events[idx] == "TaskDecomposedEvent"
222+
has_decomposition = idx < n and events[idx] == TaskDecomposedEvent
214223
if has_decomposition:
215224
idx += 1
216225

217226
# 4. Count all event types in the remaining events
218227
all_events = events[idx:]
219-
task_assigned_count = all_events.count("TaskAssignedEvent")
220-
task_started_count = all_events.count("TaskStartedEvent")
221-
task_completed_count = all_events.count("TaskCompletedEvent")
222-
all_tasks_completed_count = all_events.count("AllTasksCompletedEvent")
228+
task_assigned_count = sum(e is TaskAssignedEvent for e in all_events)
229+
task_started_count = sum(e is TaskStartedEvent for e in all_events)
230+
task_completed_count = sum(e is TaskCompletedEvent for e in all_events)
231+
all_tasks_completed_count = sum(
232+
e is AllTasksCompletedEvent for e in all_events
233+
)
223234

224235
# 5. Validate basic invariants
225236
# At minimum, the main task should be assigned and processed
226-
assert (
227-
task_assigned_count >= 1
228-
), f"Expected at least 1 TaskAssignedEvent, got {task_assigned_count}"
229-
assert (
230-
task_started_count >= 1
231-
), f"Expected at least 1 TaskStartedEvent, got {task_started_count}"
232-
assert (
233-
task_completed_count >= 1
234-
), f"Expected at least 1 TaskCompletedEvent, got {task_completed_count}"
237+
assert task_assigned_count >= 1, (
238+
f"Expected at least 1 {TaskAssignedEvent.__name__}, "
239+
f"got {task_assigned_count}"
240+
)
241+
assert task_started_count >= 1, (
242+
f"Expected at least 1 {TaskStartedEvent.__name__}, "
243+
f"got {task_started_count}"
244+
)
245+
assert task_completed_count >= 1, (
246+
f"Expected at least 1 {TaskCompletedEvent.__name__}, "
247+
f"got {task_completed_count}"
248+
)
235249

236250
# 6. Expect exactly one AllTasksCompletedEvent at the end
237251
assert all_tasks_completed_count == 1, (
238-
f"Expected exactly 1 AllTasksCompletedEvent, got "
252+
f"Expected exactly 1 {AllTasksCompletedEvent.__name__}, got "
239253
f"{all_tasks_completed_count}"
240254
)
241255
assert (
242-
events[-1] == "AllTasksCompletedEvent"
243-
), "Last event should be AllTasksCompletedEvent"
256+
events[-1] == AllTasksCompletedEvent
257+
), f"Last event should be {AllTasksCompletedEvent.__name__}"
244258

245259
# 7. All events should be of expected types
246260
allowed_events = {
247-
"WorkerCreatedEvent",
248-
"WorkerDeletedEvent",
249-
"TaskCreatedEvent",
250-
"TaskDecomposedEvent",
251-
"TaskAssignedEvent",
252-
"TaskStartedEvent",
253-
"TaskCompletedEvent",
254-
"TaskFailedEvent",
255-
"AllTasksCompletedEvent",
261+
WorkerCreatedEvent,
262+
WorkerDeletedEvent,
263+
TaskCreatedEvent,
264+
TaskDecomposedEvent,
265+
TaskAssignedEvent,
266+
TaskStartedEvent,
267+
TaskUpdatedEvent,
268+
TaskCompletedEvent,
269+
TaskFailedEvent,
270+
AllTasksCompletedEvent,
256271
}
257272
for i, e in enumerate(events):
258273
assert e in allowed_events, f"Unexpected event type at {i}: {e}"
@@ -348,7 +363,7 @@ def test_workforce_emits_expected_event_sequence():
348363
workforce.process_task(human_task)
349364

350365
# test that the event sequence is as expected
351-
actual_events = [e.__class__.__name__ for e in cb.events]
366+
actual_events = [e.__class__ for e in cb.events]
352367
assert_event_sequence(actual_events, min_worker_count=3)
353368

354369
# test that metrics callback methods work as expected

0 commit comments

Comments
 (0)