|
46 | 46 | from .utils.common_utils import ( |
47 | 47 | generate_uuid, |
48 | 48 | get_format_time, |
| 49 | + get_timestamp, |
49 | 50 | msgpack_preprocess, |
50 | 51 | print_tree, |
51 | 52 | to_json, |
@@ -100,6 +101,8 @@ class MAS(BaseModel): |
100 | 101 | routers: list = Field(default_factory=list) |
101 | 102 | middlewares: list = Field(default_factory=list) |
102 | 103 |
|
| 104 | + stream_dict: dict[str, list] = Field(default_factory=dict) |
| 105 | + |
103 | 106 | def __init__(self, **kwargs): |
104 | 107 | """Construct a new :class:`MAS`. |
105 | 108 |
|
@@ -284,9 +287,14 @@ async def init_db(self): |
284 | 287 | "mappings": { |
285 | 288 | "properties": { |
286 | 289 | "message_id": {"type": "keyword"}, |
| 290 | + "group_id": {"type": "keyword"}, |
287 | 291 | "trace_id": {"type": "keyword"}, |
| 292 | + "node_id": {"type": "keyword"}, |
| 293 | + "node_name": {"type": "keyword"}, |
288 | 294 | "message": {"type": "text"}, |
289 | 295 | "message_type": {"type": "keyword"}, |
| 296 | + "message_event": {"type": "keyword"}, |
| 297 | + "message_timestamp": {"type": "long"}, |
290 | 298 | "create_time": { |
291 | 299 | "format": "yyyy-MM-dd HH:mm:ss.SSSSSSSSS", |
292 | 300 | "type": "date", |
@@ -566,7 +574,9 @@ async def call(self, callee, arguments, **kwargs): |
566 | 574 | oxy_response = await oxy.execute(oxy_request) |
567 | 575 | return oxy_response.output |
568 | 576 |
|
569 | | - async def send_message(self, sse_message: SSEMessage, redis_key: str): |
| 577 | + async def send_message( |
| 578 | + self, sse_message: SSEMessage, redis_key: str, group_id: str = "" |
| 579 | + ): |
570 | 580 | """Push *message* onto a capped Redis list. |
571 | 581 |
|
572 | 582 | The data is MsgPack‑encoded before being stored. At most **10** items |
@@ -598,18 +608,78 @@ async def send_message(self, sse_message: SSEMessage, redis_key: str): |
598 | 608 | parts = redis_key.split(":") |
599 | 609 | current_trace_id = parts[-1] if len(parts) >= 3 else "" |
600 | 610 |
|
601 | | - # Insert into Elasticsearch |
602 | | - await self.es_client.index( |
603 | | - Config.get_app_name() + "_message", |
604 | | - doc_id=sse_message.id, |
605 | | - body={ |
606 | | - "message_id": sse_message.id, |
607 | | - "trace_id": current_trace_id, |
608 | | - "message": to_json(message), |
609 | | - "message_type": message_type, |
610 | | - "create_time": get_format_time(), |
611 | | - }, |
612 | | - ) |
| 611 | + # 考虑 message 是 str 的情况 |
| 612 | + node_id = "" |
| 613 | + node_name = "" |
| 614 | + message_timestamp = get_timestamp() |
| 615 | + if isinstance(message, dict): |
| 616 | + message_timestamp = message.get("timestamp", get_timestamp()) |
| 617 | + if isinstance(message.get("content"), dict): |
| 618 | + node_id = message.get("content", {}).get("node_id", "") |
| 619 | + node_name = message.get("content", {}).get("agent", "") |
| 620 | + |
| 621 | + if message_type in ["stream", "stream_end"]: |
| 622 | + # 排队 |
| 623 | + if message_type == "stream": |
| 624 | + delta = message.get("content", {}).get("delta", "") |
| 625 | + if node_id not in self.stream_dict: |
| 626 | + self.stream_dict[node_id] = [] |
| 627 | + self.stream_dict[node_id].append(delta) |
| 628 | + if message_type == "stream_end" or ( |
| 629 | + self.stream_dict[node_id] |
| 630 | + and len(self.stream_dict[node_id]) |
| 631 | + % Config.get_message_stream_batch_size() |
| 632 | + == 0 |
| 633 | + ): |
| 634 | + message_id = generate_uuid() |
| 635 | + merged_type = "merged_stream" |
| 636 | + save_message_task = asyncio.create_task( |
| 637 | + self.es_client.index( |
| 638 | + Config.get_app_name() + "_message", |
| 639 | + doc_id=message_id, |
| 640 | + body={ |
| 641 | + "message_id": message_id, |
| 642 | + "group_id": group_id, |
| 643 | + "trace_id": current_trace_id, |
| 644 | + "node_id": node_id, |
| 645 | + "node_name": node_name, |
| 646 | + "message": to_json( |
| 647 | + { |
| 648 | + "type": merged_type, |
| 649 | + "content": "".join(self.stream_dict[node_id]), |
| 650 | + } |
| 651 | + ), |
| 652 | + "message_type": merged_type, |
| 653 | + "message_event": sse_message.event, |
| 654 | + "message_timestamp": message_timestamp, |
| 655 | + "create_time": get_format_time(), |
| 656 | + }, |
| 657 | + ) |
| 658 | + ) |
| 659 | + save_message_task.add_done_callback(self.background_tasks.discard) |
| 660 | + self.background_tasks.add(save_message_task) |
| 661 | + self.stream_dict[node_id].clear() |
| 662 | + else: |
| 663 | + save_message_task = asyncio.create_task( |
| 664 | + self.es_client.index( |
| 665 | + Config.get_app_name() + "_message", |
| 666 | + doc_id=sse_message.id, |
| 667 | + body={ |
| 668 | + "message_id": sse_message.id, |
| 669 | + "group_id": group_id, |
| 670 | + "trace_id": current_trace_id, |
| 671 | + "node_id": node_id, |
| 672 | + "node_name": node_name, |
| 673 | + "message": to_json(message), |
| 674 | + "message_type": message_type, |
| 675 | + "message_event": sse_message.event, |
| 676 | + "message_timestamp": message_timestamp, |
| 677 | + "create_time": get_format_time(), |
| 678 | + }, |
| 679 | + ) |
| 680 | + ) |
| 681 | + save_message_task.add_done_callback(self.background_tasks.discard) |
| 682 | + self.background_tasks.add(save_message_task) |
613 | 683 | if message_is_send: |
614 | 684 | bytes_msg = msgpack.packb(msgpack_preprocess(sse_message.to_sse())) |
615 | 685 | await self.redis_client.lpush(redis_key, bytes_msg) |
@@ -745,7 +815,9 @@ async def chat_with_agent( |
745 | 815 |
|
746 | 816 | if send_msg_key: |
747 | 817 | await self.send_message( |
748 | | - SSEMessage(event="close", data="done"), send_msg_key |
| 818 | + SSEMessage(event="close", data="done"), |
| 819 | + send_msg_key, |
| 820 | + group_id=oxy_request.group_id, |
749 | 821 | ) |
750 | 822 | return oxy_response |
751 | 823 | except Exception: |
|
0 commit comments