Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions camel/societies/workforce/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class WorkforceEventBase(BaseModel):
"task_created",
"task_assigned",
"task_started",
"task_updated",
"task_completed",
"task_failed",
"worker_created",
Expand Down Expand Up @@ -100,6 +101,17 @@ class TaskStartedEvent(WorkforceEventBase):
worker_id: str


class TaskUpdatedEvent(WorkforceEventBase):
event_type: Literal["task_updated"] = "task_updated"
task_id: str
worker_id: Optional[str] = None
update_type: Literal["replan", "reassign", "manual"]
old_value: Optional[str] = None
new_value: Optional[str] = None
parent_task_id: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None


class TaskCompletedEvent(WorkforceEventBase):
event_type: Literal["task_completed"] = "task_completed"
task_id: str
Expand Down Expand Up @@ -136,6 +148,7 @@ class QueueStatusEvent(WorkforceEventBase):
TaskCreatedEvent,
TaskAssignedEvent,
TaskStartedEvent,
TaskUpdatedEvent,
TaskCompletedEvent,
TaskFailedEvent,
WorkerCreatedEvent,
Expand Down
56 changes: 54 additions & 2 deletions camel/societies/workforce/workforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
TaskDecomposedEvent,
TaskFailedEvent,
TaskStartedEvent,
TaskUpdatedEvent,
WorkerCreatedEvent,
)

Expand Down Expand Up @@ -1673,8 +1674,30 @@ async def _apply_recovery_strategy(
elif strategy == RecoveryStrategy.REPLAN:
# Modify the task content and retry
if recovery_decision.modified_task_content:
task.content = recovery_decision.modified_task_content
logger.info(f"Task {task.id} content modified for replan")
old_content = task.content
new_content = recovery_decision.modified_task_content

task.content = new_content
logger.info(
f"Task {task.id} content modified for replan, "
f"new_content: {new_content}"
)

task_updated_event = TaskUpdatedEvent(
task_id=task.id,
parent_task_id=task.parent.id if task.parent else None,
worker_id=task.assigned_worker_id,
update_type="replan",
old_value=old_content,
new_value=new_content,
metadata={
"quality_score": recovery_decision.quality_score,
"reasoning": recovery_decision.reasoning,
"issues": recovery_decision.issues,
},
)
for cb in self._callbacks:
cb.log_task_updated(task_updated_event)

# Repost the modified task to the same worker
await self._post_task(task, original_assignee)
Expand Down Expand Up @@ -1717,6 +1740,22 @@ async def _apply_recovery_strategy(
f"{new_worker}"
)

task_updated_event = TaskUpdatedEvent(
task_id=task.id,
parent_task_id=task.parent.id if task.parent else None,
worker_id=task.assigned_worker_id,
update_type="reassign",
old_value=old_worker,
new_value=new_worker,
metadata={
"quality_score": recovery_decision.quality_score,
"reasoning": recovery_decision.reasoning,
"issues": recovery_decision.issues,
},
)
for cb in self._callbacks:
cb.log_task_updated(task_updated_event)

elif strategy == RecoveryStrategy.DECOMPOSE:
# Decompose the task into subtasks
reason = (
Expand Down Expand Up @@ -2056,8 +2095,21 @@ def modify_task_content(self, task_id: str, new_content: str) -> bool:

for task in self._pending_tasks:
if task.id == task_id:
old_content = task.content
task.content = new_content
logger.info(f"Task {task_id} content modified.")

task_updated_event = TaskUpdatedEvent(
task_id=task.id,
parent_task_id=task.parent.id if task.parent else None,
worker_id=task.assigned_worker_id,
update_type="manual",
old_value=old_content,
new_value=new_content,
)
for cb in self._callbacks:
cb.log_task_updated(task_updated_event)

return True
logger.warning(f"Task {task_id} not found in pending tasks.")
return False
Expand Down
5 changes: 5 additions & 0 deletions camel/societies/workforce/workforce_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TaskDecomposedEvent,
TaskFailedEvent,
TaskStartedEvent,
TaskUpdatedEvent,
WorkerCreatedEvent,
WorkerDeletedEvent,
)
Expand Down Expand Up @@ -82,6 +83,10 @@ def log_task_assigned(self, event: TaskAssignedEvent) -> None:
def log_task_started(self, event: TaskStartedEvent) -> None:
pass

@abstractmethod
def log_task_updated(self, event: TaskUpdatedEvent) -> None:
pass

@abstractmethod
def log_task_completed(self, event: TaskCompletedEvent) -> None:
pass
Expand Down
16 changes: 16 additions & 0 deletions camel/societies/workforce/workforce_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
TaskDecomposedEvent,
TaskFailedEvent,
TaskStartedEvent,
TaskUpdatedEvent,
WorkerCreatedEvent,
WorkerDeletedEvent,
)
Expand Down Expand Up @@ -157,6 +158,21 @@ def log_task_started(
if event.task_id in self._task_hierarchy:
self._task_hierarchy[event.task_id]['status'] = 'processing'

def log_task_updated(self, event: TaskUpdatedEvent) -> None:
r"""Logs updates made to a task."""
self._log_event(
event_type=event.event_type,
task_id=event.task_id,
worker_id=event.worker_id,
update_type=event.update_type,
old_value=event.old_value,
new_value=event.new_value,
parent_task_id=event.parent_task_id,
metadata=event.metadata or {},
)
if event.task_id in self._task_hierarchy:
self._task_hierarchy[event.task_id]['status'] = 'updated'

def log_task_completed(self, event: TaskCompletedEvent) -> None:
r"""Logs the successful completion of a task."""
self._log_event(
Expand Down
7 changes: 7 additions & 0 deletions examples/workforce/workforce_callbacks_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
TaskDecomposedEvent,
TaskFailedEvent,
TaskStartedEvent,
TaskUpdatedEvent,
WorkerCreatedEvent,
WorkerDeletedEvent,
)
Expand Down Expand Up @@ -78,6 +79,12 @@ def log_task_started(self, event: TaskStartedEvent) -> None:
f"worker={event.worker_id}"
)

