Skip to content

Commit ae44769

Browse files
authored
Merge pull request #272 from cubist38/feat/gemma4
Feat/gemma4
2 parents 3fbb47a + c7e03a5 commit ae44769

File tree

3 files changed

+256
-1
lines changed

3 files changed

+256
-1
lines changed

app/parsers/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from .function_parameter import FunctionParameterToolParser
1515
from .functiongemma import FunctionGemmaToolParser
16+
from .gemma4 import Gemma4ReasoningParser, Gemma4ToolParser
1617
from .glm4_moe import GLM4MoEReasoningParser, GLM4MoEToolParser
1718
from .harmony import HarmonyParser
1819
from .hermes import HermesReasoningParser, HermesToolParser
@@ -37,6 +38,7 @@
3738
"minimax_m2": Qwen3MoEReasoningParser, # use Qwen3MoEReasoningParser for MiniMax M2
3839
"nemotron3_nano": Qwen3MoEReasoningParser, # use Qwen3MoEReasoningParser for Nemotron3 Nano
3940
"solar_open": SolarOpenReasoningParser,
41+
"gemma4": Gemma4ReasoningParser,
4042
"kimi_k2": HermesReasoningParser,
4143
"mixed_think_tool_handoff": MixedThinkToolHandoffReasoningParser,
4244
"step_35": Step35ReasoningParser, # backward-compatible parser with legacy implicit-open behavior
@@ -53,6 +55,7 @@
5355
"minimax_m2": MiniMaxM2ToolParser,
5456
"nemotron3_nano": FunctionParameterToolParser, # use FunctionParameterToolParser for Nemotron3 Nano
5557
"functiongemma": FunctionGemmaToolParser,
58+
"gemma4": Gemma4ToolParser,
5659
"iquest_coder_v1": HermesToolParser, # use HermesToolParser for IQuest Coder V1
5760
"solar_open": SolarOpenToolParser,
5861
"longcat_flash_lite": LongCatFlashLiteToolParser,
@@ -309,12 +312,14 @@ def is_unified_parser(parser_name: str | None) -> bool:
309312
"Qwen3ReasoningParser",
310313
"Qwen3MoEReasoningParser",
311314
"Qwen35ReasoningParser",
315+
"Gemma4ReasoningParser",
312316
"GLM4MoEReasoningParser",
313317
"SolarOpenReasoningParser",
314318
"MixedThinkToolHandoffReasoningParser",
315319
"Step35ReasoningParser",
316320
# Tool parsers
317321
"HermesToolParser",
322+
"Gemma4ToolParser",
318323
"GLM4MoEToolParser",
319324
"MiniMaxM2ToolParser",
320325
"FunctionGemmaToolParser",

