Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 88 additions & 7 deletions omlx/admin/templates/chat.html
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ <h1 class="text-2xl font-bold">{{ t('chat.welcome_heading') }}</h1>
<!-- Messages -->
<div x-show="messages.length > 0" class="max-w-4xl mx-auto px-4 py-6 space-y-4">
<template x-for="(msg, index) in messages" :key="index">
<div class="message-fade-in">
<div class="message-fade-in" x-show="msg._ui !== false">
<!-- User Message -->
<div x-show="msg.role === 'user'" class="flex justify-end">
<div class="user-message">
Expand Down Expand Up @@ -562,6 +562,15 @@ <h1 class="text-2xl font-bold">{{ t('chat.welcome_heading') }}</h1>
</div>
<div class="message-body markdown-content" x-html="renderMarkdown(msg.content)"></div>
</div>
<!-- Tool call indicator (shown while model is fetching tool results) -->
<div x-show="msg.role === 'tool_call'" class="assistant-message" style="border-left: 2px solid var(--border-normal); opacity: 0.8;">
<div class="message-header">
<div class="flex items-center gap-2 text-sm" style="color: var(--text-tertiary);">
<i data-lucide="wrench" class="w-3.5 h-3.5"></i>
<span x-text="msg.content"></span>
</div>
</div>
</div>
</div>
</template>

Expand Down Expand Up @@ -1120,11 +1129,15 @@ <h1 class="text-2xl font-bold">{{ t('chat.welcome_heading') }}</h1>
this.abortController = new AbortController();
this.autoScrollEnabled = true; // Enable auto-scroll at start of streaming

// Build messages for API - strip metadata, keep content as-is (already OpenAI format)
const messagesForApi = this.messages.map(msg => ({
role: msg.role,
content: msg.content
}));
// Build messages for API - pass through tool_calls/tool_call_id for MCP loop
const messagesForApi = this.messages
.filter(msg => ['user', 'assistant', 'tool', 'system'].includes(msg.role))
.map(msg => {
const m = { role: msg.role, content: msg.content ?? null };
if (msg.tool_calls) m.tool_calls = msg.tool_calls;
if (msg.tool_call_id) m.tool_call_id = msg.tool_call_id;
return m;
});

