Skip to content

Commit 6cbeee0

Browse files
authored
[Feat] Support more running mode in workforce (#3157)
1 parent c7d4423 commit 6cbeee0

File tree

5 files changed

+2255
-66
lines changed

5 files changed

+2255
-66
lines changed

camel/societies/workforce/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414

1515
from .role_playing_worker import RolePlayingWorker
1616
from .single_agent_worker import SingleAgentWorker
17+
from .utils import PipelineTaskBuilder
1718
from .workflow_memory_manager import WorkflowSelectionMethod
18-
from .workforce import Workforce
19+
from .workforce import Workforce, WorkforceMode
1920

2021
__all__ = [
2122
"Workforce",
23+
"WorkforceMode",
24+
"PipelineTaskBuilder",
2225
"SingleAgentWorker",
2326
"RolePlayingWorker",
2427
"WorkflowSelectionMethod",

camel/societies/workforce/utils.py

Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,320 @@ def quality_sufficient(self) -> bool:
221221
)
222222

223223

224+
class PipelineTaskBuilder:
225+
r"""Helper class for building pipeline tasks with dependencies."""
226+
227+
def __init__(self):
228+
"""Initialize an empty pipeline task builder."""
229+
from camel.tasks import Task
230+
231+
self._TaskClass = Task
232+
self.task_list = []
233+
self.task_counter = 0
234+
self._task_registry = {} # task_id -> Task mapping for fast lookup
235+
self._last_task_id = (
236+
None # Track the last added task for chain inference
237+
)
238+
# Track the last added parallel tasks for sync
239+
self._last_parallel_tasks: List[str] = []
240+
241+
def add(
242+
self,
243+
content: str,
244+
task_id: Optional[str] = None,
245+
dependencies: Optional[List[str]] = None,
246+
additional_info: Optional[dict] = None,
247+
auto_depend: bool = True,
248+
) -> 'PipelineTaskBuilder':
249+
"""Add a task to the pipeline with support for chaining.
250+
251+
Args:
252+
content (str): The content/description of the task.
253+
task_id (str, optional): Unique identifier for the task. If None,
254+
a unique ID will be generated. (default: :obj:`None`)
255+
dependencies (List[str], optional): List of task IDs that this
256+
task depends on. If None and auto_depend=True, will depend on
257+
the last added task. (default: :obj:`None`)
258+
additional_info (dict, optional): Additional information
259+
for the task. (default: :obj:`None`)
260+
auto_depend (bool, optional): If True and dependencies is None,
261+
automatically depend on the last added task.
262+
(default: :obj:`True`)
263+
264+
Returns:
265+
PipelineTaskBuilder: Self for method chaining.
266+
267+
Raises:
268+
ValueError: If task_id already exists or if any dependency is
269+
not found.
270+
271+
Example:
272+
>>> builder.add("Step 1").add("Step 2").add("Step 3")
273+
# Step 2 depends on Step 1, Step 3 depends on Step 2
274+
"""
275+
# Generate or validate task_id
276+
task_id = task_id or f"pipeline_task_{self.task_counter}"
277+
278+
# Check ID uniqueness
279+
if task_id in self._task_registry:
280+
raise ValueError(f"Task ID '{task_id}' already exists")
281+
282+
# Auto-infer dependencies if not specified
283+
if (
284+
dependencies is None
285+
and auto_depend
286+
and self._last_task_id is not None
287+
):
288+
dependencies = [self._last_task_id]
289+
290+
# Validate dependencies exist
291+
dep_tasks = []
292+
if dependencies:
293+
missing_deps = [
294+
dep for dep in dependencies if dep not in self._task_registry
295+
]
296+
if missing_deps:
297+
raise ValueError(f"Dependencies not found: {missing_deps}")
298+
dep_tasks = [self._task_registry[dep] for dep in dependencies]
299+
300+
# Create task
301+
task = self._TaskClass(
302+
content=content,
303+
id=task_id,
304+
dependencies=dep_tasks,
305+
additional_info=additional_info,
306+
)
307+
308+
self.task_list.append(task)
309+
self._task_registry[task_id] = task
310+
self._last_task_id = task_id # Update last task for chaining
311+
self.task_counter += 1
312+
return self
313+
314+
def add_parallel_tasks(
315+
self,
316+
task_contents: List[str],
317+
dependencies: Optional[List[str]] = None,
318+
task_id_prefix: str = "parallel",
319+
auto_depend: bool = True,
320+
) -> 'PipelineTaskBuilder':
321+
"""Add multiple parallel tasks that can execute simultaneously.
322+
323+
Args:
324+
task_contents (List[str]): List of task content strings.
325+
dependencies (List[str], optional): Common dependencies for all
326+
parallel tasks. If None and auto_depend=True, will depend on
327+
the last added task. (default: :obj:`None`)
328+
task_id_prefix (str, optional): Prefix for generated task IDs.
329+
(default: :obj:`"parallel"`)
330+
auto_depend (bool, optional): If True and dependencies is None,
331+
automatically depend on the last added task.
332+
(default: :obj:`True`)
333+
334+
Returns:
335+
PipelineTaskBuilder: Self for method chaining.
336+
337+
Raises:
338+
ValueError: If any task_id already exists or if any dependency
339+
is not found.
340+
341+
Example:
342+
>>> builder.add("Collect Data").add_parallel_tasks([
343+
... "Technical Analysis", "Fundamental Analysis"
344+
... ]).add_sync_task("Generate Report")
345+
"""
346+
if not task_contents:
347+
raise ValueError("task_contents cannot be empty")
348+
349+
# Auto-infer dependencies if not specified
350+
if (
351+
dependencies is None
352+
and auto_depend
353+
and self._last_task_id is not None
354+
):
355+
dependencies = [self._last_task_id]
356+
357+
parallel_task_ids = []
358+
base_counter = (
359+
self.task_counter
360+
) # Save current counter for consistent naming
361+
362+
for i, content in enumerate(task_contents):
363+
task_id = f"{task_id_prefix}_{i}_{base_counter}"
364+
# Use auto_depend=False since we're manually managing dependencies
365+
self.add(content, task_id, dependencies, auto_depend=False)
366+
parallel_task_ids.append(task_id)
367+
368+
# Set the last task to None since we have multiple parallel endings
369+
# The next task will need to explicitly specify dependencies
370+
self._last_task_id = None
371+
# Store parallel task IDs for potential sync operations
372+
self._last_parallel_tasks = parallel_task_ids
373+
374+
return self
375+
376+
def add_sync_task(
377+
self,
378+
content: str,
379+
wait_for: Optional[List[str]] = None,
380+
task_id: Optional[str] = None,
381+
) -> 'PipelineTaskBuilder':
382+
"""Add a synchronization task that waits for multiple tasks.
383+
384+
Args:
385+
content (str): Content of the synchronization task.
386+
wait_for (List[str], optional): List of task IDs to wait for.
387+
If None, will automatically wait for the last parallel tasks.
388+
(default: :obj:`None`)
389+
task_id (str, optional): ID for the sync task. If None, a unique
390+
ID will be generated. (default: :obj:`None`)
391+
392+
Returns:
393+
PipelineTaskBuilder: Self for method chaining.
394+
395+
Raises:
396+
ValueError: If task_id already exists or if any dependency is
397+
not found.
398+
399+
Example:
400+
>>> builder.add_parallel_tasks(
401+
... ["Task A", "Task B"]
402+
... ).add_sync_task("Merge Results")
403+
# Automatically waits for both parallel tasks
404+
"""
405+
# Auto-infer wait_for from last parallel tasks
406+
if wait_for is None:
407+
if self._last_parallel_tasks:
408+
wait_for = self._last_parallel_tasks
409+
# Clear the parallel tasks after using them
410+
self._last_parallel_tasks = []
411+
else:
412+
raise ValueError(
413+
"wait_for cannot be empty for sync task and no "
414+
"parallel tasks found"
415+
)
416+
417+
if not wait_for:
418+
raise ValueError("wait_for cannot be empty for sync task")
419+
420+
return self.add(
421+
content, task_id, dependencies=wait_for, auto_depend=False
422+
)
423+
424+
def build(self) -> List:
425+
"""Build and return the complete task list with dependencies.
426+
427+
Returns:
428+
List[Task]: List of tasks with proper dependency relationships.
429+
430+
Raises:
431+
ValueError: If there are circular dependencies or other
432+
validation errors.
433+
"""
434+
if not self.task_list:
435+
raise ValueError("No tasks defined in pipeline")
436+
437+
# Validate no circular dependencies
438+
self._validate_dependencies()
439+
440+
return self.task_list.copy()
441+
442+
def clear(self) -> None:
443+
"""Clear all tasks from the builder."""
444+
self.task_list.clear()
445+
self._task_registry.clear()
446+
self.task_counter = 0
447+
self._last_task_id = None
448+
self._last_parallel_tasks = []
449+
450+
def fork(self, task_contents: List[str]) -> 'PipelineTaskBuilder':
451+
"""Create parallel branches from the current task (alias for
452+
add_parallel_tasks).
453+
454+
Args:
455+
task_contents (List[str]): List of task content strings for
456+
parallel execution.
457+
458+
Returns:
459+
PipelineTaskBuilder: Self for method chaining.
460+
461+
Example:
462+
>>> builder.add("Collect Data").fork([
463+
... "Technical Analysis", "Fundamental Analysis"
464+
... ]).join("Generate Report")
465+
"""
466+
return self.add_parallel_tasks(task_contents)
467+
468+
def join(
469+
self, content: str, task_id: Optional[str] = None
470+
) -> 'PipelineTaskBuilder':
471+
"""Join parallel branches with a synchronization task (alias for
472+
add_sync_task).
473+
474+
Args:
475+
content (str): Content of the join/sync task.
476+
task_id (str, optional): ID for the sync task.
477+
478+
Returns:
479+
PipelineTaskBuilder: Self for method chaining.
480+
481+
Example:
482+
>>> builder.fork(["Task A", "Task B"]).join("Merge Results")
483+
"""
484+
return self.add_sync_task(content, task_id=task_id)
485+
486+
def _validate_dependencies(self) -> None:
487+
"""Validate that there are no circular dependencies.
488+
489+
Raises:
490+
ValueError: If circular dependencies are detected.
491+
"""
492+
# Use DFS to detect cycles
493+
visited = set()
494+
rec_stack = set()
495+
496+
def has_cycle(task_id: str) -> bool:
497+
visited.add(task_id)
498+
rec_stack.add(task_id)
499+
500+
task = self._task_registry[task_id]
501+
for dep in task.dependencies:
502+
if dep.id not in visited:
503+
if has_cycle(dep.id):
504+
return True
505+
elif dep.id in rec_stack:
506+
return True
507+
508+
rec_stack.remove(task_id)
509+
return False
510+
511+
for task_id in self._task_registry:
512+
if task_id not in visited:
513+
if has_cycle(task_id):
514+
raise ValueError(
515+
f"Circular dependency detected involving task: "
516+
f"{task_id}"
517+
)
518+
519+
def get_task_info(self) -> dict:
520+
"""Get information about all tasks in the pipeline.
521+
522+
Returns:
523+
dict: Dictionary containing task count and task details.
524+
"""
525+
return {
526+
"task_count": len(self.task_list),
527+
"tasks": [
528+
{
529+
"id": task.id,
530+
"content": task.content,
531+
"dependencies": [dep.id for dep in task.dependencies],
532+
}
533+
for task in self.task_list
534+
],
535+
}
536+
537+
224538
def check_if_running(
225539
running: bool,
226540
max_retries: int = 3,

0 commit comments

Comments
 (0)