app/parsers/gemma4.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
from __future__ import annotations
2+
3+
import json
4+
import re
5+
6+
from loguru import logger
7+
8+
from .abstract_parser import (
9+
AbstractReasoningParser,
10+
AbstractToolParser,
11+
ReasoningParserState,
12+
_suffix_prefix_overlap,
13+
)
14+
15+
REASONING_OPEN = "<|channel>thought\n"
16+
REASONING_CLOSE = "<channel|>"
17+
18+
TOOL_OPEN = "<|tool_call>"
19+
TOOL_CLOSE = "<tool_call|>"
20+
21+
STRING_OPEN = '<|"|>'
22+
STRING_CLOSE = '<|"|>'
23+
24+
25+
# ---------------------------------------------------------------------------
26+
# Gemma 4 value format parser
27+
# ---------------------------------------------------------------------------
28+
# The model serialises tool-call arguments in a custom format:
29+
# - strings: <|"|>text<|"|>
30+
# - booleans: true / false
31+
# - null: null
32+
# - numbers: raw digits (int or float)
33+
# - objects: {key:value,key:value} (keys are bare identifiers)
34+
# - arrays: [value,value]
35+
# ---------------------------------------------------------------------------
36+
37+
38+
def _parse_value(text: str, pos: int) -> tuple[object, int]:
39+
"""Parse a single value starting at *pos* and return (value, new_pos)."""
40+
if pos >= len(text):
41+
return None, pos
42+
43+
# String
44+
if text[pos : pos + len(STRING_OPEN)] == STRING_OPEN:
45+
start = pos + len(STRING_OPEN)
46+
end = text.index(STRING_CLOSE, start)
47+
return text[start:end], end + len(STRING_CLOSE)
48+
49+
# Object
50+
if text[pos] == "{":
51+
return _parse_object(text, pos)
52+
53+
# Array
54+
if text[pos] == "[":
55+
return _parse_array(text, pos)
56+
57+
# Boolean / null
58+
if text[pos : pos + 4] == "true":
59+
return True, pos + 4
60+
if text[pos : pos + 5] == "false":
61+
return False, pos + 5
62+
if text[pos : pos + 4] == "null":
63+
return None, pos + 4
64+
65+
# Number – consume until delimiter
66+
end = pos
67+
while end < len(text) and text[end] not in ",}]":
68+
end += 1
69+
num_str = text[pos:end]
70+
try:
71+
return (float(num_str) if "." in num_str else int(num_str)), end
72+
except ValueError:
73+
return num_str, end
74+
75+
76+
def _parse_object(text: str, pos: int) -> tuple[dict, int]:
77+
"""Parse ``{key:value, ...}`` starting at *pos*."""
78+
pos += 1 # skip '{'
79+
result: dict = {}
80+
while pos < len(text) and text[pos] != "}":
81+
# Key – bare identifier (letters, digits, underscores)
82+
key_end = pos
83+
while key_end < len(text) and text[key_end] not in ":}":
84+
key_end += 1
85+
key = text[pos:key_end]
86+
pos = key_end
87+
if pos < len(text) and text[pos] == ":":
88+
pos += 1 # skip ':'
89+
value, pos = _parse_value(text, pos)
90+
result[key] = value
91+
if pos < len(text) and text[pos] == ",":
92+
pos += 1 # skip ','
93+
if pos < len(text) and text[pos] == "}":
94+
pos += 1
95+
return result, pos
96+
97+
98+
def _parse_array(text: str, pos: int) -> tuple[list, int]:
99+
"""Parse ``[value, ...]`` starting at *pos*."""
100+
pos += 1 # skip '['
101+
result: list = []
102+
while pos < len(text) and text[pos] != "]":
103+
value, pos = _parse_value(text, pos)
104+
result.append(value)
105+
if pos < len(text) and text[pos] == ",":
106+
pos += 1
107+
if pos < len(text) and text[pos] == "]":
108+
pos += 1
109+
return result, pos
110+
111+
112+
def _parse_tool_call_body(body: str) -> dict | None:
113+
"""Parse ``call:func_name{args}`` into ``{name, arguments}``."""
114+
m = re.match(r"call:(\w+)", body.strip())
115+
if not m:
116+
return None
117+
name = m.group(1)
118+
brace_idx = body.find("{", m.end())
119+
if brace_idx == -1:
120+
return {"name": name, "arguments": "{}"}
121+
try:
122+
args, _ = _parse_object(body, brace_idx)
123+
except (ValueError, IndexError):
124+
logger.warning(f"Failed to parse Gemma4 tool call arguments: {body[:120]}")
125+
return None
126+
return {"name": name, "arguments": json.dumps(args, ensure_ascii=False)}
127+
128+
129+
# ---------------------------------------------------------------------------
130+
# Reasoning parser
131+
# ---------------------------------------------------------------------------
132+
133+
134+
class Gemma4ReasoningParser(AbstractReasoningParser):
135+
r"""Reasoning parser for Gemma 4 models.
136+
137+
Thinking content is wrapped in:
138+
<|channel>thought\n ... <channel|>
139+
"""
140+
141+
def __init__(
142+
self,
143+
reasoning_open: str = REASONING_OPEN,
144+
reasoning_close: str = REASONING_CLOSE,
145+
) -> None:
146+
super().__init__(reasoning_open=reasoning_open, reasoning_close=reasoning_close)
147+
self.reasoning_regex = re.compile(
148+
re.escape(reasoning_open) + r"(.*?)" + re.escape(reasoning_close),
149+
re.DOTALL,
150+
)
151+
152+
def respects_enable_thinking(self) -> bool:
153+
return True
154+
155+
def extract_reasoning(self, model_output: str) -> dict[str, str] | None:
156+
matches = self.reasoning_regex.findall(model_output)
157+
if not matches:
158+
return {"content": model_output}
159+
reasoning_content_end_idx = model_output.rfind(self.reasoning_close)
160+
after = model_output[reasoning_content_end_idx + len(self.reasoning_close) :]
161+
return {
162+
"reasoning_content": matches[0],
163+
"after_reasoning_close_content": after,
164+
}
165+
166+
def extract_reasoning_streaming(self, chunk: str) -> tuple[dict[str, str] | str | None, bool]:
167+
if self.reasoning_open in chunk:
168+
self.state = ReasoningParserState.FOUND_PREFIX
169+
start_idx = chunk.find(self.reasoning_open)
170+
reasoning_content = chunk[start_idx + len(self.reasoning_open) :]
171+
172+
if self.reasoning_close in reasoning_content:
173+
end_idx = reasoning_content.find(self.reasoning_close)
174+
after = reasoning_content[end_idx + len(self.reasoning_close) :]
175+
self.state = ReasoningParserState.NORMAL
176+
return {
177+
"reasoning_content": reasoning_content[:end_idx],
178+
"after_reasoning_close_content": after,
179+
}, True
180+
181+
overlap = _suffix_prefix_overlap(reasoning_content, self.reasoning_close)
182+
if overlap > 0:
183+
emitted = reasoning_content[:-overlap]
184+
self.buffer = reasoning_content[-overlap:]
185+
else:
186+
emitted = reasoning_content
187+
self.buffer = ""
188+
189+
if emitted:
190+
return {"reasoning_content": emitted}, False
191+
return None, False
192+
193+
if self.state == ReasoningParserState.FOUND_PREFIX:
194+
combined = self.buffer + chunk
195+
if self.reasoning_close in combined:
196+
end_idx = combined.find(self.reasoning_close)
197+
reasoning_content = combined[:end_idx]
198+
after = combined[end_idx + len(self.reasoning_close) :]
199+
self.buffer = ""
200+
return {
201+
"reasoning_content": reasoning_content,
202+
"after_reasoning_close_content": after,
203+
}, True
204+
205+
overlap = _suffix_prefix_overlap(combined, self.reasoning_close)
206+
if overlap > 0:
207+
reasoning_content = combined[:-overlap]
208+
self.buffer = combined[-overlap:]
209+
else:
210+
reasoning_content = combined
211+
self.buffer = ""
212+
213+
if reasoning_content:
214+
return {"reasoning_content": reasoning_content}, False
215+
return None, False
216+
217+
return {"content": chunk}, False
218+
219+
220+
# ---------------------------------------------------------------------------
221+
# Tool call parser
222+
# ---------------------------------------------------------------------------
223+
224+
225+
class Gemma4ToolParser(AbstractToolParser):
226+
"""Tool parser for Gemma 4 models.
227+
228+
Tool calls use the format:
229+
<|tool_call>call:func_name{key:<|"|>value<|"|>,num:42}<tool_call|>
230+
"""
231+
232+
def __init__(self, tool_open: str = TOOL_OPEN, tool_close: str = TOOL_CLOSE) -> None:
233+
super().__init__(tool_open=tool_open, tool_close=tool_close)
234+
self.tool_call_regex = re.compile(
235+
re.escape(TOOL_OPEN) + r"(.*?)" + re.escape(TOOL_CLOSE),
236+
re.DOTALL,
237+
)
238+
239+
def extract_tool_calls(self, model_output: str) -> dict[str, list] | None:
240+
matches = self.tool_call_regex.findall(model_output)
241+
if not matches:
242+
return {"content": model_output}
243+
tool_calls = []
244+
for match in matches:
245+
parsed = _parse_tool_call_body(match)
246+
if parsed:
247+
tool_calls.append(parsed)
248+
if not tool_calls:
249+
return {"content": model_output}
250+
return {"tool_calls": tool_calls}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ dependencies = [
3232
"loguru>=0.7.3,<0.8",
3333
"mlx-embeddings>=0.0.5,<0.1",
3434
"mlx-lm>=0.31.0,<0.32",
35-
"mlx-vlm>=0.4.2,<0.5",
35+
"mlx-vlm>=0.4.3,<0.5",
3636
"mlx-whisper>=0.4.3,<0.5",
3737
"mlx>=0.31.0",
3838
"numpy>=2.2.0,<2.4",

0 commit comments

Comments
 (0)