|
| 1 | +# Copyright 2025 Horizon RL Contributors |
| 2 | + |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | + |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | + |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Tool call parser for GLM (ChatGLM) models. |
| 16 | +
|
| 17 | +GLM models use an XML key-value format for tool calls instead of the |
| 18 | +JSON format used by Hermes/Qwen models:: |
| 19 | +
|
| 20 | + <tool_call>function_name |
| 21 | + <arg_key>key1</arg_key> |
| 22 | + <arg_value>value1</arg_value> |
| 23 | + <arg_key>key2</arg_key> |
| 24 | + <arg_value>value2</arg_value> |
| 25 | + </tool_call> |
| 26 | +
|
| 27 | +Values are either plain strings or JSON-encoded (for non-string types). |
| 28 | +""" |
| 29 | + |
| 30 | +from __future__ import annotations |
| 31 | + |
| 32 | +import json |
| 33 | +import logging |
| 34 | +import re |
| 35 | +from typing import Any |
| 36 | + |
| 37 | +from .base import UNKNOWN_TOOL_NAME, ToolParser, ToolParseResult, register_tool_parser |
| 38 | + |
| 39 | +logger = logging.getLogger(__name__) |
| 40 | + |
| 41 | + |
| 42 | +@register_tool_parser("glm") |
| 43 | +class GLMToolCallParser(ToolParser): |
| 44 | + """Parser for GLM XML key-value tool call format. |
| 45 | +
|
| 46 | + Format: |
| 47 | + <tool_call>function_name |
| 48 | + <arg_key>key1</arg_key> |
| 49 | + <arg_value>value1</arg_value> |
| 50 | + <arg_key>key2</arg_key> |
| 51 | + <arg_value>value2</arg_value> |
| 52 | + </tool_call> |
| 53 | +
|
| 54 | + This format uses a key-value pair structure where the function name |
| 55 | + appears on the first line after <tool_call>, followed by alternating |
| 56 | + <arg_key> and <arg_value> tags. Values can be plain strings or |
| 57 | + JSON-encoded for non-string types. |
| 58 | +
|
| 59 | + Think Block Handling: |
| 60 | + Models with reasoning capabilities may output draft tool calls |
| 61 | + inside <think>...</think> blocks. These are excluded by default |
| 62 | + to avoid executing planning/reasoning tool calls. |
| 63 | + Set think_tokens=None to disable this behavior. |
| 64 | +
|
| 65 | + Chat Template Notes: |
| 66 | + GLM uses no explicit separator between messages. |
| 67 | +
|
| 68 | + Attributes: |
| 69 | + tool_call_tokens: Start/end delimiters for tool calls. |
| 70 | + think_tokens: Start/end delimiters for think blocks to exclude. |
| 71 | + """ |
| 72 | + |
| 73 | + DEFAULT_TOOL_CALL_TOKENS = ("<tool_call>", "</tool_call>") |
| 74 | + DEFAULT_THINK_TOKENS = ("<think>", "</think>") |
| 75 | + |
| 76 | + def __init__( |
| 77 | + self, |
| 78 | + tool_call_tokens: tuple[str, str] = DEFAULT_TOOL_CALL_TOKENS, |
| 79 | + think_tokens: tuple[str, str] | None = DEFAULT_THINK_TOKENS, |
| 80 | + ) -> None: |
| 81 | + """Initialize the parser with optional custom tokens. |
| 82 | +
|
| 83 | + Args: |
| 84 | + tool_call_tokens: (start, end) delimiters for tool calls. |
| 85 | + think_tokens: (start, end) delimiters for think blocks to exclude. |
| 86 | + Set to None to disable think block exclusion. |
| 87 | + """ |
| 88 | + self.tool_call_tokens = tool_call_tokens |
| 89 | + self.think_tokens = think_tokens |
| 90 | + |
| 91 | + self._pattern = re.compile( |
| 92 | + rf"{re.escape(tool_call_tokens[0])}\s*(.*?)\s*{re.escape(tool_call_tokens[1])}", |
| 93 | + re.DOTALL, |
| 94 | + ) |
| 95 | + |
| 96 | + if think_tokens: |
| 97 | + self._think_pattern: re.Pattern[str] | None = re.compile( |
| 98 | + rf"{re.escape(think_tokens[0])}.*?{re.escape(think_tokens[1])}", |
| 99 | + re.DOTALL, |
| 100 | + ) |
| 101 | + else: |
| 102 | + self._think_pattern = None |
| 103 | + |
| 104 | + self._arg_pattern = re.compile( |
| 105 | + r"<arg_key>\s*(.*?)\s*</arg_key>\s*<arg_value>\s*(.*?)\s*</arg_value>", |
| 106 | + re.DOTALL, |
| 107 | + ) |
| 108 | + |
| 109 | + @property |
| 110 | + def message_separator(self) -> str: |
| 111 | + """Separator between messages in the chat template. |
| 112 | +
|
| 113 | + GLM uses no explicit separator between messages. |
| 114 | + """ |
| 115 | + return "" |
| 116 | + |
| 117 | + def parse(self, text: str) -> list[ToolParseResult]: |
| 118 | + """Parse tool calls from GLM model output. |
| 119 | +
|
| 120 | + Extracts the function name from the first line after ``<tool_call>``, |
| 121 | + then parses ``<arg_key>``/``<arg_value>`` pairs into a dict. |
| 122 | +
|
| 123 | + Args: |
| 124 | + text: Model output text. |
| 125 | +
|
| 126 | + Returns: |
| 127 | + List of tool call results (successful and errors). |
| 128 | + """ |
| 129 | + # Remove think blocks to avoid parsing draft tool calls from reasoning |
| 130 | + if self._think_pattern: |
| 131 | + text = self._think_pattern.sub("", text) |
| 132 | + |
| 133 | + tool_calls: list[ToolParseResult] = [] |
| 134 | + |
| 135 | + for i, match in enumerate(self._pattern.finditer(text)): |
| 136 | + raw_content = match.group(1).strip() |
| 137 | + tool_call_id = f"call_{i:04d}" # Sequential IDs for sortability |
| 138 | + |
| 139 | + # Function name is on the first line |
| 140 | + lines = raw_content.split("\n", 1) |
| 141 | + name = lines[0].strip() |
| 142 | + |
| 143 | + # Check if name is missing or contains XML tags (indicating we picked up arg tags instead) |
| 144 | + if not name or "<" in name or ">" in name: |
| 145 | + tool_calls.append(self._make_error_tool_call(raw_content, tool_call_id, "missing function name")) |
| 146 | + continue |
| 147 | + |
| 148 | + # Parse <arg_key>/<arg_value> pairs |
| 149 | + arguments: dict[str, Any] = {} |
| 150 | + rest = lines[1] if len(lines) > 1 else "" |
| 151 | + for arg_match in self._arg_pattern.finditer(rest): |
| 152 | + key = arg_match.group(1).strip() |
| 153 | + value_str = arg_match.group(2).strip() |
| 154 | + try: |
| 155 | + value = json.loads(value_str) |
| 156 | + except (json.JSONDecodeError, ValueError): |
| 157 | + value = value_str |
| 158 | + arguments[key] = value |
| 159 | + |
| 160 | + tool_calls.append(ToolParseResult(id=tool_call_id, name=name, input=arguments)) |
| 161 | + |
| 162 | + return tool_calls |
| 163 | + |
| 164 | + def _make_error_tool_call( |
| 165 | + self, |
| 166 | + raw_content: str, |
| 167 | + tool_call_id: str, |
| 168 | + error: str, |
| 169 | + ) -> ToolParseResult: |
| 170 | + """Create an error tool call for parse failures.""" |
| 171 | + # Try to extract function name from first line |
| 172 | + lines = raw_content.split("\n", 1) |
| 173 | + name = lines[0].strip() if lines else UNKNOWN_TOOL_NAME |
| 174 | + # If name is empty or contains XML tags, use UNKNOWN_TOOL_NAME |
| 175 | + if not name or "<" in name or ">" in name: |
| 176 | + name = UNKNOWN_TOOL_NAME |
| 177 | + |
| 178 | + logger.warning(f"Tool call parse error: {error}") |
| 179 | + |
| 180 | + return ToolParseResult( |
| 181 | + id=tool_call_id, |
| 182 | + name=name, |
| 183 | + input={}, |
| 184 | + raw=raw_content, |
| 185 | + ) |
0 commit comments