try {
const response = await fetch('/v1/chat/completions', {
Expand All @@ -1149,6 +1162,8 @@ <h1 class="text-2xl font-bold">{{ t('chat.welcome_heading') }}</h1>
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = '';
const toolCallsMap = {}; // accumulate streaming tool_call chunks by index
let finishReason = null;

while (true) {
const { done, value } = await reader.read();
Expand All @@ -1164,7 +1179,21 @@ <h1 class="text-2xl font-bold">{{ t('chat.welcome_heading') }}</h1>
if (line.startsWith('data: ')) {
try {
const data = JSON.parse(line.slice(6));
const delta = data.choices?.[0]?.delta;
const choice = data.choices?.[0];
const delta = choice?.delta;
if (choice?.finish_reason) finishReason = choice.finish_reason;

// Accumulate tool_call argument chunks
if (delta?.tool_calls) {
for (const tc of delta.tool_calls) {
const i = tc.index ?? 0;
if (!toolCallsMap[i]) toolCallsMap[i] = { id: '', type: 'function', function: { name: '', arguments: '' } };
if (tc.id) toolCallsMap[i].id = tc.id;
if (tc.function?.name) toolCallsMap[i].function.name += tc.function.name;
if (tc.function?.arguments) toolCallsMap[i].function.arguments += tc.function.arguments;
}
}

if (delta?.reasoning_content) {
// Wrap reasoning in <think> tags for UI rendering
if (!this.thinkingState.isInThinking) {
Expand Down Expand Up @@ -1199,6 +1228,58 @@ <h1 class="text-2xl font-bold">{{ t('chat.welcome_heading') }}</h1>
this.thinkingState.isInThinking = false;
}

// --- MCP tool call loop ---
const toolCalls = Object.values(toolCallsMap);
if (finishReason === 'tool_calls' && toolCalls.length > 0) {
// Store the assistant tool_calls message (hidden from UI)
this.messages.push({
role: 'assistant',
content: this.streamingContent || null,
tool_calls: toolCalls,
_ui: false,
});

// Execute all tools in parallel, show indicators while waiting
const indicatorIndices = toolCalls.map(() => {
const idx = this.messages.length;
this.messages.push({ role: 'tool_call', content: '', _ui: true });
return idx;
});

const results = await Promise.all(toolCalls.map(async (tc, i) => {
const toolName = tc.function.name;
let args = {};
try { args = JSON.parse(tc.function.arguments || '{}'); } catch(e) {}
this.messages[indicatorIndices[i]].content = `${toolName}…`;

try {
const execResp = await fetch('/v1/mcp/execute', {
method: 'POST',
headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer ${this.getApiKey()}` },
body: JSON.stringify({ tool_name: toolName, arguments: args }),
signal: this.abortController?.signal,
});
const execData = await execResp.json();
return typeof execData.content === 'string' ? execData.content : JSON.stringify(execData.content ?? execData);
} catch(e) {
return `Error: ${e.message}`;
}
}));

// Remove all indicators, push hidden tool result messages in order
indicatorIndices.slice().reverse().forEach(idx => this.messages.splice(idx, 1));
toolCalls.forEach((tc, i) => {
this.messages.push({ role: 'tool', tool_call_id: tc.id, content: results[i], _ui: false });
});

this.saveCurrentChat();
this.streamingContent = '';
// Loop back — model will now synthesize a final answer
await this.streamResponse();
return;
}
// --- end MCP tool call loop ---

// Add completed message
if (this.streamingContent) {
this.messages.push({
Expand Down
128 changes: 128 additions & 0 deletions tests/test_chat_tool_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for chat MCP tool call loop (chat.html streamResponse changes)."""
import json


class TestChatToolCallMessageFiltering:
"""Test the messagesForApi filtering logic (Python equivalent of the JS)."""

@staticmethod
def build_messages_for_api(messages):
"""Replicate the messagesForApi logic from streamResponse in chat.html."""
valid_roles = {"user", "assistant", "tool", "system"}
result = []
for msg in messages:
if msg["role"] not in valid_roles:
continue
m = {"role": msg["role"], "content": msg.get("content")}
if msg.get("tool_calls"):
m["tool_calls"] = msg["tool_calls"]
if msg.get("tool_call_id"):
m["tool_call_id"] = msg["tool_call_id"]
result.append(m)
return result

def test_filters_tool_call_indicator_messages(self):
"""tool_call role messages must not be sent to the API."""
messages = [
{"role": "user", "content": "Who is X?"},
{"role": "tool_call", "content": "tavily__tavily_search…", "_ui": True},
{"role": "assistant", "content": "X is...", "tool_calls": None},
]
api_msgs = self.build_messages_for_api(messages)
roles = [m["role"] for m in api_msgs]
assert "tool_call" not in roles
assert roles == ["user", "assistant"]

def test_passes_tool_calls_and_tool_call_id(self):
"""Assistant tool_calls and tool result tool_call_id must be preserved."""
tc = [{"id": "tc_1", "type": "function", "function": {"name": "t", "arguments": "{}"}}]
messages = [
{"role": "user", "content": "Search for X"},
{"role": "assistant", "content": None, "tool_calls": tc, "_ui": False},
{"role": "tool", "tool_call_id": "tc_1", "content": "result...", "_ui": False},
]
api_msgs = self.build_messages_for_api(messages)
assert len(api_msgs) == 3
assert api_msgs[1]["tool_calls"] == tc
assert api_msgs[2]["tool_call_id"] == "tc_1"

def test_normal_conversation_unchanged(self):
"""Normal user/assistant conversation with no tools is unaffected."""
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
]
api_msgs = self.build_messages_for_api(messages)
assert len(api_msgs) == 2
assert api_msgs[0] == {"role": "user", "content": "Hello"}
assert api_msgs[1] == {"role": "assistant", "content": "Hi there"}


class TestChatToolCallAccumulation:
"""Test streaming tool_call chunk accumulation (Python equivalent of the JS)."""

@staticmethod
def accumulate_tool_calls(deltas):
"""Replicate the toolCallsMap accumulation logic from streamResponse."""
tool_calls_map = {}
for delta in deltas:
if not delta.get("tool_calls"):
continue
for tc in delta["tool_calls"]:
i = tc.get("index", 0)
if i not in tool_calls_map:
tool_calls_map[i] = {"id": "", "type": "function", "function": {"name": "", "arguments": ""}}
if tc.get("id"):
tool_calls_map[i]["id"] = tc["id"]
if tc.get("function", {}).get("name"):
tool_calls_map[i]["function"]["name"] += tc["function"]["name"]
if tc.get("function", {}).get("arguments"):
tool_calls_map[i]["function"]["arguments"] += tc["function"]["arguments"]
return list(tool_calls_map.values())

def test_single_tool_call(self):
"""A single tool call split across multiple chunks is assembled correctly."""
deltas = [
{"tool_calls": [{"index": 0, "id": "tc_1", "function": {"name": "tavily__tavily_search"}}]},
{"tool_calls": [{"index": 0, "function": {"arguments": '{"que'}}]},
{"tool_calls": [{"index": 0, "function": {"arguments": 'ry":"test"}'}}]},
]
result = self.accumulate_tool_calls(deltas)
assert len(result) == 1
assert result[0]["id"] == "tc_1"
assert result[0]["function"]["name"] == "tavily__tavily_search"
assert json.loads(result[0]["function"]["arguments"]) == {"query": "test"}

def test_multiple_parallel_tool_calls(self):
"""Multiple tool calls with different indices are accumulated separately."""
deltas = [
{"tool_calls": [{"index": 0, "id": "tc_1", "function": {"name": "search"}}]},
{"tool_calls": [{"index": 1, "id": "tc_2", "function": {"name": "extract"}}]},
{"tool_calls": [{"index": 0, "function": {"arguments": '{"q":"a"}'}}]},
{"tool_calls": [{"index": 1, "function": {"arguments": '{"urls":["http://x"]}'}}]},
]
result = self.accumulate_tool_calls(deltas)
assert len(result) == 2
assert result[0]["function"]["name"] == "search"
assert result[1]["function"]["name"] == "extract"
assert json.loads(result[0]["function"]["arguments"]) == {"q": "a"}
assert json.loads(result[1]["function"]["arguments"]) == {"urls": ["http://x"]}

def test_no_tool_calls(self):
"""Deltas with no tool_calls produce empty list."""
deltas = [
{"content": "Hello"},
{"content": " world"},
]
result = self.accumulate_tool_calls(deltas)
assert result == []

def test_missing_index_defaults_to_zero(self):
"""A tool_call chunk without an index field defaults to index 0."""
deltas = [
{"tool_calls": [{"id": "tc_1", "function": {"name": "t", "arguments": "{}"}}]},
]
result = self.accumulate_tool_calls(deltas)
assert len(result) == 1
assert result[0]["id"] == "tc_1"