|
| 1 | +# Copyright (c) Alibaba, Inc. and its affiliates. |
| 2 | +from typing import TYPE_CHECKING, List, Optional, Tuple, Union |
| 3 | + |
| 4 | +import json |
| 5 | + |
| 6 | +from .hermes import HermesAgentTemplate |
| 7 | + |
| 8 | +if TYPE_CHECKING: |
| 9 | + from swift.llm.template import Prompt |
| 10 | + |
| 11 | + |
| 12 | +class YoutuAgentTemplate(HermesAgentTemplate): |
| 13 | + """Agent template for Youtu-LLM models. |
| 14 | +
|
| 15 | + Tool calling format: |
| 16 | + - Tool call: <tool_call>{"name": "function-name", "arguments": {...}}</tool_call> |
| 17 | + - Tool response: <tool_response>...</tool_response> |
| 18 | + """ |
| 19 | + |
| 20 | + def _get_tool_responses(self, tool_messages): |
| 21 | + res_tool = [] |
| 22 | + for tool_message in tool_messages: |
| 23 | + tool_content = tool_message['content'] |
| 24 | + res_tool.append(f'<tool_response>{tool_content}</tool_response>') |
| 25 | + return '\n'.join(res_tool) |
| 26 | + |
| 27 | + def _format_tool_responses( |
| 28 | + self, |
| 29 | + assistant_content: str, |
| 30 | + tool_messages, |
| 31 | + ) -> Tuple[str, 'Prompt']: |
| 32 | + with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content |
| 33 | + if with_action: |
| 34 | + return super()._format_tool_responses(assistant_content, tool_messages) |
| 35 | + # For Youtu-LLM, tool responses are placed in user message |
| 36 | + if hasattr(self, 'template_meta'): |
| 37 | + prompt = self.template_meta.prompt |
| 38 | + chat_sep = self.template_meta.chat_sep |
| 39 | + else: |
| 40 | + prompt = ['<|User|>{{QUERY}}<|Assistant|>'] |
| 41 | + chat_sep = ['<|end_of_text|>'] |
| 42 | + res = chat_sep.copy() |
| 43 | + total_tool = self._get_tool_responses(tool_messages) |
| 44 | + for context in prompt: |
| 45 | + if isinstance(context, str): |
| 46 | + context = context.replace('{{QUERY}}', total_tool) |
| 47 | + res.append(context) |
| 48 | + return assistant_content, res |
| 49 | + |
| 50 | + def _format_tools(self, tools: List[Union[str, dict]], system: Optional[str] = None, user_message=None) -> str: |
| 51 | + tool_descs = [json.dumps(self.wrap_tool(tool), ensure_ascii=False) for tool in tools] |
| 52 | + system = system or '' |
| 53 | + if system: |
| 54 | + system = f'{system}\n\n' |
| 55 | + return f"""{system}<|begin_of_tool_description|>Tool calling capabilities. |
| 56 | +You may call one or more functions to assist with the user query. You have the following functions available: |
| 57 | +""" + '\n'.join([f'```json\n{desc}\n```' for desc in tool_descs]) + """ |
| 58 | +For tool call returns, you MUST use the following format: |
| 59 | +<tool_call>{"name": "function-name", "arguments": {"param1": "value1", "param2": "value2"}}</tool_call> |
| 60 | +<|end_of_tool_description|>""" |
| 61 | + |
| 62 | + def _format_tool_calls(self, tool_call_messages): |
| 63 | + tool_calls = [] |
| 64 | + for message in tool_call_messages: |
| 65 | + tool_call = self._parse_tool_call(message['content']) |
| 66 | + tool_calls.append(f'<tool_call>{json.dumps(tool_call, ensure_ascii=False)}</tool_call>') |
| 67 | + return ''.join(tool_calls) |
0 commit comments