Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 180 additions & 67 deletions camel/societies/workforce/single_agent_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
import datetime
import time
from collections import deque
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TypedDict

from colorama import Fore

from camel.agents import ChatAgent
from camel.agents.chat_agent import AsyncStreamingChatAgentResponse
from camel.logger import get_logger
from camel.memories.records import MemoryRecord
from camel.messages.base import BaseMessage
from camel.societies.workforce.prompts import PROCESS_TASK_PROMPT
from camel.societies.workforce.structured_output_handler import (
StructuredOutputHandler,
Expand All @@ -34,11 +36,25 @@
WorkflowMemoryManager,
)
from camel.tasks.task import Task, TaskState, is_task_result_insufficient
from camel.types import OpenAIBackendRole
from camel.utils.context_utils import ContextUtility

logger = get_logger(__name__)


class ExecutionContext(TypedDict, total=False):
r"""Type definition for task execution context used in breakpoint resume.

Attributes:
conversation_history: Serialized conversation history from previous
execution attempts.
retry_context: Message describing the failure reason for retry.
"""

conversation_history: List[Dict[str, Any]]
retry_context: str


class AgentPool:
r"""A pool of agent instances for efficient reuse.

Expand Down Expand Up @@ -217,6 +233,11 @@ class SingleAgentWorker(Worker):
conversations from all task executions are accumulated for
potential workflow saving. Set to True if you plan to call
save_workflow_memories(). (default: :obj:`False`)
enable_breakpoint_resume (bool, optional): Whether to preserve and
restore conversation history for retry attempts. When enabled,
failed tasks save their conversation history to
``task.additional_info`` and reuse it on retry. (default:
:obj:`True`)
"""

