Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions camel/societies/workforce/role_playing_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# ========= Copyright 2023-2026 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations

from typing import Dict, List, Optional
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional

from colorama import Fore

Expand All @@ -31,6 +31,9 @@
from camel.societies.workforce.worker import Worker
from camel.tasks.task import Task, TaskState, is_task_result_insufficient

if TYPE_CHECKING:
from camel.responses import ChatAgentResponse


class RolePlayingWorker(Worker):
r"""A worker node that contains a role playing.
Expand Down Expand Up @@ -98,7 +101,12 @@ def __init__(
self.user_agent_kwargs = user_agent_kwargs

async def _process_task(
self, task: Task, dependencies: List[Task]
self,
task: Task,
dependencies: List[Task],
stream_callback: Optional[
Callable[["ChatAgentResponse"], Optional[Awaitable[None]]]
] = None,
) -> TaskState:
r"""Processes a task leveraging its dependencies through role-playing.

Expand Down
4 changes: 4 additions & 0 deletions camel/societies/workforce/single_agent_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,10 @@ async def _process_task(
if isinstance(response, AsyncStreamingChatAgentResponse):
task_result = None
async for chunk in response:
if stream_callback:
maybe = stream_callback(chunk)
if asyncio.iscoroutine(maybe):
await maybe
if chunk.msg and chunk.msg.parsed:
task_result = chunk.msg.parsed
response_content = chunk.msg.content
Expand Down
41 changes: 38 additions & 3 deletions camel/societies/workforce/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from typing import List, Optional, Set
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Set

from colorama import Fore

Expand All @@ -25,6 +25,9 @@
from camel.societies.workforce.utils import check_if_running
from camel.tasks.task import Task, TaskState

if TYPE_CHECKING:
from camel.responses import ChatAgentResponse

logger = logging.getLogger(__name__)


Expand All @@ -46,13 +49,24 @@ def __init__(
super().__init__(description, node_id=node_id)
self._active_task_ids: Set[str] = set()
self._running_tasks: Set[asyncio.Task] = set()
self._stream_callback: Optional[
Callable[
["ChatAgentResponse", str, str],
Optional[Awaitable[None]],
]
] = None

def __repr__(self):
return f"Worker node {self.node_id} ({self.description})"

@abstractmethod
async def _process_task(
self, task: Task, dependencies: List[Task]
self,
task: Task,
dependencies: List[Task],
stream_callback: Optional[
Callable[["ChatAgentResponse"], Optional[Awaitable[None]]]
] = None,
) -> TaskState:
r"""Processes a task based on its dependencies.

Expand Down Expand Up @@ -80,6 +94,18 @@ def _get_dep_tasks_info(dependencies: List[Task]) -> str:
def set_channel(self, channel: TaskChannel):
self._channel = channel

def set_stream_callback(
self,
stream_callback: Optional[
Callable[
["ChatAgentResponse", str, str],
Optional[Awaitable[None]],
]
],
) -> None:
r"""Set streaming callback for worker token/chunk updates."""
self._stream_callback = stream_callback

async def _process_single_task(self, task: Task) -> None:
r"""Process a single task and handle its completion/failure."""
try:
Expand All @@ -89,8 +115,17 @@ async def _process_single_task(self, task: Task) -> None:
f"{Fore.RESET}"
)

async def _on_chunk(chunk: "ChatAgentResponse") -> None:
if self._stream_callback is None:
return
maybe = self._stream_callback(chunk, self.node_id, task.id)
if asyncio.iscoroutine(maybe):
await maybe

# Process the task
task_state = await self._process_task(task, task.dependencies)
task_state = await self._process_task(
task, task.dependencies, stream_callback=_on_chunk
)

# Update the result and status of the task
task.set_state(task_state)
Expand Down
84 changes: 84 additions & 0 deletions camel/societies/workforce/workforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Coroutine,
Deque,
Expand Down Expand Up @@ -247,6 +248,10 @@ class Workforce(BaseNode):
is added automatically. If at least one provided callback
implements :class:`WorkforceMetrics`, no default logger is added.
(default: :obj:`None`)
stream_callback (Optional[Callable], optional): Real-time callback for
worker streaming output chunks. Called as
``(worker_id, task_id, text, mode)``.
(default: :obj:`None`)
mode (WorkforceMode, optional): The execution mode for task
processing. AUTO_DECOMPOSE mode uses intelligent recovery
strategies (decompose, replan, etc.) when tasks fail.
Expand Down Expand Up @@ -327,6 +332,9 @@ def __init__(
task_timeout_seconds: Optional[float] = None,
mode: WorkforceMode = WorkforceMode.AUTO_DECOMPOSE,
callbacks: Optional[List[WorkforceCallback]] = None,
stream_callback: Optional[
Callable[[str, str, str, str], Optional[Awaitable[None]]]
] = None,
failure_handling_config: Optional[
Union[FailureHandlingConfig, Dict[str, Any]]
] = None,
Expand Down Expand Up @@ -384,6 +392,9 @@ def __init__(
self.snapshot_interval: float = 30.0
# Shared memory UUID tracking to prevent re-sharing duplicates
self._shared_memory_uuids: Set[str] = set()
# Optional user callback for worker streaming text chunks.
self._user_stream_callback = stream_callback
self._stream_progress: Dict[Tuple[str, str], str] = {}
self._initialize_callbacks(callbacks)

# Set up coordinator agent with default system message
Expand Down Expand Up @@ -539,6 +550,7 @@ def __init__(

# Shared context utility for workflow management (created lazily)
self._shared_context_utility: Optional["ContextUtility"] = None
self._sync_child_stream_callbacks()

# ------------------------------------------------------------------
# Helper for propagating pause control to externally supplied agents
Expand Down Expand Up @@ -584,6 +596,69 @@ def _initialize_callbacks(
for child in self._children:
self._notify_worker_created(child)

def _sync_child_stream_callbacks(self) -> None:
r"""Propagate stream callback settings to all child nodes."""
for child in self._children:
if isinstance(child, Worker):
child.set_stream_callback(self._on_worker_stream_chunk)
elif isinstance(child, Workforce):
child.set_stream_callback(self._user_stream_callback)

def set_stream_callback(
self,
stream_callback: Optional[
Callable[[str, str, str, str], Optional[Awaitable[None]]]
],
) -> Workforce:
r"""Set callback for real-time worker streaming output chunks.

Callback arguments:
worker_id (str): ID of the worker emitting this chunk.
task_id (str): ID of the task currently processed.
text (str): Incremental text content for this chunk.
mode (str): Chunk mode reported by model ("delta"/"accumulate").
"""
self._user_stream_callback = stream_callback
self._stream_progress.clear()
self._sync_child_stream_callbacks()
return self

async def _on_worker_stream_chunk(
self, chunk: "ChatAgentResponse", worker_id: str, task_id: str
) -> None:
r"""Normalize worker stream chunks and dispatch to user callback."""
if self._user_stream_callback is None:
return

if chunk.msg is None or chunk.msg.content is None:
return

mode = "accumulate"
if chunk.info:
mode = chunk.info.get("stream_accumulate_mode", mode)
if "mode" in chunk.info:
mode = chunk.info["mode"]

content = chunk.msg.content
key = (worker_id, task_id)

if mode == "accumulate":
previous = self._stream_progress.get(key, "")
if content.startswith(previous):
text = content[len(previous) :]
else:
text = content
self._stream_progress[key] = content
else:
text = content

if not text:
return

maybe = self._user_stream_callback(worker_id, task_id, text, mode)
if asyncio.iscoroutine(maybe):
await maybe

def _notify_worker_created(
self,
worker_node: BaseNode,
Expand Down Expand Up @@ -1388,6 +1463,11 @@ def _cleanup_task_tracking(self, task_id: str) -> None:
if task_id in self._assignees:
del self._assignees[task_id]

# Clean streaming progress for completed/failed task chunks.
stale_keys = [k for k in self._stream_progress if k[1] == task_id]
for key in stale_keys:
del self._stream_progress[key]

def _decompose_task(
self,
task: Task,
Expand Down Expand Up @@ -2911,6 +2991,7 @@ def add_single_agent_worker(
enable_workflow_memory=enable_workflow_memory,
)
self._children.append(worker_node)
self._sync_child_stream_callbacks()

# If we have a channel set up, set it for the new worker
if hasattr(self, '_channel') and self._channel is not None:
Expand Down Expand Up @@ -2988,6 +3069,7 @@ def add_role_playing_worker(
use_structured_output_handler=self.use_structured_output_handler,
)
self._children.append(worker_node)
self._sync_child_stream_callbacks()

# If we have a channel set up, set it for the new worker
if hasattr(self, '_channel') and self._channel is not None:
Expand Down Expand Up @@ -3024,6 +3106,7 @@ def add_workforce(self, workforce: Workforce) -> Workforce:
# control of worker agents only.
workforce._pause_event = self._pause_event
self._children.append(workforce)
self._sync_child_stream_callbacks()

# If we have a channel set up, set it for the new workforce
if hasattr(self, '_channel') and self._channel is not None:
Expand Down Expand Up @@ -4120,6 +4203,7 @@ async def _create_worker_node_for_task(self, task: Task) -> Worker:
)

self._children.append(new_node)
self._sync_child_stream_callbacks()

self._notify_worker_created(
new_node,
Expand Down
10 changes: 10 additions & 0 deletions examples/workforce/workforce_callbacks_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,19 @@ def main() -> None:
model=model,
)

# Real-time streaming callback from worker execution.
# You can filter by worker_id/task_id to only consume selected streams.
def on_worker_stream(
worker_id: str, task_id: str, text: str, mode: str
) -> None:
if not text.strip():
return
print(f"[Stream][{worker_id}][{task_id}][{mode}] {text}", end="")

workforce = Workforce(
"Workforce Callbacks Demo",
callbacks=[WorkforceLogger('demo-logger'), PrintCallback()],
stream_callback=on_worker_stream,
use_structured_output_handler=True,
failure_handling_config=FailureHandlingConfig(enabled_strategies=[]),
default_model=model,
Expand Down
Loading