Skip to content

Commit a35e87b

Browse files
[tools][wip] Add support for Llama 3.2 tool-call injection: batch tool calls, user message integration, and enhanced response parsing.
1 parent bf18441 commit a35e87b

5 files changed

Lines changed: 236 additions & 49 deletions

File tree

src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,54 @@ default ChatTokens chatTokens() {
3939

4040
/**
4141
* Returns plain text to append to the system message content when tools are available.
42-
* The returned string is concatenated to the system message before encoding, so the
43-
* normal {@link #encodeMessage} path handles tokenization.
42+
* Used by formats that inject tool definitions into the <em>system</em> message.
4443
*
45-
* @param toolsJson JSON array of tool definitions, e.g.
46-
* {@code [{"type":"function","function":{...}}]}
44+
* <p>Formats that inject tools into the <em>user</em> message instead should override
45+
* {@link #injectsToolsInUserMessage()}, {@link #toolSystemMessagePrefix()}, and
46+
* {@link #toolFirstUserMessagePrefix(String)} rather than this method.
47+
*
48+
* @param toolsJson JSON array of tool definitions
4749
*/
4850
default String toolSystemPromptSuffix(String toolsJson) {
4951
throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName());
5052
}
5153

54+
/**
55+
* Returns {@code true} when this format injects tool definitions into the
56+
* <em>first user message</em> instead of the system message.
57+
*
58+
* <p>When this returns {@code true}, callers should:
59+
* <ol>
60+
* <li>Prepend {@link #toolSystemMessagePrefix()} to the system message content.</li>
61+
* <li>Prepend {@link #toolFirstUserMessagePrefix(String)} to the first user message.</li>
62+
* </ol>
63+
* When {@code false} (default), callers should append {@link #toolSystemPromptSuffix} to
64+
* the system message as before.
65+
*/
66+
default boolean injectsToolsInUserMessage() {
67+
return false;
68+
}
69+
70+
/**
71+
* Returns text to <em>prepend</em> to the system message content when tools are active
72+
* and {@link #injectsToolsInUserMessage()} is {@code true}.
73+
* Default: empty string (no prefix).
74+
*/
75+
default String toolSystemMessagePrefix() {
76+
return "";
77+
}
78+
79+
/**
80+
* Returns the preamble to <em>prepend</em> to the first user message when
81+
* {@link #injectsToolsInUserMessage()} is {@code true}.
82+
* The preamble should include the tool definitions and usage instructions.
83+
*
84+
* @param toolsJson JSON array of tool definitions
85+
*/
86+
default String toolFirstUserMessagePrefix(String toolsJson) {
87+
return "";
88+
}
89+
5290
/**
5391
* Re-encodes a prior assistant tool-call turn into the conversation token stream.
5492
* Used when replaying multi-turn history that contains a previous tool call.
@@ -80,6 +118,18 @@ default Optional<ToolCallExtract> extractToolCall(String responseText) {
80118
return Optional.empty();
81119
}
82120

121+
/**
122+
* Extracts ALL tool calls from a response. Models may emit multiple
123+
* {@code <tool_call>} blocks in a single turn (batch tool calls).
124+
* The default delegates to {@link #extractToolCall} for formats that
125+
* do not support batch calls.
126+
*
127+
* @param responseText the fully decoded response from the model
128+
*/
129+
default List<ToolCallExtract> extractAllToolCalls(String responseText) {
130+
return extractToolCall(responseText).map(List::of).orElse(List.of());
131+
}
132+
83133
/**
84134
* Stop tokens to use when tool calling is enabled.
85135
* Some models (LLaMA 3.1+) use a different end-of-turn token ({@code <|eom_id|>})

src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -78,32 +78,58 @@ public List<Integer> encodeDialogPrompt(boolean appendAssistantTurn, List<Messag
7878
// ── Tool calling ──────────────────────────────────────────────────────────
7979

8080
/**
81-
* LLaMA 3.1 tool calling system prompt suffix.
82-
* Instructs the model to respond with JSON using the {"name":…,"parameters":{…}} format.
81+
* Llama 3.2 Instruct injects tool definitions into the <em>first user message</em>
82+
* (the GGUF-embedded chat template has {@code tools_in_user_message = true} by default).
83+
* The system message receives only an environment prefix; the tools and usage instructions
84+
* go in the user turn.
8385
*/
8486
@Override
85-
public String toolSystemPromptSuffix(String toolsJson) {
86-
return "\n\n# Tools\n\n"
87-
+ "You may call one or more functions to assist with the user query.\n\n"
88-
+ "You are provided with function signatures within <tools></tools> XML tags:\n\n"
89-
+ "<tools>\n" + toolsJson + "\n</tools>\n\n"
90-
+ "IMPORTANT: the \"name\" field in your tool call MUST be exactly one of the function names "
91-
+ "listed inside <tools> above — not a path, not a word from the user's message.\n\n"
92-
+ "For each function call, return a json object with function name and arguments "
93-
+ "within <tool_call></tool_call> XML tags:\n\n"
94-
+ "<tool_call>\n"
95-
+ "{\"name\": <function-name>, \"arguments\": <args-json-object>}\n"
96-
+ "</tool_call>";
87+
public boolean injectsToolsInUserMessage() {
88+
return true;
9789
}
9890

9991
/**
100-
* Re-encodes a prior assistant tool-call turn for multi-turn history.
101-
* Format: {@code <|start_header_id|>assistant<|end_header_id|>\n<|python_tag|>JSON<|eom_id|>}
92+
* System-message prefix that signals tool availability to Llama 3.2.
93+
* Matches the template's {@code "Environment: ipython\n"} line.
94+
*/
95+
@Override
96+
public String toolSystemMessagePrefix() {
97+
return "Environment: ipython\n\n";
98+
}
99+
100+
/**
101+
* Prepends tool definitions and usage instructions to the first user message,
102+
* matching the Llama 3.2 GGUF chat template ({@code tools_in_user_message = true}).
103+
*
104+
* <p>Format mirrors:
105+
* <pre>
106+
* Given the following functions, please respond with a JSON for a function call
107+
* with its proper arguments that best answers the given prompt.
108+
*
109+
* Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.
110+
* Do not use variables.
111+
*
112+
* {toolsJson}
113+
*
114+
* </pre>
115+
*/
116+
@Override
117+
public String toolFirstUserMessagePrefix(String toolsJson) {
118+
return "Given the following functions, please respond with a JSON for a function call "
119+
+ "with its proper arguments that best answers the given prompt.\n\n"
120+
+ "Respond in the format {\"name\": function name, \"parameters\": dictionary of "
121+
+ "argument name and its value}. Do not use variables.\n\n"
122+
+ toolsJson + "\n\n";
123+
}
124+
125+
/**
126+
* Re-encodes a prior assistant tool-call turn for multi-turn history using the
127+
* Llama 3.2 native JSON format: {@code {"name":"…","parameters":{…}}<|eot_id|>}.
102128
*/
103129
@Override
104130
public List<Integer> encodeToolCallAssistantTurn(ToolCallExtract toolCall) {
105131
List<Integer> tokens = new ArrayList<>(encodeHeader(new Message(Role.ASSISTANT, "")));
106-
String json = "<tool_call>\n{\"name\":\"" + toolCall.name() + "\",\"arguments\":" + toolCall.argumentsJson() + "}\n</tool_call>";
132+
String json = "{\"name\": \"" + toolCall.name() + "\", \"parameters\": " + toolCall.argumentsJson() + "}";
107133
tokens.addAll(tokenizer.encodeAsList(json));
108134
tokens.add(endOfTurn);
109135
return tokens;
@@ -136,6 +162,11 @@ public Optional<ToolCallExtract> extractToolCall(String responseText) {
136162
return ToolCallParserUtils.parseLlamaResponse(responseText);
137163
}
138164

165+
@Override
166+
public List<ToolCallExtract> extractAllToolCalls(String responseText) {
167+
return ToolCallParserUtils.parseAllToolCalls(responseText);
168+
}
169+
139170
/**
140171
* Adds {@code <|eom_id|>} to the stop tokens when tools are enabled.
141172
* LLaMA 3.1 ends tool-call turns with {@code <|eom_id|>} instead of {@code <|eot_id|>}.

src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,4 +193,9 @@ public List<Integer> encodeToolResultTurn(String toolCallId, String toolName, St
193193
public Optional<ToolCallExtract> extractToolCall(String responseText) {
194194
return ToolCallParserUtils.parseQwen3Response(responseText);
195195
}
196+
197+
@Override
198+
public List<ToolCallExtract> extractAllToolCalls(String responseText) {
199+
return ToolCallParserUtils.parseAllToolCalls(responseText);
200+
}
196201
}

src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java

Lines changed: 78 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package org.beehive.gpullama3.model.format;
22

3+
import java.util.ArrayList;
4+
import java.util.List;
35
import java.util.Optional;
46

57
/**
@@ -41,6 +43,11 @@ public static Optional<ToolCallExtract> parseLlamaResponse(String responseText)
4143
String json = responseText.substring(tcStart + "<tool_call>".length(), tcEnd).strip();
4244
return parseLlamaJson(json);
4345
}
46+
// 2b. Unclosed <tool_call> — model stopped (eot_id / eom_id) before writing the closing tag
47+
if (tcStart != -1 && tcEnd == -1) {
48+
String json = responseText.substring(tcStart + "<tool_call>".length()).strip();
49+
return parseLlamaJson(json);
50+
}
4451

4552
// 3. Fallback: raw JSON, possibly inside markdown code fences
4653
String stripped = stripMarkdownFences(responseText.strip());
@@ -72,16 +79,66 @@ private static Optional<ToolCallExtract> parseLlamaJson(String json) {
7279

7380
// ── Qwen3 ─────────────────────────────────────────────────────────────────
7481

82+
/**
83+
* Extracts ALL tool calls from a response that may contain multiple
84+
* {@code <tool_call>…</tool_call>} blocks (Llama 3.2 and Qwen3 batch calls).
85+
*
86+
* Falls back to the raw-JSON single-call path if no tags are found.
87+
* Returns an empty list when the response contains no tool calls.
88+
*/
89+
public static List<ToolCallExtract> parseAllToolCalls(String responseText) {
90+
List<ToolCallExtract> calls = new java.util.ArrayList<>();
91+
92+
// <|python_tag|> (Llama 3.1) — single call by definition
93+
int pythonIdx = responseText.indexOf("<|python_tag|>");
94+
if (pythonIdx != -1) {
95+
parseLlamaJson(responseText.substring(pythonIdx + "<|python_tag|>".length()).strip())
96+
.ifPresent(calls::add);
97+
return calls;
98+
}
99+
100+
// Scan for all <tool_call>…</tool_call> blocks
101+
int searchFrom = 0;
102+
while (true) {
103+
int start = responseText.indexOf("<tool_call>", searchFrom);
104+
if (start == -1) break;
105+
int end = responseText.indexOf("</tool_call>", start);
106+
String json;
107+
if (end != -1) {
108+
json = responseText.substring(start + "<tool_call>".length(), end).strip();
109+
searchFrom = end + "</tool_call>".length();
110+
} else {
111+
// Unclosed tag — model stopped before writing the closing tag
112+
json = responseText.substring(start + "<tool_call>".length()).strip();
113+
searchFrom = responseText.length();
114+
}
115+
parseLlamaJson(json).ifPresent(calls::add);
116+
if (end == -1) break;
117+
}
118+
119+
// Raw JSON fallback (no tags at all)
120+
if (calls.isEmpty()) {
121+
String stripped = stripMarkdownFences(responseText.strip());
122+
if (stripped.startsWith("{")) {
123+
parseLlamaJson(stripped).ifPresent(calls::add);
124+
}
125+
}
126+
127+
return calls;
128+
}
129+
75130
/**
76131
* Extracts a tool call enclosed in {@code <tool_call>…</tool_call>} tags
77132
* as produced by Qwen3 models.
78133
*/
79134
public static Optional<ToolCallExtract> parseQwen3Response(String responseText) {
80135
int start = responseText.indexOf("<tool_call>");
81136
int end = responseText.lastIndexOf("</tool_call>");
82-
if (start == -1 || end == -1 || end <= start) return Optional.empty();
137+
if (start == -1) return Optional.empty();
83138

84-
String json = responseText.substring(start + "<tool_call>".length(), end).strip();
139+
String json = (end != -1 && end > start)
140+
? responseText.substring(start + "<tool_call>".length(), end).strip()
141+
: responseText.substring(start + "<tool_call>".length()).strip();
85142

86143
String name = extractStringValue(json, "name");
87144
if (name == null) return Optional.empty();
@@ -104,7 +161,11 @@ public static String stripMarkdownFences(String text) {
104161
return body.strip();
105162
}
106163

107-
/** Extracts the string value for {@code "key": "<value>"} from a JSON object. Tolerates whitespace around {@code :}. */
164+
/**
165+
* Extracts the string value for {@code "key": "<value>"} from a JSON object.
166+
* Tolerates whitespace around {@code :} and correctly skips escaped quotes ({@code \"})
167+
* inside the value, so multi-line code strings with embedded {@code "} are returned intact.
168+
*/
108169
public static String extractStringValue(String json, String key) {
109170
String marker = "\"" + key + "\"";
110171
int markerIdx = json.indexOf(marker);
@@ -113,9 +174,20 @@ public static String extractStringValue(String json, String key) {
113174
if (colonIdx == -1) return null;
114175
int quoteStart = json.indexOf('"', colonIdx + 1);
115176
if (quoteStart == -1) return null;
116-
int quoteEnd = json.indexOf('"', quoteStart + 1);
117-
if (quoteEnd == -1) return null;
118-
return json.substring(quoteStart + 1, quoteEnd);
177+
// Scan for the closing quote, honouring backslash escapes
178+
int i = quoteStart + 1;
179+
while (i < json.length()) {
180+
char c = json.charAt(i);
181+
if (c == '\\') {
182+
i += 2; // skip escape sequence (e.g. \", \\, \n)
183+
} else if (c == '"') {
184+
break;
185+
} else {
186+
i++;
187+
}
188+
}
189+
if (i >= json.length()) return null;
190+
return json.substring(quoteStart + 1, i);
119191
}
120192

121193
/**

0 commit comments

Comments
 (0)