Skip to content

Commit e47f40e

Browse files
committed
refactor(workforce): Refactor WorkforceCallback and all related callbacks to async interface
1 parent b8cc4cc commit e47f40e

22 files changed

+502
-476
lines changed

camel/benchmarks/browsecomp.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14-
14+
import asyncio
1515
import base64
1616
import hashlib
1717
import json
@@ -585,15 +585,17 @@ def process_benchmark_row(row: Dict[str, Any]) -> Dict[str, Any]:
585585
input_message = QUERY_TEMPLATE.format(question=problem)
586586

587587
if isinstance(pipeline_template, (ChatAgent)):
588-
pipeline = pipeline_template.clone() # type: ignore[assignment]
588+
chat_pipeline = pipeline_template.clone()
589589

590-
response_text = pipeline.step(
590+
response_text = chat_pipeline.step(
591591
input_message, response_format=QueryResponse
592592
)
593593
elif isinstance(pipeline_template, Workforce):
594-
pipeline = pipeline_template.clone() # type: ignore[assignment]
594+
workforce_pipeline = asyncio.run(pipeline_template.clone())
595595
task = Task(content=input_message, id="0")
596-
task = pipeline.process_task(task) # type: ignore[attr-defined]
596+
task = asyncio.run(
597+
workforce_pipeline.process_task_async(task)
598+
) # type: ignore[attr-defined]
597599
if task_json_formatter:
598600
formatter_in_process = task_json_formatter.clone()
599601
else:
@@ -607,16 +609,16 @@ def process_benchmark_row(row: Dict[str, Any]) -> Dict[str, Any]:
607609

608610
elif isinstance(pipeline_template, RolePlaying):
609611
# RolePlaying is different.
610-
pipeline = pipeline_template.clone( # type: ignore[assignment]
612+
rp_pipeline = pipeline_template.clone(
611613
task_prompt=input_message
612614
)
613615

614616
n = 0
615-
input_msg = pipeline.init_chat() # type: ignore[attr-defined]
617+
input_msg = rp_pipeline.init_chat()
616618
chat_history = []
617619
while n < chat_turn_limit:
618620
n += 1
619-
assistant_response, user_response = pipeline.step(
621+
assistant_response, user_response = rp_pipeline.step(
620622
input_msg
621623
)
622624
if assistant_response.terminated: # type: ignore[union-attr]

camel/societies/workforce/workforce.py

Lines changed: 56 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,9 @@ def __init__(
340340
self.snapshot_interval: float = 30.0
341341
# Shared memory UUID tracking to prevent re-sharing duplicates
342342
self._shared_memory_uuids: Set[str] = set()
343+
# Defer initial worker-created callbacks until an event loop is
344+
# available in async context.
345+
self._pending_worker_created: Deque[BaseNode] = deque(self._children)
343346
self._initialize_callbacks(callbacks)
344347

345348
# Set up coordinator agent with default system message
@@ -533,10 +536,7 @@ def _initialize_callbacks(
533536
"WorkforceLogger addition."
534537
)
535538

536-
for child in self._children:
537-
self._notify_worker_created(child)
538-
539-
def _notify_worker_created(
539+
async def _notify_worker_created(
540540
self,
541541
worker_node: BaseNode,
542542
*,
@@ -552,7 +552,19 @@ def _notify_worker_created(
552552
metadata=metadata,
553553
)
554554
for cb in self._callbacks:
555-
cb.log_worker_created(event)
555+
await cb.log_worker_created(event)
556+
557+
async def _flush_initial_worker_created_callbacks(self) -> None:
558+
r"""Flush pending worker-created callbacks that were queued during
559+
initialization before an event loop was available."""
560+
if not self._pending_worker_created:
561+
return
562+
563+
pending = list(self._pending_worker_created)
564+
self._pending_worker_created.clear()
565+
566+
for child in pending:
567+
await self._notify_worker_created(child)
556568

557569
def _get_or_create_shared_context_utility(
558570
self,
@@ -1669,7 +1681,7 @@ async def _apply_recovery_strategy(
16691681
subtask_ids=[st.id for st in subtasks],
16701682
)
16711683
for cb in self._callbacks:
1672-
cb.log_task_decomposed(task_decomposed_event)
1684+
await cb.log_task_decomposed(task_decomposed_event)
16731685
for subtask in subtasks:
16741686
task_created_event = TaskCreatedEvent(
16751687
task_id=subtask.id,
@@ -1679,7 +1691,7 @@ async def _apply_recovery_strategy(
16791691
metadata=subtask.additional_info,
16801692
)
16811693
for cb in self._callbacks:
1682-
cb.log_task_created(task_created_event)
1694+
await cb.log_task_created(task_created_event)
16831695

16841696
# Insert subtasks at the head of the queue
16851697
self._pending_tasks.extendleft(reversed(subtasks))
@@ -2288,7 +2300,7 @@ async def handle_decompose_append_task(
22882300
return [task]
22892301

22902302
if reset and self._state != WorkforceState.RUNNING:
2291-
self.reset()
2303+
await self.reset()
22922304
logger.info("Workforce reset before handling task.")
22932305

22942306
# Focus on the new task
@@ -2302,7 +2314,7 @@ async def handle_decompose_append_task(
23022314
metadata=task.additional_info,
23032315
)
23042316
for cb in self._callbacks:
2305-
cb.log_task_created(task_created_event)
2317+
await cb.log_task_created(task_created_event)
23062318

23072319
# The agent tend to be overconfident on the whole task, so we
23082320
# decompose the task into subtasks first
@@ -2323,7 +2335,7 @@ async def handle_decompose_append_task(
23232335
subtask_ids=[st.id for st in subtasks],
23242336
)
23252337
for cb in self._callbacks:
2326-
cb.log_task_decomposed(task_decomposed_event)
2338+
await cb.log_task_decomposed(task_decomposed_event)
23272339
for subtask in subtasks:
23282340
task_created_event = TaskCreatedEvent(
23292341
task_id=subtask.id,
@@ -2333,7 +2345,7 @@ async def handle_decompose_append_task(
23332345
metadata=subtask.additional_info,
23342346
)
23352347
for cb in self._callbacks:
2336-
cb.log_task_created(task_created_event)
2348+
await cb.log_task_created(task_created_event)
23372349

23382350
if subtasks:
23392351
# _pending_tasks will contain both undecomposed
@@ -2361,6 +2373,9 @@ async def process_task_async(
23612373
Returns:
23622374
Task: The updated task.
23632375
"""
2376+
# Emit worker-created callbacks lazily once an event loop is present.
2377+
await self._flush_initial_worker_created_callbacks()
2378+
23642379
# Delegate to intervention pipeline when requested to keep
23652380
# backward-compat.
23662381
if interactive:
@@ -2709,7 +2724,7 @@ def _start_child_node_when_paused(
27092724
# Close the coroutine to prevent RuntimeWarning
27102725
start_coroutine.close()
27112726

2712-
def add_single_agent_worker(
2727+
async def add_single_agent_worker(
27132728
self,
27142729
description: str,
27152730
worker: ChatAgent,
@@ -2765,13 +2780,13 @@ def add_single_agent_worker(
27652780
# If workforce is paused, start the worker's listening task
27662781
self._start_child_node_when_paused(worker_node.start())
27672782

2768-
self._notify_worker_created(
2783+
await self._notify_worker_created(
27692784
worker_node,
27702785
worker_type='SingleAgentWorker',
27712786
)
27722787
return self
27732788

2774-
def add_role_playing_worker(
2789+
async def add_role_playing_worker(
27752790
self,
27762791
description: str,
27772792
assistant_role_name: str,
@@ -2842,7 +2857,7 @@ def add_role_playing_worker(
28422857
# If workforce is paused, start the worker's listening task
28432858
self._start_child_node_when_paused(worker_node.start())
28442859

2845-
self._notify_worker_created(
2860+
await self._notify_worker_created(
28462861
worker_node,
28472862
worker_type='RolePlayingWorker',
28482863
)
@@ -2884,7 +2899,7 @@ async def _async_reset(self) -> None:
28842899
self._pause_event.set()
28852900

28862901
@check_if_running(False)
2887-
def reset(self) -> None:
2902+
async def reset(self) -> None:
28882903
r"""Reset the workforce and all the child nodes under it. Can only
28892904
be called when the workforce is not running.
28902905
"""
@@ -2919,9 +2934,7 @@ def reset(self) -> None:
29192934
if self._loop and not self._loop.is_closed():
29202935
# If we have a loop, use it to set the event safely
29212936
try:
2922-
asyncio.run_coroutine_threadsafe(
2923-
self._async_reset(), self._loop
2924-
).result()
2937+
await self._async_reset()
29252938
except RuntimeError as e:
29262939
logger.warning(f"Failed to reset via existing loop: {e}")
29272940
# Fallback to direct event manipulation
@@ -2932,7 +2945,7 @@ def reset(self) -> None:
29322945

29332946
for cb in self._callbacks:
29342947
if isinstance(cb, WorkforceMetrics):
2935-
cb.reset_task_data()
2948+
await cb.reset_task_data()
29362949

29372950
def save_workflow_memories(
29382951
self,
@@ -3800,7 +3813,7 @@ async def _post_task(self, task: Task, assignee_id: str) -> None:
38003813
task_id=task.id, worker_id=assignee_id
38013814
)
38023815
for cb in self._callbacks:
3803-
cb.log_task_started(task_started_event)
3816+
await cb.log_task_started(task_started_event)
38043817

38053818
try:
38063819
await self._channel.post_task(task, self.node_id, assignee_id)
@@ -3947,7 +3960,7 @@ async def _create_worker_node_for_task(self, task: Task) -> Worker:
39473960

39483961
self._children.append(new_node)
39493962

3950-
self._notify_worker_created(
3963+
await self._notify_worker_created(
39513964
new_node,
39523965
worker_type='SingleAgentWorker',
39533966
role=new_node_conf.role,
@@ -4085,7 +4098,7 @@ async def _post_ready_tasks(self) -> None:
40854098
for cb in self._callbacks:
40864099
# queue_time_seconds can be derived by logger if task
40874100
# creation time is logged
4088-
cb.log_task_assigned(task_assigned_event)
4101+
await cb.log_task_assigned(task_assigned_event)
40894102

40904103
# Step 2: Iterate through all pending tasks and post those that are
40914104
# ready
@@ -4243,7 +4256,7 @@ async def _post_ready_tasks(self) -> None:
42434256
},
42444257
)
42454258
for cb in self._callbacks:
4246-
cb.log_task_failed(task_failed_event)
4259+
await cb.log_task_failed(task_failed_event)
42474260

42484261
self._completed_tasks.append(task)
42494262
self._cleanup_task_tracking(task.id)
@@ -4306,7 +4319,7 @@ async def _handle_failed_task(self, task: Task) -> bool:
43064319
},
43074320
)
43084321
for cb in self._callbacks:
4309-
cb.log_task_failed(task_failed_event)
4322+
await cb.log_task_failed(task_failed_event)
43104323

43114324
# Check for immediate halt conditions after max retries.
43124325
if task.failure_count >= MAX_TASK_RETRIES:
@@ -4493,7 +4506,7 @@ async def _handle_completed_task(self, task: Task) -> None:
44934506
metadata={'current_state': task.state.value},
44944507
)
44954508
for cb in self._callbacks:
4496-
cb.log_task_completed(task_completed_event)
4509+
await cb.log_task_completed(task_completed_event)
44974510

44984511
# Find and remove the completed task from pending tasks
44994512
tasks_list = list(self._pending_tasks)
@@ -4609,7 +4622,7 @@ async def _graceful_shutdown(self, failed_task: Task) -> None:
46094622
# Wait for the full timeout period
46104623
await asyncio.sleep(self.graceful_shutdown_timeout)
46114624

4612-
def get_workforce_log_tree(self) -> str:
4625+
async def get_workforce_log_tree(self) -> str:
46134626
r"""Returns an ASCII tree representation of the task hierarchy and
46144627
worker status.
46154628
"""
@@ -4619,19 +4632,19 @@ def get_workforce_log_tree(self) -> str:
46194632
if len(metrics_cb) == 0:
46204633
return "Metrics Callback not initialized."
46214634
else:
4622-
return metrics_cb[0].get_ascii_tree_representation()
4635+
return await metrics_cb[0].get_ascii_tree_representation()
46234636

4624-
def get_workforce_kpis(self) -> Dict[str, Any]:
4637+
async def get_workforce_kpis(self) -> Dict[str, Any]:
46254638
r"""Returns a dictionary of key performance indicators."""
46264639
metrics_cb: List[WorkforceMetrics] = [
46274640
cb for cb in self._callbacks if isinstance(cb, WorkforceMetrics)
46284641
]
46294642
if len(metrics_cb) == 0:
46304643
return {"error": "Metrics Callback not initialized."}
46314644
else:
4632-
return metrics_cb[0].get_kpis()
4645+
return await metrics_cb[0].get_kpis()
46334646

4634-
def dump_workforce_logs(self, file_path: str) -> None:
4647+
async def dump_workforce_logs(self, file_path: str) -> None:
46354648
r"""Dumps all collected logs to a JSON file.
46364649
46374650
Args:
@@ -4643,7 +4656,7 @@ def dump_workforce_logs(self, file_path: str) -> None:
46434656
if len(metrics_cb) == 0:
46444657
print("Logger not initialized. Cannot dump logs.")
46454658
return
4646-
metrics_cb[0].dump_to_json(file_path)
4659+
await metrics_cb[0].dump_to_json(file_path)
46474660
# Use logger.info or print, consistent with existing style
46484661
logger.info(f"Workforce logs dumped to {file_path}")
46494662

@@ -5119,7 +5132,7 @@ async def _listen_to_channel(self) -> None:
51195132
logger.info("All tasks completed.")
51205133
all_tasks_completed_event = AllTasksCompletedEvent()
51215134
for cb in self._callbacks:
5122-
cb.log_all_tasks_completed(all_tasks_completed_event)
5135+
await cb.log_all_tasks_completed(all_tasks_completed_event)
51235136

51245137
# shut down the whole workforce tree
51255138
self.stop()
@@ -5187,7 +5200,7 @@ def stop(self) -> None:
51875200
f"(event-loop not yet started)."
51885201
)
51895202

5190-
def clone(self, with_memory: bool = False) -> 'Workforce':
5203+
async def clone(self, with_memory: bool = False) -> 'Workforce':
51915204
r"""Creates a new instance of Workforce with the same configuration.
51925205
51935206
Args:
@@ -5219,13 +5232,13 @@ def clone(self, with_memory: bool = False) -> 'Workforce':
52195232
for child in self._children:
52205233
if isinstance(child, SingleAgentWorker):
52215234
cloned_worker = child.worker.clone(with_memory)
5222-
new_instance.add_single_agent_worker(
5235+
await new_instance.add_single_agent_worker(
52235236
child.description,
52245237
cloned_worker,
52255238
pool_max_size=10,
52265239
)
52275240
elif isinstance(child, RolePlayingWorker):
5228-
new_instance.add_role_playing_worker(
5241+
await new_instance.add_role_playing_worker(
52295242
child.description,
52305243
child.assistant_role_name,
52315244
child.user_role_name,
@@ -5235,7 +5248,7 @@ def clone(self, with_memory: bool = False) -> 'Workforce':
52355248
child.chat_turn_limit,
52365249
)
52375250
elif isinstance(child, Workforce):
5238-
new_instance.add_workforce(child.clone(with_memory))
5251+
new_instance.add_workforce(await child.clone(with_memory))
52395252
else:
52405253
logger.warning(f"{type(child)} is not being cloned.")
52415254
continue
@@ -5470,7 +5483,7 @@ def get_children_info():
54705483
return children_info
54715484

54725485
# Add single agent worker
5473-
def add_single_agent_worker(
5486+
async def add_single_agent_worker(
54745487
description,
54755488
system_message=None,
54765489
role_name="Assistant",
@@ -5534,7 +5547,9 @@ def add_single_agent_worker(
55345547
"message": str(e),
55355548
}
55365549

5537-
workforce_instance.add_single_agent_worker(description, agent)
5550+
await workforce_instance.add_single_agent_worker(
5551+
description, agent
5552+
)
55385553

55395554
return {
55405555
"status": "success",
@@ -5545,7 +5560,7 @@ def add_single_agent_worker(
55455560
return {"status": "error", "message": str(e)}
55465561

55475562
# Add role playing worker
5548-
def add_role_playing_worker(
5563+
async def add_role_playing_worker(
55495564
description,
55505565
assistant_role_name,
55515566
user_role_name,
@@ -5602,7 +5617,7 @@ def add_role_playing_worker(
56025617
"message": "Cannot add workers while workforce is running", # noqa: E501
56035618
}
56045619

5605-
workforce_instance.add_role_playing_worker(
5620+
await workforce_instance.add_role_playing_worker(
56065621
description=description,
56075622
assistant_role_name=assistant_role_name,
56085623
user_role_name=user_role_name,

0 commit comments

Comments
 (0)