def __init__(
Expand All @@ -230,6 +251,7 @@ def __init__(
use_structured_output_handler: bool = True,
context_utility: Optional[ContextUtility] = None,
enable_workflow_memory: bool = False,
enable_breakpoint_resume: bool = True,
) -> None:
node_id = worker.agent_id
super().__init__(
Expand All @@ -245,6 +267,7 @@ def __init__(
self.worker = worker
self.use_agent_pool = use_agent_pool
self.enable_workflow_memory = enable_workflow_memory
self.enable_breakpoint_resume = enable_breakpoint_resume
self._shared_context_utility = context_utility
self._context_utility: Optional[ContextUtility] = (
None # Will be initialized when needed
Expand All @@ -271,6 +294,68 @@ def __init__(
auto_scale=auto_scale_pool,
)

def _ensure_execution_context(self, task: Task) -> ExecutionContext:
r"""Ensure task.execution_context is a mutable dict."""
if task.execution_context is None or not isinstance(
task.execution_context, dict
):
task.execution_context = {}
return task.execution_context # type: ignore[return-value]

def _serialize_conversation_history(
self, agent: ChatAgent
) -> List[Dict[str, Any]]:
r"""Serialize agent's conversation history for storage.

Filters out retry_context messages to avoid duplication on subsequent
retries.
"""
records = agent.memory.retrieve()
result = []
for record in records:
meta = record.memory_record.message.meta_dict or {}
if meta.get("type") == "retry_context":
continue
result.append(record.memory_record.to_dict())
return result

def _restore_conversation_history(
self, agent: ChatAgent, history: List[Dict[str, Any]]
) -> None:
r"""Restore conversation history to agent, skipping system messages."""
for record_dict in history:
record = MemoryRecord.from_dict(record_dict)

if record.role_at_backend == OpenAIBackendRole.SYSTEM:
continue

restored_record = MemoryRecord(
message=record.message,
role_at_backend=record.role_at_backend,
timestamp=record.timestamp,
agent_id=agent.agent_id,
extra_info=record.extra_info,
)
agent.memory.write_record(restored_record)

def _save_breakpoint_history(self, task: Task, agent: ChatAgent) -> None:
r"""Save conversation history into task.execution_context."""
execution_context = self._ensure_execution_context(task)
execution_context["conversation_history"] = (
self._serialize_conversation_history(agent)
)

def _build_retry_message(self, task: Task) -> str:
r"""Build a concise retry message describing the last failure."""
failure_reason = (task.result or "Unknown error").strip()
if len(failure_reason) > 500:
failure_reason = f"{failure_reason[:500]}..."
return (
f"Retry attempt {task.failure_count + 1}. "
f"Previous attempt failed with: {failure_reason}. "
"Continue from the previous context and address the failure."
)

def reset(self) -> Any:
r"""Resets the worker to its initial state."""
super().reset()
Expand Down Expand Up @@ -359,6 +444,23 @@ async def _process_task(
response_content = ""

try:
if self.enable_breakpoint_resume and task.failure_count > 0:
execution_context = self._ensure_execution_context(task)
history = execution_context.get("conversation_history")
if history:
self._restore_conversation_history(worker_agent, history)

retry_message = self._build_retry_message(task)
execution_context["retry_context"] = retry_message
worker_agent.update_memory(
BaseMessage.make_user_message(
role_name="user",
content=retry_message,
meta_dict={"type": "retry_context"},
),
OpenAIBackendRole.USER,
)

dependency_tasks_info = self._get_dep_tasks_info(dependencies)
prompt = str(
PROCESS_TASK_PROMPT.format(
Expand Down Expand Up @@ -489,81 +591,92 @@ async def _process_task(
)
# Store error information in task result
task.result = f"{type(e).__name__}: {e!s}"
return TaskState.FAILED
finally:
# Return agent to pool or let it be garbage collected
if self.enable_breakpoint_resume:
self._save_breakpoint_history(task, worker_agent)
await self._return_worker_agent(worker_agent)
return TaskState.FAILED

# Populate additional_info with worker attempt details
if task.additional_info is None:
task.additional_info = {}

# Create worker attempt details with descriptive keys
worker_attempt_details = {
"agent_id": getattr(
worker_agent, "agent_id", worker_agent.role_name
),
"original_worker_id": getattr(
self.worker, "agent_id", self.worker.role_name
),
"timestamp": str(datetime.datetime.now()),
"description": f"Attempt by "
f"{getattr(worker_agent, 'agent_id', worker_agent.role_name)} "
f"(from pool/clone of "
f"{getattr(self.worker, 'agent_id', self.worker.role_name)}) "
f"to process task: {task.content}",
"response_content": response_content[:50],
"tool_calls": str(
final_response.info.get("tool_calls")
if isinstance(response, AsyncStreamingChatAgentResponse)
else response.info.get("tool_calls")
)[:50],
"total_tokens": total_tokens,
}

# Store the worker attempt in additional_info
if "worker_attempts" not in task.additional_info:
task.additional_info["worker_attempts"] = []
task.additional_info["worker_attempts"].append(worker_attempt_details)
try:
# Populate additional_info with worker attempt details
if task.additional_info is None:
task.additional_info = {}

# Create worker attempt details with descriptive keys
worker_attempt_details = {
"agent_id": getattr(
worker_agent, "agent_id", worker_agent.role_name
),
"original_worker_id": getattr(
self.worker, "agent_id", self.worker.role_name
),
"timestamp": str(datetime.datetime.now()),
"description": f"Attempt by "
f"{getattr(worker_agent, 'agent_id', worker_agent.role_name)} "
f"(from pool/clone of "
f"{getattr(self.worker, 'agent_id', self.worker.role_name)}) "
f"to process task: {task.content}",
"response_content": response_content[:50],
"tool_calls": str(
final_response.info.get("tool_calls")
if isinstance(response, AsyncStreamingChatAgentResponse)
else response.info.get("tool_calls")
)[:50],
"total_tokens": total_tokens,
}

# Store the worker attempt in additional_info
if "worker_attempts" not in task.additional_info:
task.additional_info["worker_attempts"] = []
task.additional_info["worker_attempts"].append(
worker_attempt_details
)

# Store the actual token usage for this specific task
task.additional_info["token_usage"] = {"total_tokens": total_tokens}
# Store the actual token usage for this specific task
task.additional_info["token_usage"] = {
"total_tokens": total_tokens
}

print(f"======\n{Fore.GREEN}Response from {self}:{Fore.RESET}")
logger.info(f"Response from {self}:")
print(f"======\n{Fore.GREEN}Response from {self}:{Fore.RESET}")
logger.info(f"Response from {self}:")

if not self.use_structured_output_handler:
# Handle native structured output parsing
if task_result is None:
logger.error(
"Error in worker step execution: Invalid task result"
)
task_result = TaskResult(
content="Failed to generate valid task result.",
failed=True,
)
if not self.use_structured_output_handler:
# Handle native structured output parsing
if task_result is None:
logger.error(
"Error in worker step execution: Invalid task result"
)
task_result = TaskResult(
content="Failed to generate valid task result.",
failed=True,
)

color = Fore.RED if task_result.failed else Fore.GREEN # type: ignore[union-attr]
print(
f"\n{color}{task_result.content}{Fore.RESET}\n======", # type: ignore[union-attr]
)
if task_result.failed: # type: ignore[union-attr]
logger.error(f"{task_result.content}") # type: ignore[union-attr]
else:
logger.info(f"{task_result.content}") # type: ignore[union-attr]
color = Fore.RED if task_result.failed else Fore.GREEN # type: ignore[union-attr]
print(
f"\n{color}{task_result.content}{Fore.RESET}\n======", # type: ignore[union-attr]
)
if task_result.failed: # type: ignore[union-attr]
logger.error(f"{task_result.content}") # type: ignore[union-attr]
else:
logger.info(f"{task_result.content}") # type: ignore[union-attr]

task.result = task_result.content # type: ignore[union-attr]
task.result = task_result.content # type: ignore[union-attr]

if task_result.failed: # type: ignore[union-attr]
return TaskState.FAILED
if task_result.failed: # type: ignore[union-attr]
if self.enable_breakpoint_resume:
self._save_breakpoint_history(task, worker_agent)
return TaskState.FAILED

if is_task_result_insufficient(task):
logger.warning(
f"Task {task.id}: Content validation failed - "
f"task marked as failed"
)
return TaskState.FAILED
return TaskState.DONE
if is_task_result_insufficient(task):
logger.warning(
f"Task {task.id}: Content validation failed - "
f"task marked as failed"
)
if self.enable_breakpoint_resume:
self._save_breakpoint_history(task, worker_agent)
return TaskState.FAILED
return TaskState.DONE
finally:
await self._return_worker_agent(worker_agent)

async def _listen_to_channel(self):
r"""Override to start cleanup task when pool is enabled."""
Expand Down
6 changes: 6 additions & 0 deletions camel/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ class Task(BaseModel):
(default: :obj:`[]`)
additional_info (Optional[Dict[str, Any]]): Additional information for
the task. (default: :obj:`None`)
execution_context (Optional[Dict[str, Any]]): Internal execution state
used by workers for features like breakpoint resume. This stores
conversation history and retry context to allow failed tasks to
resume from where they left off. (default: :obj:`None`)
image_list (Optional[List[Union[Image.Image, str]]]): Optional list
of PIL Image objects or image URLs (strings) associated with the
task. (default: :obj:`None`)
Expand Down Expand Up @@ -274,6 +278,8 @@ class Task(BaseModel):

additional_info: Optional[Dict[str, Any]] = None

execution_context: Optional[Dict[str, Any]] = None

image_list: Optional[List[Union[Image.Image, str]]] = None

image_detail: Literal["auto", "low", "high"] = "auto"
Expand Down
Loading
Loading