From 801eddfa3588b3d8e4cef045e63bb6cf5feaa691 Mon Sep 17 00:00:00 2001 From: Sun Tao <2605127667@qq.com> Date: Tue, 10 Feb 2026 17:54:43 +0800 Subject: [PATCH] update --- .../workforce/role_playing_worker.py | 12 +- .../workforce/single_agent_worker.py | 4 + camel/societies/workforce/worker.py | 41 ++++++- camel/societies/workforce/workforce.py | 84 ++++++++++++++ .../workforce/workforce_callbacks_example.py | 10 ++ test/workforce/test_workforce_single_agent.py | 108 +++++++++++++++++- 6 files changed, 250 insertions(+), 9 deletions(-) diff --git a/camel/societies/workforce/role_playing_worker.py b/camel/societies/workforce/role_playing_worker.py index 54c26054ae..00454078ca 100644 --- a/camel/societies/workforce/role_playing_worker.py +++ b/camel/societies/workforce/role_playing_worker.py @@ -13,7 +13,7 @@ # ========= Copyright 2023-2026 @ CAMEL-AI.org. All Rights Reserved. ========= from __future__ import annotations -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional from colorama import Fore @@ -31,6 +31,9 @@ from camel.societies.workforce.worker import Worker from camel.tasks.task import Task, TaskState, is_task_result_insufficient +if TYPE_CHECKING: + from camel.responses import ChatAgentResponse + class RolePlayingWorker(Worker): r"""A worker node that contains a role playing. @@ -98,7 +101,12 @@ def __init__( self.user_agent_kwargs = user_agent_kwargs async def _process_task( - self, task: Task, dependencies: List[Task] + self, + task: Task, + dependencies: List[Task], + stream_callback: Optional[ + Callable[["ChatAgentResponse"], Optional[Awaitable[None]]] + ] = None, ) -> TaskState: r"""Processes a task leveraging its dependencies through role-playing. diff --git a/camel/societies/workforce/single_agent_worker.py b/camel/societies/workforce/single_agent_worker.py index 16cd2ef273..8f9b996266 100644 --- a/camel/societies/workforce/single_agent_worker.py +++ b/camel/societies/workforce/single_agent_worker.py @@ -436,6 +436,10 @@ async def _process_task( if isinstance(response, AsyncStreamingChatAgentResponse): task_result = None async for chunk in response: + if stream_callback: + maybe = stream_callback(chunk) + if asyncio.iscoroutine(maybe): + await maybe if chunk.msg and chunk.msg.parsed: task_result = chunk.msg.parsed response_content = chunk.msg.content diff --git a/camel/societies/workforce/worker.py b/camel/societies/workforce/worker.py index 2ea1c6d136..927db7a91c 100644 --- a/camel/societies/workforce/worker.py +++ b/camel/societies/workforce/worker.py @@ -16,7 +16,7 @@ import asyncio import logging from abc import ABC, abstractmethod -from typing import List, Optional, Set +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Set from colorama import Fore @@ -25,6 +25,9 @@ from camel.societies.workforce.utils import check_if_running from camel.tasks.task import Task, TaskState +if TYPE_CHECKING: + from camel.responses import ChatAgentResponse + logger = logging.getLogger(__name__) @@ -46,13 +49,24 @@ def __init__( super().__init__(description, node_id=node_id) self._active_task_ids: Set[str] = set() self._running_tasks: Set[asyncio.Task] = set() + self._stream_callback: Optional[ + Callable[ + ["ChatAgentResponse", str, str], + Optional[Awaitable[None]], + ] + ] = None def __repr__(self): return f"Worker node {self.node_id} ({self.description})" @abstractmethod async def _process_task( - self, task: Task, dependencies: List[Task] + self, + task: Task, + dependencies: List[Task], + stream_callback: Optional[ + Callable[["ChatAgentResponse"], Optional[Awaitable[None]]] + ] = None, ) -> TaskState: r"""Processes a task based on its dependencies. @@ -80,6 +94,18 @@ def _get_dep_tasks_info(dependencies: List[Task]) -> str: def set_channel(self, channel: TaskChannel): self._channel = channel + def set_stream_callback( + self, + stream_callback: Optional[ + Callable[ + ["ChatAgentResponse", str, str], + Optional[Awaitable[None]], + ] + ], + ) -> None: + r"""Set streaming callback for worker token/chunk updates.""" + self._stream_callback = stream_callback + async def _process_single_task(self, task: Task) -> None: r"""Process a single task and handle its completion/failure.""" try: @@ -89,8 +115,17 @@ async def _process_single_task(self, task: Task) -> None: f"{Fore.RESET}" ) + async def _on_chunk(chunk: "ChatAgentResponse") -> None: + if self._stream_callback is None: + return + maybe = self._stream_callback(chunk, self.node_id, task.id) + if asyncio.iscoroutine(maybe): + await maybe + # Process the task - task_state = await self._process_task(task, task.dependencies) + task_state = await self._process_task( + task, task.dependencies, stream_callback=_on_chunk + ) # Update the result and status of the task task.set_state(task_state) diff --git a/camel/societies/workforce/workforce.py b/camel/societies/workforce/workforce.py index 4f4a69abbb..f7961836d4 100644 --- a/camel/societies/workforce/workforce.py +++ b/camel/societies/workforce/workforce.py @@ -24,6 +24,7 @@ from typing import ( TYPE_CHECKING, Any, + Awaitable, Callable, Coroutine, Deque, @@ -247,6 +248,10 @@ class Workforce(BaseNode): is added automatically. If at least one provided callback implements :class:`WorkforceMetrics`, no default logger is added. (default: :obj:`None`) + stream_callback (Optional[Callable], optional): Real-time callback for + worker streaming output chunks. Called as + ``(worker_id, task_id, text, mode)``. + (default: :obj:`None`) mode (WorkforceMode, optional): The execution mode for task processing. AUTO_DECOMPOSE mode uses intelligent recovery strategies (decompose, replan, etc.) when tasks fail. @@ -327,6 +332,9 @@ def __init__( task_timeout_seconds: Optional[float] = None, mode: WorkforceMode = WorkforceMode.AUTO_DECOMPOSE, callbacks: Optional[List[WorkforceCallback]] = None, + stream_callback: Optional[ + Callable[[str, str, str, str], Optional[Awaitable[None]]] + ] = None, failure_handling_config: Optional[ Union[FailureHandlingConfig, Dict[str, Any]] ] = None, @@ -384,6 +392,9 @@ def __init__( self.snapshot_interval: float = 30.0 # Shared memory UUID tracking to prevent re-sharing duplicates self._shared_memory_uuids: Set[str] = set() + # Optional user callback for worker streaming text chunks. + self._user_stream_callback = stream_callback + self._stream_progress: Dict[Tuple[str, str], str] = {} self._initialize_callbacks(callbacks) # Set up coordinator agent with default system message @@ -539,6 +550,7 @@ def __init__( # Shared context utility for workflow management (created lazily) self._shared_context_utility: Optional["ContextUtility"] = None + self._sync_child_stream_callbacks() # ------------------------------------------------------------------ # Helper for propagating pause control to externally supplied agents @@ -584,6 +596,69 @@ def _initialize_callbacks( for child in self._children: self._notify_worker_created(child) + def _sync_child_stream_callbacks(self) -> None: + r"""Propagate stream callback settings to all child nodes.""" + for child in self._children: + if isinstance(child, Worker): + child.set_stream_callback(self._on_worker_stream_chunk) + elif isinstance(child, Workforce): + child.set_stream_callback(self._user_stream_callback) + + def set_stream_callback( + self, + stream_callback: Optional[ + Callable[[str, str, str, str], Optional[Awaitable[None]]] + ], + ) -> Workforce: + r"""Set callback for real-time worker streaming output chunks. + + Callback arguments: + worker_id (str): ID of the worker emitting this chunk. + task_id (str): ID of the task currently processed. + text (str): Incremental text content for this chunk. + mode (str): Chunk mode reported by model ("delta"/"accumulate"). + """ + self._user_stream_callback = stream_callback + self._stream_progress.clear() + self._sync_child_stream_callbacks() + return self + + async def _on_worker_stream_chunk( + self, chunk: "ChatAgentResponse", worker_id: str, task_id: str + ) -> None: + r"""Normalize worker stream chunks and dispatch to user callback.""" + if self._user_stream_callback is None: + return + + if chunk.msg is None or chunk.msg.content is None: + return + + mode = "accumulate" + if chunk.info: + mode = chunk.info.get("stream_accumulate_mode", mode) + if "mode" in chunk.info: + mode = chunk.info["mode"] + + content = chunk.msg.content + key = (worker_id, task_id) + + if mode == "accumulate": + previous = self._stream_progress.get(key, "") + if content.startswith(previous): + text = content[len(previous) :] + else: + text = content + self._stream_progress[key] = content + else: + text = content + + if not text: + return + + maybe = self._user_stream_callback(worker_id, task_id, text, mode) + if asyncio.iscoroutine(maybe): + await maybe + def _notify_worker_created( self, worker_node: BaseNode, @@ -1388,6 +1463,11 @@ def _cleanup_task_tracking(self, task_id: str) -> None: if task_id in self._assignees: del self._assignees[task_id] + # Clean streaming progress for completed/failed task chunks. + stale_keys = [k for k in self._stream_progress if k[1] == task_id] + for key in stale_keys: + del self._stream_progress[key] + def _decompose_task( self, task: Task, @@ -2911,6 +2991,7 @@ def add_single_agent_worker( enable_workflow_memory=enable_workflow_memory, ) self._children.append(worker_node) + self._sync_child_stream_callbacks() # If we have a channel set up, set it for the new worker if hasattr(self, '_channel') and self._channel is not None: @@ -2988,6 +3069,7 @@ def add_role_playing_worker( use_structured_output_handler=self.use_structured_output_handler, ) self._children.append(worker_node) + self._sync_child_stream_callbacks() # If we have a channel set up, set it for the new worker if hasattr(self, '_channel') and self._channel is not None: @@ -3024,6 +3106,7 @@ def add_workforce(self, workforce: Workforce) -> Workforce: # control of worker agents only. workforce._pause_event = self._pause_event self._children.append(workforce) + self._sync_child_stream_callbacks() # If we have a channel set up, set it for the new workforce if hasattr(self, '_channel') and self._channel is not None: @@ -4120,6 +4203,7 @@ async def _create_worker_node_for_task(self, task: Task) -> Worker: ) self._children.append(new_node) + self._sync_child_stream_callbacks() self._notify_worker_created( new_node, diff --git a/examples/workforce/workforce_callbacks_example.py b/examples/workforce/workforce_callbacks_example.py index 210dba2f68..f602cd5edf 100644 --- a/examples/workforce/workforce_callbacks_example.py +++ b/examples/workforce/workforce_callbacks_example.py @@ -151,9 +151,19 @@ def main() -> None: model=model, ) + # Real-time streaming callback from worker execution. + # You can filter by worker_id/task_id to only consume selected streams. + def on_worker_stream( + worker_id: str, task_id: str, text: str, mode: str + ) -> None: + if not text.strip(): + return + print(f"[Stream][{worker_id}][{task_id}][{mode}] {text}", end="") + workforce = Workforce( "Workforce Callbacks Demo", callbacks=[WorkforceLogger('demo-logger'), PrintCallback()], + stream_callback=on_worker_stream, use_structured_output_handler=True, failure_handling_config=FailureHandlingConfig(enabled_strategies=[]), default_model=model, diff --git a/test/workforce/test_workforce_single_agent.py b/test/workforce/test_workforce_single_agent.py index af27f11576..04b77efc11 100644 --- a/test/workforce/test_workforce_single_agent.py +++ b/test/workforce/test_workforce_single_agent.py @@ -12,20 +12,18 @@ # limitations under the License. # ========= Copyright 2023-2026 @ CAMEL-AI.org. All Rights Reserved. ========= import time -from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional +from typing import Awaitable, Callable, List, Optional from unittest.mock import patch import pytest from camel.agents.chat_agent import ChatAgent from camel.messages.base import BaseMessage +from camel.responses import ChatAgentResponse from camel.societies.workforce import Workforce from camel.societies.workforce.single_agent_worker import SingleAgentWorker from camel.tasks.task import Task, TaskState -if TYPE_CHECKING: - from camel.responses import ChatAgentResponse - class AlwaysFailingWorker(SingleAgentWorker): """A worker that always fails tasks for testing purposes.""" @@ -52,6 +50,48 @@ async def _process_task( return TaskState.FAILED +class StreamingProbeWorker(SingleAgentWorker): + """A worker that emits one artificial streaming chunk.""" + + def __init__(self, description: str): + sys_msg = BaseMessage.make_assistant_message( + role_name="Streaming Probe", + content="Emit one stream chunk.", + ) + agent = ChatAgent(sys_msg) + super().__init__(description, agent) + + async def _process_task( + self, + task: Task, + dependencies: List[Task], + stream_callback: Optional[ + Callable[["ChatAgentResponse"], Optional[Awaitable[None]]] + ] = None, + ) -> TaskState: + chunk = ChatAgentResponse( + msgs=[ + BaseMessage.make_assistant_message( + role_name="Assistant", content="partial" + ) + ], + terminated=False, + info={"stream_accumulate_mode": "delta"}, + ) + if stream_callback: + maybe = stream_callback(chunk) + if maybe is not None: + await maybe + task.state = TaskState.DONE + task.result = "ok" + return TaskState.DONE + + +class _DummyChannel: + async def return_task(self, task_id: str) -> None: + self.last_task_id = task_id + + @pytest.mark.asyncio async def test_graceful_shutdown_immediate_timeout(): """Test that 0 timeout causes immediate shutdown.""" @@ -143,3 +183,63 @@ async def test_get_dep_tasks_info(mock_process_task, mock_decompose): # Verify the mocks were called mock_decompose.assert_called_once_with(agent) mock_process_task.assert_called_once_with(human_task, mock_subtasks) + + +@pytest.mark.asyncio +async def test_worker_forwards_stream_chunks_to_registered_callback(): + worker = StreamingProbeWorker("probe") + worker.set_channel(_DummyChannel()) + received = [] + + async def stream_handler(chunk, worker_id, task_id): + received.append((worker_id, task_id, chunk.msg.content)) + + worker.set_stream_callback(stream_handler) + task = Task(content="test", id="task-stream") + + await worker._process_single_task(task) + + assert len(received) == 1 + worker_id, task_id, content = received[0] + assert worker_id + assert task_id == "task-stream" + assert content == "partial" + + +@pytest.mark.asyncio +async def test_workforce_stream_callback_accumulate_mode_emits_delta(): + chunks = [] + + def on_stream(worker_id: str, task_id: str, text: str, mode: str): + chunks.append((worker_id, task_id, text, mode)) + + workforce = Workforce("stream test", stream_callback=on_stream) + + first_chunk = ChatAgentResponse( + msgs=[ + BaseMessage.make_assistant_message( + role_name="Assistant", + content="Hello", + ) + ], + terminated=False, + info={"stream_accumulate_mode": "accumulate"}, + ) + second_chunk = ChatAgentResponse( + msgs=[ + BaseMessage.make_assistant_message( + role_name="Assistant", + content="Hello world", + ) + ], + terminated=False, + info={"stream_accumulate_mode": "accumulate"}, + ) + + await workforce._on_worker_stream_chunk(first_chunk, "worker-1", "task-1") + await workforce._on_worker_stream_chunk(second_chunk, "worker-1", "task-1") + + assert chunks == [ + ("worker-1", "task-1", "Hello", "accumulate"), + ("worker-1", "task-1", " world", "accumulate"), + ]