def log_task_updated(self, event: TaskUpdatedEvent) -> None:
print(
f"[PrintCallback] task_updated: task={event.task_id}, "
f"worker={event.worker_id}"
)

def log_task_completed(self, event: TaskCompletedEvent) -> None:
print(
f"[PrintCallback] task_completed: task={event.task_id}, "
Expand Down
77 changes: 46 additions & 31 deletions test/workforce/test_workforce_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TaskDecomposedEvent,
TaskFailedEvent,
TaskStartedEvent,
TaskUpdatedEvent,
WorkerCreatedEvent,
WorkerDeletedEvent,
WorkforceEvent,
Expand Down Expand Up @@ -61,6 +62,9 @@ def log_task_assigned(self, event: TaskAssignedEvent) -> None:
def log_task_started(self, event: TaskStartedEvent) -> None:
self.events.append(event)

def log_task_updated(self, event: TaskUpdatedEvent) -> None:
self.events.append(event)

def log_task_completed(self, event: TaskCompletedEvent) -> None:
self.events.append(event)

Expand Down Expand Up @@ -121,6 +125,9 @@ def log_task_assigned(self, event: TaskAssignedEvent) -> None:
def log_task_started(self, event: TaskStartedEvent) -> None:
self.events.append(event)

def log_task_updated(self, event: TaskUpdatedEvent) -> None:
self.events.append(event)

def log_task_completed(self, event: TaskCompletedEvent) -> None:
self.events.append(event)

Expand Down Expand Up @@ -179,7 +186,9 @@ def test_workforce_callback_registration_and_metrics_handling():
Workforce("CB Test - Invalid", callbacks=[object()])


def assert_event_sequence(events: list[str], min_worker_count: int):
def assert_event_sequence(
events: list[type[WorkforceEvent]], min_worker_count: int
):
"""
Validate that the given event sequence follows the expected logical order.
This version is flexible to handle:
Expand All @@ -191,7 +200,7 @@ def assert_event_sequence(events: list[str], min_worker_count: int):

# 1. Expect at least min_worker_count WorkerCreatedEvent events first
initial_worker_count = 0
while idx < n and events[idx] == "WorkerCreatedEvent":
while idx < n and events[idx] == WorkerCreatedEvent:
initial_worker_count += 1
idx += 1
assert initial_worker_count >= min_worker_count, (
Expand All @@ -200,8 +209,8 @@ def assert_event_sequence(events: list[str], min_worker_count: int):
)

# 2. Expect one main TaskCreatedEvent
assert idx < n and events[idx] == "TaskCreatedEvent", (
f"Event {idx} should be TaskCreatedEvent, got "
assert idx < n and events[idx] == TaskCreatedEvent, (
f"Event {idx} should be {TaskCreatedEvent.__name__}, got "
f"{events[idx] if idx < n else 'END'}"
)
idx += 1
Expand All @@ -210,49 +219,55 @@ def assert_event_sequence(events: list[str], min_worker_count: int):
# (depends on coordinator behavior)
# If the coordinator can't parse stub responses, it may skip
# decomposition
has_decomposition = idx < n and events[idx] == "TaskDecomposedEvent"
has_decomposition = idx < n and events[idx] == TaskDecomposedEvent
if has_decomposition:
idx += 1

# 4. Count all event types in the remaining events
all_events = events[idx:]
task_assigned_count = all_events.count("TaskAssignedEvent")
task_started_count = all_events.count("TaskStartedEvent")
task_completed_count = all_events.count("TaskCompletedEvent")
all_tasks_completed_count = all_events.count("AllTasksCompletedEvent")
task_assigned_count = sum(e is TaskAssignedEvent for e in all_events)
task_started_count = sum(e is TaskStartedEvent for e in all_events)
task_completed_count = sum(e is TaskCompletedEvent for e in all_events)
all_tasks_completed_count = sum(
e is AllTasksCompletedEvent for e in all_events
)

# 5. Validate basic invariants
# At minimum, the main task should be assigned and processed
assert (
task_assigned_count >= 1
), f"Expected at least 1 TaskAssignedEvent, got {task_assigned_count}"
assert (
task_started_count >= 1
), f"Expected at least 1 TaskStartedEvent, got {task_started_count}"
assert (
task_completed_count >= 1
), f"Expected at least 1 TaskCompletedEvent, got {task_completed_count}"
assert task_assigned_count >= 1, (
f"Expected at least 1 {TaskAssignedEvent.__name__}, "
f"got {task_assigned_count}"
)
assert task_started_count >= 1, (
f"Expected at least 1 {TaskStartedEvent.__name__}, "
f"got {task_started_count}"
)
assert task_completed_count >= 1, (
f"Expected at least 1 {TaskCompletedEvent.__name__}, "
f"got {task_completed_count}"
)

# 6. Expect exactly one AllTasksCompletedEvent at the end
assert all_tasks_completed_count == 1, (
f"Expected exactly 1 AllTasksCompletedEvent, got "
f"Expected exactly 1 {AllTasksCompletedEvent.__name__}, got "
f"{all_tasks_completed_count}"
)
assert (
events[-1] == "AllTasksCompletedEvent"
), "Last event should be AllTasksCompletedEvent"
events[-1] == AllTasksCompletedEvent
), f"Last event should be {AllTasksCompletedEvent.__name__}"

# 7. All events should be of expected types
allowed_events = {
"WorkerCreatedEvent",
"WorkerDeletedEvent",
"TaskCreatedEvent",
"TaskDecomposedEvent",
"TaskAssignedEvent",
"TaskStartedEvent",
"TaskCompletedEvent",
"TaskFailedEvent",
"AllTasksCompletedEvent",
WorkerCreatedEvent,
WorkerDeletedEvent,
TaskCreatedEvent,
TaskDecomposedEvent,
TaskAssignedEvent,
TaskStartedEvent,
TaskUpdatedEvent,
TaskCompletedEvent,
TaskFailedEvent,
AllTasksCompletedEvent,
}
for i, e in enumerate(events):
assert e in allowed_events, f"Unexpected event type at {i}: {e}"
Expand Down Expand Up @@ -348,7 +363,7 @@ def test_workforce_emits_expected_event_sequence():
workforce.process_task(human_task)

# test that the event sequence is as expected
actual_events = [e.__class__.__name__ for e in cb.events]
actual_events = [e.__class__ for e in cb.events]
assert_event_sequence(actual_events, min_worker_count=3)

# test that metrics callback methods work as expected
Expand Down