diff --git a/camel/societies/workforce/single_agent_worker.py b/camel/societies/workforce/single_agent_worker.py index 3026e832d0..05c3dce807 100644 --- a/camel/societies/workforce/single_agent_worker.py +++ b/camel/societies/workforce/single_agent_worker.py @@ -24,6 +24,7 @@ from camel.agents import ChatAgent from camel.agents.chat_agent import AsyncStreamingChatAgentResponse from camel.logger import get_logger +from camel.messages.base import BaseMessage from camel.societies.workforce.prompts import PROCESS_TASK_PROMPT from camel.societies.workforce.structured_output_handler import ( StructuredOutputHandler, @@ -34,6 +35,7 @@ WorkflowMemoryManager, ) from camel.tasks.task import Task, TaskState, is_task_result_insufficient +from camel.types import OpenAIBackendRole from camel.utils.context_utils import ContextUtility logger = get_logger(__name__) @@ -217,6 +219,11 @@ class SingleAgentWorker(Worker): conversations from all task executions are accumulated for potential workflow saving. Set to True if you plan to call save_workflow_memories(). (default: :obj:`False`) + enable_breakpoint_resume (bool, optional): Whether to retain the agent + instance on task failure for reuse on retry. When enabled, the + worker keeps the failed agent (with its conversation history + intact) and reuses it directly on the next attempt instead of + creating a fresh agent. (default: :obj:`True`) """ def __init__( @@ -230,6 +237,7 @@ def __init__( use_structured_output_handler: bool = True, context_utility: Optional[ContextUtility] = None, enable_workflow_memory: bool = False, + enable_breakpoint_resume: bool = True, ) -> None: node_id = worker.agent_id super().__init__( @@ -245,6 +253,8 @@ def __init__( self.worker = worker self.use_agent_pool = use_agent_pool self.enable_workflow_memory = enable_workflow_memory + self.enable_breakpoint_resume = enable_breakpoint_resume + self._failed_task_agents: Dict[str, ChatAgent] = {} self._shared_context_utility = context_utility self._context_utility: Optional[ContextUtility] = ( None # Will be initialized when needed @@ -271,6 +281,17 @@ def __init__( auto_scale=auto_scale_pool, ) + def _build_retry_message(self, task: Task) -> str: + r"""Build a concise retry message describing the last failure.""" + failure_reason = (task.result or "Unknown error").strip() + if len(failure_reason) > 500: + failure_reason = f"{failure_reason[:500]}..." + return ( + f"Retry attempt {task.failure_count + 1}. " + f"Previous attempt failed with: {failure_reason}. " + "Continue from the previous context and address the failure." + ) + def reset(self) -> Any: r"""Resets the worker to its initial state.""" super().reset() @@ -354,11 +375,32 @@ async def _process_task( TaskState: `TaskState.DONE` if processed successfully, otherwise `TaskState.FAILED`. """ - # Get agent efficiently (from pool or by cloning) - worker_agent = await self._get_worker_agent() + # Reuse the failed agent if breakpoint resume is enabled, + # otherwise get a fresh agent from pool or by cloning. + reusing_agent = False + if ( + self.enable_breakpoint_resume + and task.failure_count > 0 + and task.id in self._failed_task_agents + ): + worker_agent = self._failed_task_agents.pop(task.id) + reusing_agent = True + else: + worker_agent = await self._get_worker_agent() response_content = "" try: + if reusing_agent: + retry_message = self._build_retry_message(task) + worker_agent.update_memory( + BaseMessage.make_user_message( + role_name="user", + content=retry_message, + meta_dict={"type": "retry_context"}, + ), + OpenAIBackendRole.USER, + ) + dependency_tasks_info = self._get_dep_tasks_info(dependencies) prompt = str( PROCESS_TASK_PROMPT.format( @@ -489,81 +531,98 @@ async def _process_task( ) # Store error information in task result task.result = f"{type(e).__name__}: {e!s}" + if self.enable_breakpoint_resume: + self._failed_task_agents[task.id] = worker_agent + else: + await self._return_worker_agent(worker_agent) return TaskState.FAILED - finally: - # Return agent to pool or let it be garbage collected - await self._return_worker_agent(worker_agent) - - # Populate additional_info with worker attempt details - if task.additional_info is None: - task.additional_info = {} - - # Create worker attempt details with descriptive keys - worker_attempt_details = { - "agent_id": getattr( - worker_agent, "agent_id", worker_agent.role_name - ), - "original_worker_id": getattr( - self.worker, "agent_id", self.worker.role_name - ), - "timestamp": str(datetime.datetime.now()), - "description": f"Attempt by " - f"{getattr(worker_agent, 'agent_id', worker_agent.role_name)} " - f"(from pool/clone of " - f"{getattr(self.worker, 'agent_id', self.worker.role_name)}) " - f"to process task: {task.content}", - "response_content": response_content[:50], - "tool_calls": str( - final_response.info.get("tool_calls") - if isinstance(response, AsyncStreamingChatAgentResponse) - else response.info.get("tool_calls") - )[:50], - "total_tokens": total_tokens, - } - # Store the worker attempt in additional_info - if "worker_attempts" not in task.additional_info: - task.additional_info["worker_attempts"] = [] - task.additional_info["worker_attempts"].append(worker_attempt_details) + try: + # Populate additional_info with worker attempt details + if task.additional_info is None: + task.additional_info = {} + + # Create worker attempt details with descriptive keys + worker_attempt_details = { + "agent_id": getattr( + worker_agent, "agent_id", worker_agent.role_name + ), + "original_worker_id": getattr( + self.worker, "agent_id", self.worker.role_name + ), + "timestamp": str(datetime.datetime.now()), + "description": f"Attempt by " + f"{getattr(worker_agent, 'agent_id', worker_agent.role_name)} " + f"(from pool/clone of " + f"{getattr(self.worker, 'agent_id', self.worker.role_name)}) " + f"to process task: {task.content}", + "response_content": response_content[:50], + "tool_calls": str( + final_response.info.get("tool_calls") + if isinstance(response, AsyncStreamingChatAgentResponse) + else response.info.get("tool_calls") + )[:50], + "total_tokens": total_tokens, + } + + # Store the worker attempt in additional_info + if "worker_attempts" not in task.additional_info: + task.additional_info["worker_attempts"] = [] + task.additional_info["worker_attempts"].append( + worker_attempt_details + ) - # Store the actual token usage for this specific task - task.additional_info["token_usage"] = {"total_tokens": total_tokens} + # Store the actual token usage for this specific task + task.additional_info["token_usage"] = { + "total_tokens": total_tokens + } - print(f"======\n{Fore.GREEN}Response from {self}:{Fore.RESET}") - logger.info(f"Response from {self}:") + print(f"======\n{Fore.GREEN}Response from {self}:{Fore.RESET}") + logger.info(f"Response from {self}:") - if not self.use_structured_output_handler: - # Handle native structured output parsing - if task_result is None: - logger.error( - "Error in worker step execution: Invalid task result" - ) - task_result = TaskResult( - content="Failed to generate valid task result.", - failed=True, - ) + if not self.use_structured_output_handler: + # Handle native structured output parsing + if task_result is None: + logger.error( + "Error in worker step execution: Invalid task result" + ) + task_result = TaskResult( + content="Failed to generate valid task result.", + failed=True, + ) - color = Fore.RED if task_result.failed else Fore.GREEN # type: ignore[union-attr] - print( - f"\n{color}{task_result.content}{Fore.RESET}\n======", # type: ignore[union-attr] - ) - if task_result.failed: # type: ignore[union-attr] - logger.error(f"{task_result.content}") # type: ignore[union-attr] - else: - logger.info(f"{task_result.content}") # type: ignore[union-attr] + color = Fore.RED if task_result.failed else Fore.GREEN # type: ignore[union-attr] + print( + f"\n{color}{task_result.content}{Fore.RESET}\n======", # type: ignore[union-attr] + ) + if task_result.failed: # type: ignore[union-attr] + logger.error(f"{task_result.content}") # type: ignore[union-attr] + else: + logger.info(f"{task_result.content}") # type: ignore[union-attr] - task.result = task_result.content # type: ignore[union-attr] + task.result = task_result.content # type: ignore[union-attr] - if task_result.failed: # type: ignore[union-attr] - return TaskState.FAILED + if task_result.failed: # type: ignore[union-attr] + if self.enable_breakpoint_resume: + self._failed_task_agents[task.id] = worker_agent + return TaskState.FAILED + return TaskState.FAILED - if is_task_result_insufficient(task): - logger.warning( - f"Task {task.id}: Content validation failed - " - f"task marked as failed" - ) - return TaskState.FAILED - return TaskState.DONE + if is_task_result_insufficient(task): + logger.warning( + f"Task {task.id}: Content validation failed - " + f"task marked as failed" + ) + if self.enable_breakpoint_resume: + self._failed_task_agents[task.id] = worker_agent + return TaskState.FAILED + return TaskState.FAILED + return TaskState.DONE + finally: + # Only return the agent to the pool if it wasn't retained + # for breakpoint resume + if task.id not in self._failed_task_agents: + await self._return_worker_agent(worker_agent) async def _listen_to_channel(self): r"""Override to start cleanup task when pool is enabled.""" diff --git a/examples/workforce/breakpoint_resume_example.py b/examples/workforce/breakpoint_resume_example.py new file mode 100644 index 0000000000..983a373ed4 --- /dev/null +++ b/examples/workforce/breakpoint_resume_example.py @@ -0,0 +1,226 @@ +# ========= Copyright 2023-2026 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2026 @ CAMEL-AI.org. All Rights Reserved. ========= + +""" +Breakpoint Resume Example +========================= + +This example demonstrates the breakpoint resume feature of SingleAgentWorker. +When a task fails, the worker retains the agent instance (with its conversation +history intact) and reuses it directly on the next retry attempt, instead of +creating a fresh agent. + +This example: +- First attempt: Mocked to FAIL (simulating partial work done) +- Second attempt: Real API call with the same agent, history preserved +""" + +import asyncio +from unittest.mock import MagicMock + +from camel.agents.chat_agent import ChatAgent +from camel.messages.base import BaseMessage +from camel.models import ModelFactory +from camel.societies.workforce.single_agent_worker import SingleAgentWorker +from camel.societies.workforce.utils import TaskResult +from camel.tasks.task import Task +from camel.types import ModelPlatformType, ModelType, OpenAIBackendRole + + +async def main(): + print("=" * 70) + print("Breakpoint Resume Demo") + print("=" * 70) + print( + "\nThis demo shows how the agent instance is retained across retries." + ) + print("- First attempt: Mocked to FAIL with partial work saved") + print("- Second attempt: Same agent reused with history intact\n") + + # Create system message + sys_msg = BaseMessage.make_assistant_message( + role_name="Research Assistant", + content="You are a research assistant that helps with tasks.", + ) + + # Create agent with a real model for the second attempt + agent = ChatAgent( + system_message=sys_msg, + model=ModelFactory.create( + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, + ), + ) + + call_count = 0 + + # Save the real astep for the second call + real_astep = agent.astep + + async def astep_toggle(prompt, *args, **kwargs): + nonlocal call_count + call_count += 1 + + if call_count == 1: + # First call: simulate partial work then fail + agent.update_memory( + BaseMessage.make_user_message( + role_name="user", + content=( + "[Progress] Started researching quantum " + "computing..." + ), + ), + OpenAIBackendRole.USER, + ) + agent.update_memory( + BaseMessage.make_assistant_message( + role_name="assistant", + content=( + "[Progress] Found: Feynman proposed quantum " + "computing in 1981. Shor's algorithm was " + "discovered in 1994..." + ), + ), + OpenAIBackendRole.ASSISTANT, + ) + + print( + " -> Agent did partial work (added to conversation history)" + ) + print(" -> Returning FAILED result") + + response = MagicMock() + response.msg = MagicMock( + parsed=TaskResult( + content=( + "Incomplete: Connection timeout while fetching " + "sources" + ), + failed=True, + ), + content=( + '{"content":"Incomplete: Connection timeout",' + '"failed":true}' + ), + ) + response.info = {} + return response + else: + # Second call: use real API + return await real_astep(prompt, *args, **kwargs) + + agent.astep = astep_toggle + + # Create worker + worker = SingleAgentWorker( + "Research assistant", + agent, + use_agent_pool=False, + enable_breakpoint_resume=True, + use_structured_output_handler=False, + ) + + # Override to provide our test agent + first_call = True + + async def get_agent(): + nonlocal first_call + if first_call: + first_call = False + return agent + raise AssertionError("Should reuse retained agent") + + async def return_agent(a): + pass + + worker._get_worker_agent = get_agent + worker._return_worker_agent = return_agent + + # Create task + task = Task( + content=( + "Research quantum computing history and list 3 key milestones." + ), + id="task-1", + ) + + # ========== FIRST ATTEMPT (MOCKED FAIL) ========== + print("-" * 70) + print("[FIRST ATTEMPT - Mocked Failure]") + print("-" * 70) + + state = await worker._process_task(task, []) + + print(f"\nResult: {state}") + print(f"Task result: {task.result}") + + # Show retained agent info + if task.id in worker._failed_task_agents: + retained = worker._failed_task_agents[task.id] + records = retained.memory.retrieve() + print(f"\nAgent retained with {len(records)} memory records:") + for record in records: + role = record.memory_record.role_at_backend + content = record.memory_record.message.content[:70] + print(f" [{role}] {content}...") + + # ========== SECOND ATTEMPT (REAL API) ========== + print("\n" + "-" * 70) + print("[SECOND ATTEMPT - Real API Call]") + print("-" * 70) + + # Note: In real usage with Workforce, failure_count is auto-incremented. + # We set it manually here because we're calling _process_task() directly. + task.failure_count = 1 + + print("\nCalling real API with the same agent (history preserved)...") + print("The agent will receive:") + print(" 1. Its existing conversation history (partial work)") + print(" 2. A retry message explaining the failure\n") + + state = await worker._process_task(task, []) + + print(f"\nResult: {state}") + print(f"Task result:\n{task.result}") + + # ========== VERIFY ========== + print("\n" + "-" * 70) + print("[VERIFICATION]") + print("-" * 70) + + # Check agent's memory for retained content + memory_contents = [ + r.memory_record.message.content for r in agent.memory.retrieve() + ] + + has_previous = any( + "Feynman" in c or "Progress" in c for c in memory_contents + ) + has_retry = any("Retry attempt" in c for c in memory_contents) + + print(f" Previous context preserved: {'YES' if has_previous else 'NO'}") + print(f" Retry message injected: {'YES' if has_retry else 'NO'}") + print( + " Agent retained after success: " + f"{'NO' if task.id not in worker._failed_task_agents else 'YES'}" + ) + + print("\n" + "=" * 70) + print("Demo completed!") + print("=" * 70) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/test/workforce/test_breakpoint_resume.py b/test/workforce/test_breakpoint_resume.py new file mode 100644 index 0000000000..f312f13875 --- /dev/null +++ b/test/workforce/test_breakpoint_resume.py @@ -0,0 +1,284 @@ +# ========= Copyright 2023-2026 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2026 @ CAMEL-AI.org. All Rights Reserved. ========= +from unittest.mock import MagicMock + +import pytest + +from camel.agents.chat_agent import ChatAgent +from camel.messages.base import BaseMessage +from camel.societies.workforce.single_agent_worker import SingleAgentWorker +from camel.societies.workforce.utils import TaskResult +from camel.tasks.task import Task, TaskState +from camel.types import OpenAIBackendRole + + +@pytest.mark.asyncio +async def test_breakpoint_resume_reuses_agent_instance(): + """Verify the worker retains and reuses the same agent on retry.""" + sys_msg = BaseMessage.make_assistant_message( + role_name="tester", + content="You are a test agent.", + ) + + agent = ChatAgent(sys_msg) + call_count = 0 + + async def astep_toggle(*_args, **_kwargs): + nonlocal call_count + call_count += 1 + # Add some context during the first call + if call_count == 1: + agent.update_memory( + BaseMessage.make_user_message( + role_name="user", content="first attempt context" + ), + OpenAIBackendRole.USER, + ) + response = MagicMock() + if call_count == 1: + response.msg = MagicMock( + parsed=TaskResult(content="failed", failed=True), + content='{"content":"failed","failed":true}', + ) + else: + response.msg = MagicMock( + parsed=TaskResult(content="ok", failed=False), + content='{"content":"ok","failed":false}', + ) + response.info = {} + return response + + agent.astep = astep_toggle + + worker = SingleAgentWorker( + "stub", + agent, + use_agent_pool=False, + enable_breakpoint_resume=True, + use_structured_output_handler=False, + ) + + # Override to provide our test agent initially + first_call = True + + async def get_agent(): + nonlocal first_call + if first_call: + first_call = False + return agent + # Should not be called on retry when breakpoint resume is enabled + raise AssertionError( + "Should reuse retained agent, not request a new one" + ) + + async def return_agent(a): + pass + + worker._get_worker_agent = get_agent + worker._return_worker_agent = return_agent + + task = Task(content="do something", id="task1") + + # First attempt fails + state_first = await worker._process_task(task, []) + assert state_first == TaskState.FAILED + + # Agent should be retained in _failed_task_agents + assert "task1" in worker._failed_task_agents + + # Second attempt succeeds (reuses the same agent) + task.failure_count = 1 + state_second = await worker._process_task(task, []) + assert state_second == TaskState.DONE + + # Agent should be cleaned up after success + assert "task1" not in worker._failed_task_agents + + # Verify first attempt context is still in the agent's memory + restored_contents = [ + record.memory_record.message.content + for record in agent.memory.retrieve() + ] + assert "first attempt context" in restored_contents + + # Verify retry context message was injected + retry_messages = [ + record + for record in agent.memory.retrieve() + if (record.memory_record.message.meta_dict or {}).get("type") + == "retry_context" + ] + assert len(retry_messages) == 1 + + +@pytest.mark.asyncio +async def test_retry_context_not_duplicated_on_multiple_retries(): + """Verify retry_context messages don't accumulate across retries.""" + sys_msg = BaseMessage.make_assistant_message( + role_name="tester", + content="You are a test agent.", + ) + + agent = ChatAgent(sys_msg) + call_count = 0 + + async def astep_toggle(*_args, **_kwargs): + nonlocal call_count + call_count += 1 + response = MagicMock() + if call_count <= 2: + response.msg = MagicMock( + parsed=TaskResult(content="failed", failed=True), + content='{"content":"failed","failed":true}', + ) + else: + response.msg = MagicMock( + parsed=TaskResult(content="ok", failed=False), + content='{"content":"ok","failed":false}', + ) + response.info = {} + return response + + agent.astep = astep_toggle + + worker = SingleAgentWorker( + "stub", + agent, + use_agent_pool=False, + enable_breakpoint_resume=True, + use_structured_output_handler=False, + ) + + first_call = True + + async def get_agent(): + nonlocal first_call + if first_call: + first_call = False + return agent + raise AssertionError( + "Should reuse retained agent, not request a new one" + ) + + async def return_agent(a): + pass + + worker._get_worker_agent = get_agent + worker._return_worker_agent = return_agent + + task = Task(content="do something", id="task1") + + # First attempt fails + await worker._process_task(task, []) + assert "task1" in worker._failed_task_agents + + # Second attempt fails + task.failure_count = 1 + await worker._process_task(task, []) + assert "task1" in worker._failed_task_agents + + # Third attempt succeeds + task.failure_count = 2 + state = await worker._process_task(task, []) + assert state == TaskState.DONE + + # Check that retry_context messages exist (one per retry) + retry_messages = [ + record + for record in agent.memory.retrieve() + if (record.memory_record.message.meta_dict or {}).get("type") + == "retry_context" + ] + # Two retries = two retry context messages + assert len(retry_messages) == 2 + + +@pytest.mark.asyncio +async def test_breakpoint_resume_disabled(): + """Verify agent is not retained when feature is disabled.""" + sys_msg = BaseMessage.make_assistant_message( + role_name="tester", + content="You are a test agent.", + ) + + agent_first = ChatAgent(sys_msg) + agent_second = ChatAgent(sys_msg) + + async def astep_fail(*_args, **_kwargs): + agent_first.update_memory( + BaseMessage.make_user_message( + role_name="user", content="first attempt context" + ), + OpenAIBackendRole.USER, + ) + response = MagicMock() + response.msg = MagicMock( + parsed=TaskResult(content="failed", failed=True), + content='{"content":"failed","failed":true}', + ) + response.info = {} + return response + + async def astep_success(*_args, **_kwargs): + response = MagicMock() + response.msg = MagicMock( + parsed=TaskResult(content="ok", failed=False), + content='{"content":"ok","failed":false}', + ) + response.info = {} + return response + + agent_first.astep = astep_fail + agent_second.astep = astep_success + + # Create worker with breakpoint resume DISABLED + worker = SingleAgentWorker( + "stub", + agent_first, + use_agent_pool=False, + enable_breakpoint_resume=False, + use_structured_output_handler=False, + ) + # Override to use our test agents + agents = [agent_first, agent_second] + + async def get_agent(): + return agents.pop(0) + + async def return_agent(agent): + pass + + worker._get_worker_agent = get_agent + worker._return_worker_agent = return_agent + + task = Task(content="do something", id="task1") + + # First attempt fails + state_first = await worker._process_task(task, []) + assert state_first == TaskState.FAILED + + # No agent should be retained + assert "task1" not in worker._failed_task_agents + + # Second attempt succeeds (uses a fresh agent) + task.failure_count = 1 + state_second = await worker._process_task(task, []) + assert state_second == TaskState.DONE + + # Verify "first attempt context" was NOT in agent_second + restored_contents = [ + record.memory_record.message.content + for record in agent_second.memory.retrieve() + ] + assert "first attempt context" not in restored_contents