diff --git a/camel/societies/workforce/prompts.py b/camel/societies/workforce/prompts.py index 7fff4acdd3..42dbf1fca3 100644 --- a/camel/societies/workforce/prompts.py +++ b/camel/societies/workforce/prompts.py @@ -375,6 +375,82 @@ "modified_task_content": "new content if replan, else null" }""" +VALIDATION_PROMPT = TextPrompt( + """You are validating the aggregated results from subtasks to ensure they meet the original task requirements. + +**ORIGINAL TASK INFORMATION:** +- Task ID: {task_id} +- Task Content: {task_content} +- Number of Subtasks: {num_subtasks} +- Subtask IDs: {subtask_ids} + +**AGGREGATED RESULT:** +{aggregated_result} + +**YOUR RESPONSIBILITIES:** + +1. **Deduplication**: Identify and remove EXACT duplicate content across subtask results + - A duplicate means the SAME item appears multiple times (e.g., same paper title, same product name, same entity) + - Do NOT consider items as duplicates just because they are similar, related, or in the same category + - Only mark as duplicate if they are truly IDENTICAL (same title/name/identifier) + - Examples: + * DUPLICATE: Two subtasks both return "Paper: Attention Is All You Need" + * NOT DUPLICATE: "Attention Is All You Need" and "BERT: Pre-training of Deep Bidirectional Transformers" (different papers, both about NLP) + * NOT DUPLICATE: "iPhone 14" and "iPhone 15" (different products, both phones) + - Count unique items vs actual duplicates + +2. **Requirement Validation**: Check if the deduplicated results meet the original task requirements + - Extract any numerical requirements (e.g., "5 papers", "3 examples") + - Verify the unique count EXACTLY matches the requirements (not approximately) + - Set requirements_met to FALSE if even one item is missing + - Identify how many items are missing if requirements not met + +3. **Quality Assessment**: Ensure the deduplicated result is coherent and complete + - Verify all unique items are valid and relevant + - Check for completeness and accuracy + +**RESPONSE REQUIREMENTS:** + +- **requirements_met**: MUST be false if unique_count does not EXACTLY match the required number +- **unique_count**: Total number of DISTINCT items after removing EXACT duplicates (items with different titles/names/identifiers are NOT duplicates) +- **duplicate_count**: Number of EXACT duplicate items removed (same title/name appearing multiple times) +- **missing_count**: If requirements not met, how many items are still needed (0 if met) +- **deduplicated_result**: The cleaned result with EXACT duplicates removed and content merged +- **reasoning**: Clear explanation of your validation decision (2-3 sentences). If marking items as duplicates, explicitly state which items are IDENTICAL. +- **additional_task_guidance**: If requirements not met, provide guidance for individual subtask refinement. Express guidance in SINGULAR form (what ONE subtask should do to find ONE unique item), list all items to avoid/exclude, AND suggest exploring diverse areas (different domains, time periods, methodologies, or approaches) to maximize chances of finding different unique items. The system will retry only missing_count subtasks, selecting from duplicate_subtask_ids. +- **duplicate_subtask_ids**: List ONLY the subtask IDs that should be RETRIED due to producing duplicate results. For each set of duplicates, keep ONE result and list the OTHER task IDs for retry. + * Example: If task_1.3, task_1.5, and task_1.7 all returned "Paper A", keep one (e.g., task_1.3) and list the others: ["research_task_1.5", "research_task_1.7"] + * Example: If task_1.2 and task_1.4 both returned "Paper B", keep one and list the other: ["research_task_1.4"] + * Only include tasks that need to be retried because they produced duplicate content + * Do NOT include task IDs that returned unique, distinct results + * Use null if NO exact duplicates found (all results are distinct) + +**RESPONSE FORMAT:** +{response_format} + +**CRITICAL**: +- Return ONLY a valid JSON object +- Be thorough in identifying EXACT duplicates (same title/name/identifier) +- Do NOT mark similar or related items as duplicates +- duplicate_subtask_ids should ONLY contain task IDs to retry (exclude one kept result per duplicate set) +- If all results are distinct, set duplicate_subtask_ids to null +- Provide specific guidance if additional tasks are needed +- Ensure all required fields are included +""" +) + +VALIDATION_RESPONSE_FORMAT = """JSON format: +{ + "requirements_met": true|false, + "unique_count": number, + "duplicate_count": number, + "missing_count": number, + "deduplicated_result": "cleaned content", + "reasoning": "explanation (2-3 sentences)", + "additional_task_guidance": "guidance string or null", + "duplicate_subtask_ids": ["subtask_id1", "subtask_id2"] or null +}""" + # Strategy descriptions for dynamic prompt generation STRATEGY_DESCRIPTIONS = { "retry": """**retry** - Retry with the same worker and task content diff --git a/camel/societies/workforce/structured_output_handler.py b/camel/societies/workforce/structured_output_handler.py index d641a19a7e..b779261da4 100644 --- a/camel/societies/workforce/structured_output_handler.py +++ b/camel/societies/workforce/structured_output_handler.py @@ -22,6 +22,7 @@ RecoveryStrategy, TaskAnalysisResult, TaskAssignResult, + ValidationResult, WorkerConf, ) @@ -391,6 +392,7 @@ def _fix_common_issues( 'decompose', 'create_worker', 'reassign', + 'refine', ] if strategy not in valid_strategies: # Try to match partial @@ -399,6 +401,18 @@ def _fix_common_issues( fixed_data['recovery_strategy'] = valid break + elif schema_name == 'ValidationResult': + if 'deduplicated_result' in fixed_data: + if not isinstance(fixed_data['deduplicated_result'], str): + try: + fixed_data['deduplicated_result'] = json.dumps( + fixed_data['deduplicated_result'], indent=2 + ) + except (TypeError, ValueError): + fixed_data['deduplicated_result'] = str( + fixed_data['deduplicated_result'] + ) + return fixed_data @staticmethod @@ -428,6 +442,19 @@ def _create_default_instance(schema: Type[BaseModel]) -> BaseModel: recovery_strategy=RecoveryStrategy.RETRY, modified_task_content=None, ) + elif schema_name == 'ValidationResult': + return ValidationResult( + requirements_met=False, + unique_count=0, + duplicate_count=0, + missing_count=0, + deduplicated_result="", + reasoning=( + "Default validation result due to parsing error - " + "failing safe" + ), + additional_task_guidance=None, + ) else: # Try to create with empty dict and let defaults handle it return schema() @@ -502,6 +529,21 @@ def create_fallback_response( modified_task_content=None, ) + elif schema_name == 'ValidationResult': + # Return fallback validation result - fail-safe approach + return ValidationResult( + requirements_met=False, + unique_count=0, + duplicate_count=0, + missing_count=0, + deduplicated_result="", + reasoning=( + f"Fallback validation result (failing safe): " + f"{error_message}" + ), + additional_task_guidance=None, + ) + else: # Generic fallback try: diff --git a/camel/societies/workforce/utils.py b/camel/societies/workforce/utils.py index 048296928d..c1c7d1df40 100644 --- a/camel/societies/workforce/utils.py +++ b/camel/societies/workforce/utils.py @@ -216,6 +216,7 @@ class RecoveryStrategy(str, Enum): DECOMPOSE = "decompose" CREATE_WORKER = "create_worker" REASSIGN = "reassign" + REFINE = "refine" def __str__(self): return self.value @@ -417,6 +418,57 @@ def quality_sufficient(self) -> bool: ) +class ValidationResult(BaseModel): + r"""Result of validating aggregated parallel task results. + + This model is used to validate results from parallel subtasks that were + "scattered" across multiple agents, checking for duplicates and ensuring + the final result meets the original requirements. + """ + + requirements_met: bool = Field( + description="Whether the aggregated results meet the original task " + "requirements (e.g., '5 unique papers found')" + ) + + unique_count: int = Field( + description="Number of unique items found after deduplication" + ) + + duplicate_count: int = Field( + default=0, + description="Number of duplicate items that were removed", + ) + + missing_count: int = Field( + default=0, + description="Number of items still needed to meet requirements " + "(e.g., if 5 required but only 3 found, missing_count=2)", + ) + + deduplicated_result: str = Field( + description="The cleaned, deduplicated result content" + ) + + reasoning: str = Field( + description="Explanation of the validation decision and any issues " + "found" + ) + + additional_task_guidance: Optional[str] = Field( + default=None, + description="If requirements not met, guidance for generating " + "additional targeted subtasks (e.g., 'Find 2 more papers, " + "excluding: [list]')", + ) + + duplicate_subtask_ids: Optional[List[str]] = Field( + default=None, + description="List of subtask IDs that produced duplicate results. " + "These subtasks should be retried with refinement guidance.", + ) + + class PipelineTaskBuilder: r"""Helper class for building pipeline tasks with dependencies.""" diff --git a/camel/societies/workforce/workforce.py b/camel/societies/workforce/workforce.py index 1b8f2d4337..a06589864c 100644 --- a/camel/societies/workforce/workforce.py +++ b/camel/societies/workforce/workforce.py @@ -60,6 +60,8 @@ TASK_AGENT_SYSTEM_MESSAGE, TASK_ANALYSIS_PROMPT, TASK_DECOMPOSE_PROMPT, + VALIDATION_PROMPT, + VALIDATION_RESPONSE_FORMAT, ) from camel.societies.workforce.role_playing_worker import RolePlayingWorker from camel.societies.workforce.single_agent_worker import ( @@ -76,6 +78,7 @@ TaskAnalysisResult, TaskAssignment, TaskAssignResult, + ValidationResult, WorkerConf, check_if_running, ) @@ -227,6 +230,14 @@ class Workforce(BaseNode): support native structured output. When disabled, the workforce uses the native response_format parameter. (default: :obj:`True`) + max_refinement_iterations (int, optional): Maximum number of + iterative refinement attempts for parallel task results. When + parallel subtasks produce duplicate content or don't meet + requirements, the workforce will automatically validate and + generate additional targeted subtasks to fill gaps. This + parameter limits the refinement loop to prevent excessive + iterations. Set to 0 to disable refinement validation. + (default: :obj:`2`) callbacks (Optional[List[WorkforceCallback]], optional): A list of callback handlers to observe and record workforce lifecycle events and metrics (e.g., task creation/assignment/start/completion/ @@ -315,6 +326,7 @@ def __init__( share_memory: bool = False, use_structured_output_handler: bool = True, task_timeout_seconds: Optional[float] = None, + max_refinement_iterations: int = 2, mode: WorkforceMode = WorkforceMode.AUTO_DECOMPOSE, callbacks: Optional[List[WorkforceCallback]] = None, failure_handling_config: Optional[ @@ -333,6 +345,7 @@ def __init__( self.task_timeout_seconds = ( task_timeout_seconds or TASK_TIMEOUT_SECONDS ) + self.max_refinement_iterations = max_refinement_iterations self.mode = mode self._initial_mode = mode # Store initial mode for reset() # Initialize failure handling configuration (supports dict input) @@ -1603,6 +1616,147 @@ def _analyze_task( ) return TaskAnalysisResult(**fallback_values) + def _validate_aggregated_result( + self, task: Task, aggregated_result: str + ) -> ValidationResult: + r"""Validate aggregated results from parallel subtasks. + + This method uses the Task Planner Agent to: + 1. Deduplicate content across parallel subtask results + 2. Verify requirements are met (e.g., "5 unique papers") + 3. Provide guidance for additional tasks if needed + + Args: + task (Task): The parent task containing subtasks + aggregated_result (str): The concatenated results from all subtasks + + Returns: + ValidationResult: Validation result with deduplication and + requirement checking + """ + num_subtasks = len(task.subtasks) if task.subtasks else 0 + + subtask_ids = ( + [sub.id for sub in task.subtasks] if task.subtasks else [] + ) + subtask_ids_str = ", ".join(subtask_ids) if subtask_ids else "None" + + validation_prompt = str( + VALIDATION_PROMPT.format( + task_id=task.id, + task_content=task.content, + num_subtasks=num_subtasks, + subtask_ids=subtask_ids_str, + aggregated_result=aggregated_result, + response_format=VALIDATION_RESPONSE_FORMAT, + ) + ) + + examples = [ + { + "requirements_met": True, + "unique_count": 5, + "duplicate_count": 2, + "missing_count": 0, + "deduplicated_result": "Deduplicated content here...", + "reasoning": ( + "Found 5 unique papers. Task_1.3 and task_1.5 both " + "returned 'Attention Is All You Need' (keeping " + "task_1.3's result). Requirements fully met after " + "deduplication." + ), + "additional_task_guidance": None, + "duplicate_subtask_ids": ["task_1.5"], + }, + { + "requirements_met": False, + "unique_count": 3, + "duplicate_count": 4, + "missing_count": 2, + "deduplicated_result": "3 unique papers found...", + "reasoning": ( + "Only 3 unique papers found. Task_1.1 and task_1.2 " + "both returned 'Paper A' (kept task_1.1). Task_1.4, " + "task_1.5, task_1.6 all returned 'Paper B' (kept " + "task_1.4). Need 2 more unique papers." + ), + "additional_task_guidance": ( + "Find ONE unique research paper on the topic, avoiding: " + "Paper A, Paper B, Paper C. To ensure diversity across " + "parallel retries, consider exploring different " + "publication years, research methodologies, or " + "application domains." + ), + "duplicate_subtask_ids": [ + "task_1.2", + "task_1.5", + "task_1.6", + ], + }, + { + "requirements_met": True, + "unique_count": 5, + "duplicate_count": 0, + "missing_count": 0, + "deduplicated_result": "All 5 papers are distinct...", + "reasoning": ( + "Found 5 distinct papers with different titles. " + "All papers are unique, no exact duplicates found. " + "Requirements fully met." + ), + "additional_task_guidance": None, + "duplicate_subtask_ids": None, + }, + ] + + try: + if self.use_structured_output_handler: + enhanced_prompt = ( + self.structured_handler.generate_structured_prompt( + base_prompt=validation_prompt, + schema=ValidationResult, + examples=examples, + ) + ) + response = self.task_agent.step(enhanced_prompt) + + result = self.structured_handler.parse_structured_response( + response.msg.content if response.msg else "", + schema=ValidationResult, + fallback_values=self.structured_handler.create_fallback_response( + ValidationResult, "fail to parse structured response" + ).model_dump(), + ) + + if isinstance(result, ValidationResult): + return result + elif isinstance(result, dict): + return ValidationResult.model_validate(result) + else: + fallback = ( + self.structured_handler.create_fallback_response( + ValidationResult, + "failed to create ValidationResult Instance", + ) + ) + return ValidationResult.model_validate(fallback) + else: + response = self.task_agent.step( + validation_prompt, response_format=ValidationResult + ) + return response.msg.parsed + + except Exception as e: + logger.warning( + f"Error during validation for task {task.id}: {e}, " + f"using fallback" + ) + return ValidationResult.model_validate( + self.structured_handler.create_fallback_response( + ValidationResult, str(e) + ) + ) + async def _apply_recovery_strategy( self, task: Task, @@ -1661,8 +1815,8 @@ async def _apply_recovery_strategy( # Modify the task content and retry if recovery_decision.modified_task_content: task.content = recovery_decision.modified_task_content - logger.info(f"Task {task.id} content modified for replan") + logger.info(f"Task {task.id} content modified for replan") # Repost the modified task if task.id in self._assignees: assignee_id = self._assignees[task.id] @@ -1774,6 +1928,117 @@ async def _apply_recovery_strategy( # For decompose, we return early with special handling return True + elif strategy == RecoveryStrategy.REFINE: + logger.info( + f"Task {task.id} will be refined with additional " + f"targeted subtasks" + ) + + # Access ValidationResult object directly instead of dict + validation_result = None + if task.additional_info: + validation_result = task.additional_info.get( + 'validation_result' + ) + + # Extract values from ValidationResult, with safe defaults + if validation_result and isinstance( + validation_result, ValidationResult + ): + additional_guidance = ( + validation_result.additional_task_guidance + ) + duplicate_ids = ( + validation_result.duplicate_subtask_ids or [] + ) + + if duplicate_ids: + subtasks_to_retry = [ + sub for sub in task.subtasks if sub.id in duplicate_ids + ] + logger.info( + f"{len(subtasks_to_retry)} duplicates to retry " + ) + + refinement_iteration = (task.additional_info or {}).get( + 'refinement_iteration', 0 + ) + + for subtask in subtasks_to_retry: + subtask.result = None + subtask.state = TaskState.OPEN + subtask.failure_count = 0 + + if not subtask.additional_info: + subtask.additional_info = {} + + if 'base_content' not in subtask.additional_info: + subtask.additional_info['base_content'] = ( + subtask.content + ) + + subtask.additional_info['refinement_subtask'] = True + subtask.additional_info['refinement_iteration'] = ( + refinement_iteration + 1 + ) + + if additional_guidance: + subtask.additional_info['refinement_guidance'] = ( + additional_guidance + ) + iteration_num = refinement_iteration + 1 + subtask.content = ( + f"{subtask.additional_info['base_content']}\n\n" + f"IMPORTANT - REFINEMENT ITERATION " + f"{iteration_num}:\n" + f"{additional_guidance}\n\n" + f"Your previous result was a duplicate. You MUST " + f"find a completely DIFFERENT and UNIQUE item " + f"this time." + ) + + subtasks = subtasks_to_retry + + if subtasks: + if not task.additional_info: + task.additional_info = {} + task.additional_info['refinement_iteration'] = ( + refinement_iteration + 1 + ) + + task.state = TaskState.OPEN + + self._pending_tasks.extendleft(reversed(subtasks)) + await self._post_ready_tasks() + action_taken = ( + f"retrying {len(subtasks)} duplicate subtasks" + ) + + logger.info( + f"Task {task.id} retrying {len(subtasks)} " + f"duplicate subtasks: {[st.id for st in subtasks]}" + ) + print( + f"{Fore.CYAN}🔄 Retrying {len(subtasks)} duplicate " + f"subtasks: {', '.join([st.id for st in subtasks])}" + f"{Fore.RESET}" + ) + + if self.share_memory: + logger.info( + f"Syncing shared memory after task {task.id} " + f"refinement" + ) + self._sync_shared_memory() + + return True + else: + logger.warning( + f"No refinement subtasks generated for {task.id}" + ) + task.state = TaskState.DONE + action_taken = "marked complete with partial results" + elif strategy == RecoveryStrategy.CREATE_WORKER: assignee = await self._create_worker_node_for_task(task) await self._post_task(task, assignee.node_id) @@ -4715,14 +4980,95 @@ async def _handle_completed_task(self, task: Task) -> None: f"{completed_subtask.result}" ) - # Set parent task state and result - parent.state = TaskState.DONE - parent.result = ( + aggregated_result = ( "\n\n".join(successful_results) if successful_results else "All subtasks completed" ) + if len(parent.subtasks) > 1 and successful_results: + logger.info( + f"Validating aggregated results for parent task " + f"{parent.id} with {len(parent.subtasks)} subtasks" + ) + validation_result = self._validate_aggregated_result( + parent, aggregated_result + ) + + logger.info( + f"Validation result for {parent.id}: " + f"requirements_met=" + f"{validation_result.requirements_met}, " + f"unique_count={validation_result.unique_count}, " + f"duplicate_count=" + f"{validation_result.duplicate_count}, " + f"missing_count={validation_result.missing_count}" + ) + parent.result = validation_result.deduplicated_result + + if not validation_result.requirements_met: + if parent.additional_info is None: + parent.additional_info = {} + refinement_iteration = parent.additional_info.get( + 'refinement_iteration', 0 + ) + + max_refinement_iterations = getattr( + self, 'max_refinement_iterations', 2 + ) + + if refinement_iteration < max_refinement_iterations: + logger.info( + f"Requirements not met for {parent.id}. " + f"Triggering refinement " + f"(iteration {refinement_iteration + 1}/" + f"{max_refinement_iterations})" + ) + + # Store ValidationResult object directly to avoid + # redundant dictionary conversions + parent.additional_info['validation_result'] = ( + validation_result + ) + + parent.state = TaskState.FAILED + parent.failure_count += 1 + + refine_decision = TaskAnalysisResult( + reasoning=validation_result.reasoning, + recovery_strategy=RecoveryStrategy.REFINE, + modified_task_content=None, + issues=[ + ( + f"Missing " + f"{validation_result.missing_count} " + f"items to meet requirements" + ) + ], + ) + + await self._apply_recovery_strategy( + parent, refine_decision + ) + return + else: + logger.warning( + f"Max refinement iterations " + f"({max_refinement_iterations}) reached for " + f"{parent.id}. Accepting partial results." + ) + print( + f"{Fore.YELLOW}⚠️ Task {parent.id} partially " + f"complete: {validation_result.unique_count} " + f"items found, " + f"{validation_result.missing_count} missing" + f"{Fore.RESET}" + ) + else: + parent.result = aggregated_result + + parent.state = TaskState.DONE + logger.debug( f"All subtasks of {parent.id} are done. " f"Marking parent as complete." @@ -5404,6 +5750,7 @@ def clone(self, with_memory: bool = False) -> 'Workforce': graceful_shutdown_timeout=self.graceful_shutdown_timeout, share_memory=self.share_memory, use_structured_output_handler=self.use_structured_output_handler, + max_refinement_iterations=self.max_refinement_iterations, task_timeout_seconds=self.task_timeout_seconds, mode=self.mode, failure_handling_config=self.failure_handling_config, diff --git a/examples/workforce/workforce_validation_refinement_example.py b/examples/workforce/workforce_validation_refinement_example.py new file mode 100644 index 0000000000..a6616b0061 --- /dev/null +++ b/examples/workforce/workforce_validation_refinement_example.py @@ -0,0 +1,90 @@ +# ========= Copyright 2023-2025 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2025 @ CAMEL-AI.org. All Rights Reserved. ========= + +import asyncio + +from camel.agents import ChatAgent +from camel.models import ModelFactory +from camel.societies import Workforce +from camel.tasks import Task +from camel.toolkits import ArxivToolkit +from camel.types import ModelPlatformType, ModelType + + +async def main(): + # Create a model instance for the workforce agents + + model = ModelFactory.create( + model_platform=ModelPlatformType.OPENAI, model_type=ModelType.GPT_4_1 + ) + tools = ArxivToolkit().get_tools() + + # Create workforce agents + coordinator_agent = ChatAgent(model=model) + task_agent = ChatAgent(model=model) + + # Create worker agents - these will work in parallel + researcher1 = ChatAgent("You are AI researcher", model=model, tools=tools) + + researcher2 = ChatAgent("You are ML researcher", model=model, tools=tools) + + researcher3 = ChatAgent( + "You are CS engineer researcher", model=model, tools=tools + ) + + # Initialize Workforce with validation and refinement enabled + workforce = Workforce( + description="Research Team with Validation", + coordinator_agent=coordinator_agent, + task_agent=task_agent, + max_refinement_iterations=2, # Allow up to 2 refinement iterations + ) + workforce.add_single_agent_worker("AI researcher", researcher1) + workforce.add_single_agent_worker("ML researcher", researcher2) + workforce.add_single_agent_worker("CS researcher", researcher3) + + # Create a task that will be decomposed into parallel subtasks + # The task explicitly requires 5 unique papers + task = Task( + content=( + """ + Find 5 unique research papers on NLP systems. + For each paper, provide: title, authors, year, + and a brief description. + """ + ), + id="research_task_1", + ) + + try: + # Process the task with the workforce + result = await workforce.process_task_async(task) + + print("\n" + "=" * 80) + print("FINAL RESULT") + print("=" * 80) + print("\nTask State:", result.state.value) + print("\nResult:") + print(result.result) + + except Exception as e: + print(f"\n Error: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + # Run the main example + asyncio.run(main()) diff --git a/test/workforce/test_workforce_validation_refinement.py b/test/workforce/test_workforce_validation_refinement.py new file mode 100644 index 0000000000..8c4557be5c --- /dev/null +++ b/test/workforce/test_workforce_validation_refinement.py @@ -0,0 +1,552 @@ +# ========= Copyright 2023-2025 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2025 @ CAMEL-AI.org. All Rights Reserved. ========= +"""Tests for workforce validation and refinement functionality. + +This module tests the validation refinement feature that allows workforce +to validate aggregated results from parallel subtasks and refine them +if requirements are not met. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from camel.agents import ChatAgent +from camel.models import ModelFactory +from camel.societies.workforce.prompts import ( + VALIDATION_PROMPT, + VALIDATION_RESPONSE_FORMAT, +) +from camel.societies.workforce.structured_output_handler import ( + StructuredOutputHandler, +) +from camel.societies.workforce.utils import ( + RecoveryStrategy, + TaskAnalysisResult, + ValidationResult, +) +from camel.societies.workforce.workforce import Workforce +from camel.tasks.task import Task, TaskState +from camel.types import ModelPlatformType, ModelType + + +class TestValidationResult: + """Tests for the ValidationResult model.""" + + def test_validation_result_basic_creation(self): + """Test basic creation of ValidationResult.""" + result = ValidationResult( + requirements_met=True, + unique_count=5, + duplicate_count=2, + missing_count=0, + deduplicated_result="5 unique papers found", + reasoning="All requirements met after deduplication", + ) + + assert result.requirements_met is True + assert result.unique_count == 5 + assert result.duplicate_count == 2 + assert result.missing_count == 0 + assert result.deduplicated_result == "5 unique papers found" + assert "requirements met" in result.reasoning.lower() + assert result.additional_task_guidance is None + assert result.duplicate_subtask_ids is None + + def test_validation_result_with_guidance(self): + """Test ValidationResult with additional guidance for refinement.""" + result = ValidationResult( + requirements_met=False, + unique_count=3, + duplicate_count=4, + missing_count=2, + deduplicated_result="3 unique papers found", + reasoning="Only 3 unique papers found, need 2 more", + additional_task_guidance=( + "Find ONE unique research paper, excluding: Paper A, Paper B" + ), + duplicate_subtask_ids=["task_1.2", "task_1.5"], + ) + + assert result.requirements_met is False + assert result.unique_count == 3 + assert result.missing_count == 2 + assert result.additional_task_guidance is not None + assert "excluding" in result.additional_task_guidance + assert result.duplicate_subtask_ids == ["task_1.2", "task_1.5"] + + def test_validation_result_default_values(self): + """Test ValidationResult with default values.""" + result = ValidationResult( + requirements_met=True, + unique_count=5, + deduplicated_result="Content here", + reasoning="Validation complete", + ) + + assert result.duplicate_count == 0 + assert result.missing_count == 0 + assert result.additional_task_guidance is None + assert result.duplicate_subtask_ids is None + + def test_validation_result_empty_duplicate_ids(self): + """Test ValidationResult with empty duplicate IDs list.""" + result = ValidationResult( + requirements_met=True, + unique_count=5, + duplicate_count=0, + missing_count=0, + deduplicated_result="All unique content", + reasoning="No duplicates found", + duplicate_subtask_ids=[], + ) + + assert result.duplicate_subtask_ids == [] + + +class TestRecoveryStrategyRefine: + """Tests for the REFINE recovery strategy.""" + + def test_refine_strategy_exists(self): + """Test that REFINE is a valid recovery strategy.""" + assert hasattr(RecoveryStrategy, 'REFINE') + assert RecoveryStrategy.REFINE.value == "refine" + + def test_refine_strategy_string_conversion(self): + """Test string conversion of REFINE strategy.""" + assert str(RecoveryStrategy.REFINE) == "refine" + + def test_refine_strategy_repr(self): + """Test repr of REFINE strategy.""" + assert repr(RecoveryStrategy.REFINE) == "RecoveryStrategy.REFINE" + + def test_all_recovery_strategies(self): + """Test that all expected recovery strategies exist.""" + expected_strategies = [ + 'RETRY', + 'REPLAN', + 'DECOMPOSE', + 'CREATE_WORKER', + 'REASSIGN', + 'REFINE', + ] + for strategy in expected_strategies: + assert hasattr(RecoveryStrategy, strategy) + + +class TestValidationPrompt: + """Tests for the VALIDATION_PROMPT template.""" + + def test_validation_prompt_exists(self): + """Test that VALIDATION_PROMPT is defined.""" + assert VALIDATION_PROMPT is not None + + def test_validation_prompt_format_placeholders(self): + """Test that VALIDATION_PROMPT has required format placeholders.""" + prompt_str = str(VALIDATION_PROMPT) + assert "{task_id}" in prompt_str + assert "{task_content}" in prompt_str + assert "{num_subtasks}" in prompt_str + assert "{subtask_ids}" in prompt_str + assert "{aggregated_result}" in prompt_str + assert "{response_format}" in prompt_str + + def test_validation_prompt_can_format(self): + """Test that VALIDATION_PROMPT can be formatted.""" + formatted = VALIDATION_PROMPT.format( + task_id="test_task_1", + task_content="Find 5 papers", + num_subtasks=5, + subtask_ids="task_1.1, task_1.2, task_1.3", + aggregated_result="Paper 1, Paper 2, Paper 3", + response_format=VALIDATION_RESPONSE_FORMAT, + ) + + assert "test_task_1" in formatted + assert "Find 5 papers" in formatted + assert "task_1.1" in formatted + + def test_validation_response_format_exists(self): + """Test that VALIDATION_RESPONSE_FORMAT is defined.""" + assert VALIDATION_RESPONSE_FORMAT is not None + assert "requirements_met" in VALIDATION_RESPONSE_FORMAT + assert "unique_count" in VALIDATION_RESPONSE_FORMAT + assert "duplicate_count" in VALIDATION_RESPONSE_FORMAT + assert "deduplicated_result" in VALIDATION_RESPONSE_FORMAT + + +class TestStructuredOutputHandlerValidation: + """Tests for StructuredOutputHandler with ValidationResult.""" + + def test_create_default_validation_result(self): + """Test creating default ValidationResult instance.""" + handler = StructuredOutputHandler() + default = handler._create_default_instance(ValidationResult) + + assert isinstance(default, ValidationResult) + assert default.requirements_met is False + assert default.unique_count == 0 + assert default.deduplicated_result == "" + assert "parsing error" in default.reasoning.lower() + + def test_create_fallback_validation_result(self): + """Test creating fallback ValidationResult with error message.""" + handler = StructuredOutputHandler() + fallback = handler.create_fallback_response( + ValidationResult, "Test error message" + ) + + assert isinstance(fallback, ValidationResult) + assert fallback.requirements_met is False + assert "Test error message" in fallback.reasoning + + def test_fix_common_issues_validation_result(self): + """Test fixing common issues in ValidationResult data.""" + data = { + "requirements_met": True, + "unique_count": 5, + "deduplicated_result": { + "items": ["a", "b"] + }, # dict instead of str + "reasoning": "Test", + } + + fixed = StructuredOutputHandler._fix_common_issues( + data, ValidationResult + ) + + assert isinstance(fixed["deduplicated_result"], str) + + def test_parse_validation_result_from_json(self): + """Test parsing ValidationResult from JSON response.""" + handler = StructuredOutputHandler() + json_response = '''```json +{ + "requirements_met": true, + "unique_count": 5, + "duplicate_count": 0, + "missing_count": 0, + "deduplicated_result": "5 papers found", + "reasoning": "All requirements met", + "additional_task_guidance": null, + "duplicate_subtask_ids": null +} +```''' + + result = handler.parse_structured_response( + json_response, ValidationResult + ) + + assert isinstance(result, ValidationResult) + assert result.requirements_met is True + assert result.unique_count == 5 + + def test_parse_validation_result_with_fallback(self): + """Test parsing ValidationResult with fallback on invalid input.""" + handler = StructuredOutputHandler() + invalid_response = "This is not valid JSON" + + result = handler.parse_structured_response( + invalid_response, + ValidationResult, + fallback_values={ + "requirements_met": False, + "unique_count": 0, + "deduplicated_result": "fallback", + "reasoning": "fallback reason", + }, + ) + + assert isinstance(result, ValidationResult) + assert result.requirements_met is False + + +class TestTaskAnalysisResultWithRefine: + """Tests for TaskAnalysisResult with REFINE strategy.""" + + def test_task_analysis_result_with_refine_strategy(self): + """Test creating TaskAnalysisResult with REFINE strategy.""" + result = TaskAnalysisResult( + reasoning="Need to refine parallel task results", + recovery_strategy=RecoveryStrategy.REFINE, + modified_task_content=None, + issues=["Missing 2 items to meet requirements"], + ) + + assert result.recovery_strategy == RecoveryStrategy.REFINE + assert "refine" in result.reasoning.lower() + assert len(result.issues) == 1 + + +@pytest.fixture +def mock_model(): + """Create a mock model for testing.""" + return ModelFactory.create( + model_platform=ModelPlatformType.OPENAI, + model_type=ModelType.STUB, + ) + + +@pytest.fixture +def workforce_with_refinement(mock_model): + """Create a workforce with refinement enabled.""" + coordinator_agent = ChatAgent( + "You are a helpful coordinator.", model=mock_model + ) + task_agent = ChatAgent("You are a helpful task planner.", model=mock_model) + + workforce = Workforce( + description="Test Workforce with Refinement", + coordinator_agent=coordinator_agent, + task_agent=task_agent, + max_refinement_iterations=2, + ) + return workforce + + +class TestWorkforceRefinementInit: + """Tests for Workforce initialization with refinement parameters.""" + + def test_workforce_max_refinement_iterations_default(self, mock_model): + """Test default max_refinement_iterations value.""" + coordinator_agent = ChatAgent( + "You are a coordinator.", model=mock_model + ) + task_agent = ChatAgent("You are a planner.", model=mock_model) + + workforce = Workforce( + description="Test Workforce", + coordinator_agent=coordinator_agent, + task_agent=task_agent, + ) + + assert workforce.max_refinement_iterations == 2 + + def test_workforce_max_refinement_iterations_custom(self, mock_model): + """Test custom max_refinement_iterations value.""" + coordinator_agent = ChatAgent( + "You are a coordinator.", model=mock_model + ) + task_agent = ChatAgent("You are a planner.", model=mock_model) + + workforce = Workforce( + description="Test Workforce", + coordinator_agent=coordinator_agent, + task_agent=task_agent, + max_refinement_iterations=5, + ) + + assert workforce.max_refinement_iterations == 5 + + def test_workforce_max_refinement_iterations_disabled(self, mock_model): + """Test disabling refinement with max_refinement_iterations=0.""" + coordinator_agent = ChatAgent( + "You are a coordinator.", model=mock_model + ) + task_agent = ChatAgent("You are a planner.", model=mock_model) + + workforce = Workforce( + description="Test Workforce", + coordinator_agent=coordinator_agent, + task_agent=task_agent, + max_refinement_iterations=0, + ) + + assert workforce.max_refinement_iterations == 0 + + +class TestValidateAggregatedResult: + """Tests for the _validate_aggregated_result method.""" + + @pytest.mark.asyncio + async def test_validate_aggregated_result_success( + self, workforce_with_refinement + ): + """Test successful validation of aggregated results.""" + workforce = workforce_with_refinement + + # Create parent task with subtasks + parent_task = Task(content="Find 5 papers", id="parent_1") + subtask1 = Task(content="Find paper 1", id="subtask_1.1") + subtask2 = Task(content="Find paper 2", id="subtask_1.2") + parent_task.subtasks = [subtask1, subtask2] + + # Mock the task_agent response + mock_response = MagicMock() + mock_response.msg = MagicMock() + mock_response.msg.content = '''```json +{ + "requirements_met": true, + "unique_count": 5, + "duplicate_count": 0, + "missing_count": 0, + "deduplicated_result": "Paper 1, Paper 2, Paper 3, Paper 4, Paper 5", + "reasoning": "Found 5 unique papers", + "additional_task_guidance": null, + "duplicate_subtask_ids": null +} +```''' + + with patch.object( + workforce.task_agent, 'step', return_value=mock_response + ): + result = workforce._validate_aggregated_result( + parent_task, "Paper 1, Paper 2, Paper 3, Paper 4, Paper 5" + ) + + assert isinstance(result, ValidationResult) + assert result.requirements_met is True + assert result.unique_count == 5 + + @pytest.mark.asyncio + async def test_validate_aggregated_result_with_duplicates( + self, workforce_with_refinement + ): + """Test validation with duplicate detection.""" + workforce = workforce_with_refinement + + parent_task = Task(content="Find 5 papers", id="parent_1") + subtask1 = Task(content="Find paper 1", id="subtask_1.1") + subtask2 = Task(content="Find paper 2", id="subtask_1.2") + subtask3 = Task(content="Find paper 3", id="subtask_1.3") + parent_task.subtasks = [subtask1, subtask2, subtask3] + + mock_response = MagicMock() + mock_response.msg = MagicMock() + mock_response.msg.content = '''```json +{ + "requirements_met": false, + "unique_count": 3, + "duplicate_count": 2, + "missing_count": 2, + "deduplicated_result": "Paper 1, Paper 2, Paper 3", + "reasoning": "Found duplicates, need 2 more papers", + "additional_task_guidance": "Find unique papers excluding Paper 1, 2, 3", + "duplicate_subtask_ids": ["subtask_1.2", "subtask_1.3"] +} +```''' + + with patch.object( + workforce.task_agent, 'step', return_value=mock_response + ): + result = workforce._validate_aggregated_result( + parent_task, "Paper 1, Paper 1, Paper 2, Paper 2, Paper 3" + ) + + assert isinstance(result, ValidationResult) + assert result.requirements_met is False + assert result.duplicate_count == 2 + assert result.duplicate_subtask_ids is not None + + @pytest.mark.asyncio + async def test_validate_aggregated_result_fallback_on_error( + self, workforce_with_refinement + ): + """Test fallback behavior when validation fails.""" + workforce = workforce_with_refinement + + parent_task = Task(content="Find 5 papers", id="parent_1") + parent_task.subtasks = [] + + # Mock an error during validation + with patch.object( + workforce.task_agent, 'step', side_effect=Exception("Test error") + ): + result = workforce._validate_aggregated_result( + parent_task, "Some content" + ) + + # Should return fallback ValidationResult + assert isinstance(result, ValidationResult) + assert result.requirements_met is False + + +class TestApplyRecoveryStrategyRefine: + """Tests for applying REFINE recovery strategy.""" + + @pytest.mark.asyncio + async def test_apply_recovery_strategy_refine( + self, workforce_with_refinement + ): + """Test applying REFINE recovery strategy.""" + workforce = workforce_with_refinement + workforce._running = True + workforce._channel = MagicMock() + workforce._assignees = {} + + # Create parent task with subtasks + parent_task = Task(content="Find 5 papers", id="parent_1") + subtask1 = Task(content="Find paper 1", id="subtask_1.1") + subtask1.result = "Paper A" + subtask1.state = TaskState.DONE + subtask2 = Task(content="Find paper 2", id="subtask_1.2") + subtask2.result = "Paper A" # Duplicate + subtask2.state = TaskState.DONE + parent_task.subtasks = [subtask1, subtask2] + + # Create validation result + validation_result = ValidationResult( + requirements_met=False, + unique_count=1, + duplicate_count=1, + missing_count=1, + deduplicated_result="Paper A", + reasoning="Found duplicate", + additional_task_guidance="Find different paper", + duplicate_subtask_ids=["subtask_1.2"], + ) + + parent_task.additional_info = { + 'validation_result': validation_result, + 'refinement_iteration': 0, + } + + refine_decision = TaskAnalysisResult( + reasoning="Need to refine", + recovery_strategy=RecoveryStrategy.REFINE, + ) + + # Mock _post_ready_tasks to avoid actual posting + workforce._post_ready_tasks = AsyncMock() + + result = await workforce._apply_recovery_strategy( + parent_task, refine_decision + ) + + assert result is True + # Check that the duplicate subtask was reset for retry + assert subtask2.state == TaskState.OPEN + assert subtask2.result is None + + +class TestWorkforceCloneWithRefinement: + """Tests for workforce cloning with refinement parameter.""" + + def test_clone_preserves_refinement_iterations(self, mock_model): + """Test that cloning preserves max_refinement_iterations.""" + coordinator_agent = ChatAgent( + "You are a coordinator.", model=mock_model + ) + task_agent = ChatAgent("You are a planner.", model=mock_model) + + original = Workforce( + description="Original Workforce", + coordinator_agent=coordinator_agent, + task_agent=task_agent, + max_refinement_iterations=3, + ) + + cloned = original.clone() + + assert cloned.max_refinement_iterations == 3