diff --git a/lua/codecompanion/config.lua b/lua/codecompanion/config.lua index 6a3cf899a..768e5ede8 100644 --- a/lua/codecompanion/config.lua +++ b/lua/codecompanion/config.lua @@ -482,6 +482,9 @@ If you are providing code changes, use the insert_edit_into_file tool (if availa }, }, }, + mcp = { + servers = {}, + }, keymaps = { options = { modes = { n = "?" }, diff --git a/lua/codecompanion/interactions/chat/init.lua b/lua/codecompanion/interactions/chat/init.lua index 34979f4cf..bf83d7c7a 100644 --- a/lua/codecompanion/interactions/chat/init.lua +++ b/lua/codecompanion/interactions/chat/init.lua @@ -527,6 +527,8 @@ function Chat.new(args) self:update_metadata() + require("codecompanion.interactions.chat.mcp").start_servers() + -- Likely this hasn't been set by the time the user opens the chat buffer if not _G.codecompanion_current_context then _G.codecompanion_current_context = self.buffer_context.bufnr diff --git a/lua/codecompanion/interactions/chat/mcp/client.lua b/lua/codecompanion/interactions/chat/mcp/client.lua new file mode 100644 index 000000000..72829d05d --- /dev/null +++ b/lua/codecompanion/interactions/chat/mcp/client.lua @@ -0,0 +1,636 @@ +local adapter_utils = require("codecompanion.utils.adapters") +local log = require("codecompanion.utils.log") +local tool_bridge = require("codecompanion.interactions.chat.mcp.tool_bridge") +local utils = require("codecompanion.utils") + +local CONSTANTS = { + GRACEFUL_SHUTDOWN_TIMEOUT_MS = 3000, + SIGTERM_TIMEOUT_MS = 2000, -- After SIGTERM before SIGKILL + MAX_TOOLS_PER_SERVER = 100, -- Maximum tools per server to avoid infinite pagination + + JSONRPC = { + ERROR_PARSE = -32700, + ERROR_INVALID_REQUEST = -32600, + ERROR_METHOD_NOT_FOUND = -32601, + ERROR_INVALID_PARAMS = -32602, + ERROR_INTERNAL = -32603, + }, +} + +local last_msg_id = 0 + +---Increment and return the next unique message id used for JSON-RPC requests. +---@return number next_id +local function next_msg_id() + last_msg_id = last_msg_id + 1 + return last_msg_id +end + +---Transform static methods for easier testing +---@param class table The class with static.methods definition +---@param methods? table Optional method overrides for testing +---@return table methods Transformed methods with overrides applied +local function transform_static_methods(class, methods) + local ret = {} + for k, v in pairs(class.static.methods) do + ret[k] = (methods and methods[k]) or v.default + end + return ret +end + +---Abstraction over the IO transport to a MCP server +---@class CodeCompanion.MCP.Transport +---@field start fun(self: CodeCompanion.MCP.Transport, on_line_read: fun(line: string), on_close: fun(err?: string)) +---@field started fun(self: CodeCompanion.MCP.Transport): boolean +---@field write fun(self: CodeCompanion.MCP.Transport, lines?: string[]) +---@field stop fun(self: CodeCompanion.MCP.Transport) + +---Default Transport implementation backed by vim.system +---@class CodeCompanion.MCP.StdioTransport : CodeCompanion.MCP.Transport +---@field name string +---@field cmd string[] +---@field env? table +---@field _proc? vim.SystemObj +---@field _last_tail? string +---@field _on_line_read? fun(line: string) +---@field _on_close? fun(err?: string) +---@field methods table +local StdioTransport = {} +StdioTransport.__index = StdioTransport + +StdioTransport.static = {} +StdioTransport.static.methods = { + system = { default = vim.system }, + schedule_wrap = { default = vim.schedule_wrap }, +} + +---Create a new StdioTransport for the given server configuration. +---@param name string +---@param cfg CodeCompanion.MCP.ServerConfig +---@param methods? table Optional method overrides for testing +---@return CodeCompanion.MCP.StdioTransport +function StdioTransport:new(name, cfg, methods) + return setmetatable({ + name = name, + cmd = cfg.cmd, + env = cfg.env, + methods = transform_static_methods(StdioTransport, methods), + }, self) +end + +---Start the underlying process and attach stdout/stderr callbacks. +---@param on_line_read fun(line: string) +---@param on_close fun(err?: string) +function StdioTransport:start(on_line_read, on_close) + assert(not self._proc, "StdioTransport: start called when already started") + self._on_line_read = on_line_read + self._on_close = on_close + + adapter_utils.get_env_vars(self) + self._proc = self.methods.system( + self.cmd, + { + env = self.env_replaced or self.env, + text = true, + stdin = true, + stdout = self.methods.schedule_wrap(function(err, data) + self:_handle_stdout(err, data) + end), + stderr = self.methods.schedule_wrap(function(err, data) + self:_handle_stderr(err, data) + end), + }, + self.methods.schedule_wrap(function(out) + self:_handle_exit(out) + end) + ) +end + +---Return whether the transport process has been started. +---@return boolean +function StdioTransport:started() + return self._proc ~= nil +end + +---Handle stdout stream chunks, buffer incomplete lines and deliver complete lines to the on_line_read callback. +---@param err? string +---@param data? string +function StdioTransport:_handle_stdout(err, data) + if err then + log:error("StdioTransport stdout error: %s", err) + return + end + if not data or data == "" then + return + end + + local combined = "" + if self._last_tail then + combined = self._last_tail .. data + self._last_tail = nil + else + combined = data + end + + local last_newline_pos = combined:match(".*()\n") + if last_newline_pos == nil then + self._last_tail = combined + return + elseif last_newline_pos < #combined then + self._last_tail = combined:sub(last_newline_pos + 1) + combined = combined:sub(1, last_newline_pos) + end + + for line in vim.gsplit(combined, "\n", { plain = true, trimempty = true }) do + if line ~= "" and self._on_line_read then + local ok, _ = pcall(self._on_line_read, line) + if not ok then + log:error("StdioTransport on_line_read callback failed for line: %s", line) + end + end + end +end + +---Handle stderr output from the process. +---@param err? string +---@param data? string +function StdioTransport:_handle_stderr(err, data) + if err then + log:error("StdioTransport stderr error: %s", err) + return + end + if data then + log:info("[MCP.%s] stderr: %s", self.name, data) + end +end + +---Handle process exit and invoke the on_close callback with an optional error message. +---@param out vim.SystemCompleted The output object from vim.system containing code and signal fields. +function StdioTransport:_handle_exit(out) + local err_msg = nil + if out and (out.code ~= 0) then + err_msg = string.format("exit code %s, signal %s", tostring(out.code), tostring(out.signal)) + end + self._proc = nil + if self._on_close then + local ok, _ = pcall(self._on_close, err_msg) + if not ok then + log:error("StdioTransport on_close callback failed") + end + end +end + +---Write lines to the process stdin. +---@param lines string[] +function StdioTransport:write(lines) + if not self._proc then + error("StdioTransport: write called before start") + end + self._proc:write(lines) +end + +---Stop the MCP server process. +function StdioTransport:stop() + if not self._proc then + return + end + log:debug("[MCP.%s] initiating graceful shutdown", self.name) + + -- Step 1: Close stdin to signal the server to exit gracefully + local ok, err = pcall(function() + self._proc:write(nil) -- Close stdin + end) + if not ok then + log:warn("[MCP.%s] failed to close stdin: %s", self.name, err) + end + + -- Step 2: Schedule SIGTERM if process doesn't exit within timeout + self.methods.defer_fn(function() + if self._proc then + log:warn("[MCP.%s] process did not exit gracefully, sending SIGTERM", self.name) + pcall(function() + self._proc:kill(vim.uv.constants.SIGTERM) + end) + + -- Step 3: Schedule SIGKILL as last resort + self.methods.defer_fn(function() + if self._proc then + log:error("[MCP.%s] process still alive after SIGTERM, sending SIGKILL", self.name) + pcall(function() + self._proc:kill(vim.uv.constants.SIGKILL) + end) + end + end, CONSTANTS.SIGTERM_TIMEOUT_MS) + end + end, CONSTANTS.GRACEFUL_SHUTDOWN_TIMEOUT_MS) +end + +---@alias ServerRequestHandler fun(cli: CodeCompanion.MCP.Client, params: table?): "result" | "error", table +---@alias ResponseHandler fun(resp: MCP.JSONRPCResultResponse | MCP.JSONRPCErrorResponse) + +---@class CodeCompanion.MCP.Client +---@field name string +---@field cfg CodeCompanion.MCP.ServerConfig +---@field ready boolean +---@field transport CodeCompanion.MCP.Transport +---@field resp_handlers table +---@field server_request_handlers table +---@field server_capabilities? table +---@field server_instructions? string +---@field methods table +local Client = {} +Client.__index = Client + +Client.static = {} +Client.static.methods = { + new_transport = { + default = function(name, cfg, methods) + return StdioTransport:new(name, cfg, methods) + end, + }, + json_decode = { default = vim.json.decode }, + json_encode = { default = vim.json.encode }, + schedule_wrap = { default = vim.schedule_wrap }, + defer_fn = { default = vim.defer_fn }, +} + +---Create a new MCP client instance bound to the provided server configuration. +---@param name string +---@param cfg CodeCompanion.MCP.ServerConfig +---@param methods? table Optional method overrides for testing +---@return CodeCompanion.MCP.Client +function Client:new(name, cfg, methods) + local static_methods = transform_static_methods(Client, methods) + return setmetatable({ + name = name, + cfg = cfg, + ready = false, + transport = static_methods.new_transport(name, cfg, methods), + resp_handlers = {}, + server_request_handlers = { + ["ping"] = self._handle_server_ping, + ["roots/list"] = self._handler_server_roots_list, + }, + methods = static_methods, + }, self) +end + +---Start the client. +function Client:start() + if self.transport:started() then + return + end + log:info("[MCP.%s] Starting with command: %s", self.name, table.concat(self.cfg.cmd, " ")) + + self.transport:start(function(line) + self:_on_transport_line_read(line) + end, function(err) + self:_on_transport_close(err) + end) + utils.fire("MCPServerStart", { name = self.name }) + + self:_start_initialization() +end + +---Stop the client +---@return nil +function Client:stop() + if not self.transport:started() then + return + end + + log:info("[MCP.%s] stopping server", self.name) + self.transport:stop() +end + +---Start the MCP initialization procedure +---@return nil +function Client:_start_initialization() + assert(self.transport:started(), "MCP Server process is not running.") + assert(not self.ready, "MCP Server is already initialized.") + + local capabilities = vim.empty_dict() + if self.cfg.roots then + capabilities.roots = { listChanged = self.cfg.register_roots_list_changed ~= nil } + end + + self:request("initialize", { + protocolVersion = "2025-11-25", + clientInfo = { + name = "CodeCompanion.nvim", + version = "NO VERSION", --MCP Spec explicitly requires a version + }, + capabilities = capabilities, + }, function(resp) + if resp.error then + log:error("[MCP.%s] initialization failed: %s", self.name, resp) + self:stop() + return + end + log:info("[MCP.%s] initialized successfully.", self.name) + log:info("[MCP.%s] protocol version: %s", self.name, resp.result.protocolVersion) + log:info("[MCP.%s] info: %s", self.name, resp.result.serverInfo) + log:info("[MCP.%s] capabilities: %s", self.name, resp.result.capabilities) + self:notify("notifications/initialized") + self.server_capabilities = resp.result.capabilities + self.server_instructions = resp.result.instructions + self.ready = true + if self.cfg.register_roots_list_changed then + self.cfg.register_roots_list_changed(function() + self:notify_roots_list_changed() + end) + end + utils.fire("MCPServerReady", { name = self.name }) + self:refresh_tools() + end) +end + +---Handle transport close events. +---@param err string|nil +function Client:_on_transport_close(err) + self.ready = false + if not err then + log:info("[MCP.%s] exited.", self.name) + else + log:warn("[MCP.%s] exited with error: %s", self.name, err) + end + for id, handler in pairs(self.resp_handlers) do + -- Notify all pending requests of the transport closure + pcall(handler, { + jsonrpc = "2.0", + id = id, + error = { code = CONSTANTS.JSONRPC.ERROR_INTERNAL, message = "MCP server connection closed" }, + }) + end + utils.fire("MCPServerExit", { name = self.name, err = err }) +end + +---Process a single JSON-RPC line received from the MCP server. +---@param line string +function Client:_on_transport_line_read(line) + if not line or line == "" then + return + end + local ok, msg = pcall(self.methods.json_decode, line, { luanil = { object = true } }) + if not ok then + log:error("[MCP.%s] failed to decode received line [%s]: %s", self.name, msg, line) + return + end + if type(msg) ~= "table" or msg.jsonrpc ~= "2.0" then + log:error("[MCP.%s] received invalid MCP message: %s", self.name, line) + return + end + if msg.id == nil then + log:info("[MCP.%s] received notification: %s", self.name, line) + return + end + + if msg.method then + self:_handle_server_request(msg) + else + local handler = self.resp_handlers[msg.id] + if handler then + self.resp_handlers[msg.id] = nil + local handle_ok, handle_result = pcall(handler, msg) + if handle_ok then + log:debug("[MCP.%s] response handler succeeded for request %s", self.name, msg.id) + else + log:error("[MCP.%s] response handler failed for request %s: %s", self.name, msg.id, handle_result) + end + else + log:warn("[MCP.%s] received response with unknown id %s: %s", self.name, msg.id, line) + end + end +end + +---Handle an incoming JSON-RPC request from the MCP server. +---@param msg MCP.JSONRPCRequest +function Client:_handle_server_request(msg) + assert(self.transport:started(), "MCP Server process is not running.") + local resp = { + jsonrpc = "2.0", + id = msg.id, + } + local handler = self.server_request_handlers[msg.method] + if not handler then + log:warn("[MCP.%s] received request %s with unknown method %s", self.name, msg.id, msg.method) + resp.error = { code = CONSTANTS.JSONRPC.ERROR_METHOD_NOT_FOUND, message = "Method not found" } + else + local ok, status, body = pcall(handler, self, msg.params) + if not ok then + log:error("[MCP.%s] handler for method %s failed for request %s: %s", self.name, msg.method, msg.id, status) + resp.error = { code = CONSTANTS.JSONRPC.ERROR_INTERNAL, message = status } + elseif status == "error" then + log:error("[MCP.%s] handler for method %s returned error for request %s: %s", self.name, msg.method, msg.id, body) + resp.error = body + elseif status == "result" then + log:debug("[MCP.%s] handler for method %s returned result for request %s", self.name, msg.method, msg.id) + resp.result = body + else + log:error( + "[MCP.%s] handler for method %s returned invalid status %s for request %s", + self.name, + msg.method, + status, + msg.id + ) + resp.error = { code = CONSTANTS.JSONRPC.ERROR_INTERNAL, message = "Internal server error" } + end + end + local resp_str = self.methods.json_encode(resp) + self.transport:write({ resp_str }) +end + +---Get the server instructions, applying any overrides from the config +---@return string? +function Client:get_server_instructions() + assert(self.ready, "MCP Server is not ready.") + local override = self.cfg.server_instructions + if type(override) == "function" then + return override(self.server_instructions) + elseif type(override) == "string" then + return override + else + return self.server_instructions + end +end + +---Send a JSON-RPC notification to the MCP server. +---@param method string +---@param params? table +function Client:notify(method, params) + assert(self.transport:started(), "MCP Server process is not running.") + if params and vim.tbl_isempty(params) then + params = vim.empty_dict() + end + local notif = { + jsonrpc = "2.0", + method = method, + params = params, + } + local notif_str = self.methods.json_encode(notif) + log:debug("[MCP.%s] sending notification: %s", self.name, notif_str) + self.transport:write({ notif_str }) +end + +---Send a JSON-RPC request to the MCP server. +---@param method string +---@param params? table +---@param resp_handler ResponseHandler +---@param opts? table { timeout_ms? number } +---@return number req_id +function Client:request(method, params, resp_handler, opts) + assert(self.transport:started(), "MCP Server process is not running.") + local req_id = next_msg_id() + if params and vim.tbl_isempty(params) then + params = vim.empty_dict() + end + local req = { + jsonrpc = "2.0", + id = req_id, + method = method, + params = params, + } + if resp_handler then + self.resp_handlers[req_id] = resp_handler + end + local req_str = self.methods.json_encode(req) + log:debug("[MCP.%s] sending request %s: %s", self.name, req_id, req_str) + self.transport:write({ req_str }) + + local timeout_ms = opts and opts.timeout_ms + if timeout_ms then + self.methods.defer_fn(function() + if self.resp_handlers[req_id] then + self.resp_handlers[req_id] = nil + self:cancel_request(req_id, "timeout") + local timeout_msg = string.format("Request timeout after %dms", timeout_ms) + if resp_handler then + local ok, _ = pcall(resp_handler, { + jsonrpc = "2.0", + id = req_id, + error = { code = CONSTANTS.JSONRPC.ERROR_INTERNAL, message = timeout_msg }, + }) + if not ok then + log:error("[MCP.%s] response handler failed to handle timeout for request %s", self.name, req_id) + end + end + end + end, timeout_ms) + end + + return req_id +end + +---Handler for 'ping' server requests. +---@param params any +---@return "result", table +function Client:_handle_server_ping(params) + return "result", {} +end + +---Handler for 'roots/list' server requests. +---@param params any +---@return "result" | "error", table +function Client:_handler_server_roots_list(params) + if not self.cfg.roots then + return "error", { code = CONSTANTS.JSONRPC.ERROR_METHOD_NOT_FOUND, message = "roots capability not enabled" } + end + + local ok, roots = pcall(self.cfg.roots) + if not ok then + log:error("[MCP.%s] roots function failed: %s", self.name, roots) + return "error", { code = CONSTANTS.JSONRPC.ERROR_INTERNAL, message = "roots function failed" } + end + + if not roots or type(roots) ~= "table" then + log:error("[MCP.%s] roots function returned invalid result: %s", self.name, roots) + return "error", { code = CONSTANTS.JSONRPC.ERROR_INTERNAL, message = "roots function returned invalid result" } + end + + return "result", { roots = roots } +end + +---Send a notification that the roots list changed. +---@return nil +function Client:notify_roots_list_changed() + self:notify("notifications/roots/list_changed") +end + +---Cancel a pending request to the MCP server and notify the server of cancellation. +---@param req_id number The ID of the request to cancel +---@param reason? string The reason for cancellation +---@return nil +function Client:cancel_request(req_id, reason) + log:info("[MCP.%s] cancelling request %s: %s", self.name, req_id, reason or "") + self.resp_handlers[req_id] = nil + self:notify("notifications/cancelled", { + requestId = req_id, + reason = reason, + }) +end + +---Call a tool on the MCP server +---@param name string The name of the tool to call +---@param args? table The arguments to pass to the tool +---@param callback fun(ok: boolean, result_or_error: MCP.CallToolResult | string) Callback function that receives (ok, result_or_error) +---@param opts? table { timeout_ms? number } +---@return number req_id +function Client:call_tool(name, args, callback, opts) + assert(self.ready, "MCP Server is not ready.") + + return self:request("tools/call", { + name = name, + arguments = args, + }, function(resp) + if resp.error then + log:error("[MCP.%s] call_tool request failed for [%s]: %s", self.name, name, resp) + callback(false, string.format("MCP JSONRPC error: [%s] %s", resp.error.code, resp.error.message)) + return + end + + callback(true, resp.result) + end, opts) +end + +---Refresh the list of tools available from the MCP server. +---@return nil +function Client:refresh_tools() + assert(self.ready, "MCP Server is not ready.") + if not self.server_capabilities.tools then + log:warn("[MCP.%s] does not support tools", self.name) + return + end + + local all_tools = {} ---@type MCP.Tool[] + local function load_tools(cursor) + self:request("tools/list", { cursor = cursor }, function(resp) + if resp.error then + log:error("[MCP.%s] tools/list request failed: %s", self.name, resp) + return + end + + local tools = resp.result and resp.result.tools or {} + for _, tool in ipairs(tools) do + log:info("[MCP.%s] provides tool `%s`: %s", self.name, tool.name, tool.title or "") + table.insert(all_tools, tool) + end + + -- pagination handling + local next_cursor = resp.result and resp.result.nextCursor + if next_cursor and #all_tools >= CONSTANTS.MAX_TOOLS_PER_SERVER then + log:warn("[MCP.%s] returned too many tools (%d), stopping further loading", self.name, #all_tools) + elseif next_cursor then + log:info("[MCP.%s] loading more tools with cursor: %s", self.name, next_cursor) + return load_tools(next_cursor) + end + + local installed_tools = tool_bridge.setup_tools(self, all_tools) + utils.fire("ChatMCPToolsLoaded", { server = self.name, tools = installed_tools }) + end) + end + + load_tools() +end + +return Client diff --git a/lua/codecompanion/interactions/chat/mcp/init.lua b/lua/codecompanion/interactions/chat/mcp/init.lua new file mode 100644 index 000000000..ee903462c --- /dev/null +++ b/lua/codecompanion/interactions/chat/mcp/init.lua @@ -0,0 +1,43 @@ +local Client = require("codecompanion.interactions.chat.mcp.client") + +local M = {} + +---@class CodeCompanion.MCP.ToolOverride +---@field opts? table +---@field enabled nil | boolean | fun(): boolean +---@field system_prompt? string +---@field output? table +---@field timeout_ms? number + +---@class CodeCompanion.MCP.ServerConfig +---@field cmd string[] +---@field env? table +---@field server_instructions nil | string | fun(orig_server_instructions: string): string +---@field default_tool_opts? table +---@field tool_overrides? table +---@field roots? fun(): { name?: string, uri: string }[] +---@field register_roots_list_changed? fun(notify: fun()) + +---@class CodeCompanion.MCPConfig +---@field servers? table + +---@type table +local clients = {} + +---Start all configured MCP servers if not already started +---@return nil +function M.start_servers() + local mcp_cfg = require("codecompanion.config").interactions.chat.mcp + for name, cfg in pairs(mcp_cfg.servers or {}) do + if not clients[name] then + local client = Client:new(name, cfg) + clients[name] = client + end + end + + for _, client in pairs(clients) do + client:start() + end +end + +return M diff --git a/lua/codecompanion/interactions/chat/mcp/tool_bridge.lua b/lua/codecompanion/interactions/chat/mcp/tool_bridge.lua new file mode 100644 index 000000000..0f08988ac --- /dev/null +++ b/lua/codecompanion/interactions/chat/mcp/tool_bridge.lua @@ -0,0 +1,194 @@ +local log = require("codecompanion.utils.log") + +local CONSTANTS = { + TOOL_PREFIX = "mcp:", + MESSAGES = { + TOOL_ACCESS = "I'm giving you access to tools from an MCP server", + TOOL_GROUPS = "Tools from MCP Server `%s`", + }, +} + +local fmt = string.format + +local M = {} + +---Format the output content from an MCP tool +---@param content string | MCP.ContentBlock[] +---@return string +function M.format_tool_result_content(content) + if type(content) == "table" then + if #content == 1 and content[1].type == "text" then + return content[1].text + end + return vim.inspect(content) + end + return content or "" +end + +---Default tool output callbacks that may be overridden by user config +---@class CodeCompanion.Tool.MCPToolBridge: CodeCompanion.Tools.Tool +local default_output = { + ---@param self CodeCompanion.Tool.MCPToolBridge + ---@param tools CodeCompanion.Tools + ---@param cmd table The command that was executed + ---@param stdout table The output from the command + success = function(self, tools, cmd, stdout) + local chat = tools.chat + local output = M.format_tool_result_content(stdout and stdout[#stdout]) + local args = vim.inspect(self.args) + local for_user = fmt( + [[MCP Tool [%s] executed successfully. +Arguments: +%s +Output: +%s]], + self.name, + args, + output + ) + chat:add_tool_output(self, output, for_user) + end, + + ---@param self CodeCompanion.Tool.MCPToolBridge + ---@param tools CodeCompanion.Tools + ---@param cmd table + ---@param stderr table The error output from the command + error = function(self, tools, cmd, stderr) + local chat = tools.chat + local err_msg = M.format_tool_result_content(stderr and stderr[#stderr] or "") + local for_user = fmt( + [[MCP Tool `%s` execution failed. +Arguments: +%s +Error Message: +%s]], + self.name, + vim.inspect(self.args), + err_msg + ) + chat:add_tool_output(self, "MCP Tool execution failed:\n" .. err_msg, for_user) + end, + + ---The message which is shared with the user when asking for their approval + ---@param self CodeCompanion.Tool.MCPToolBridge + ---@param tools CodeCompanion.Tools + ---@return nil|string + prompt = function(self, tools) + return fmt("Execute the `%s` MCP tool?", self.name) + end, +} + +---Build a CodeCompanion tool from an MCP tool specification +---@param client CodeCompanion.MCP.Client +---@param mcp_tool MCP.Tool +---@return string? tool_name +---@return table? tool_config +function M.build(client, mcp_tool) + if mcp_tool.execution and mcp_tool.execution.taskSupport == "required" then + return log:warn( + "[MCP.%s] tool `%s` requires task execution support, which is not supported", + client.name, + mcp_tool.name + ) + end + + local prefixed_name = fmt("mcp_%s_%s", client.name, mcp_tool.name) + local override = (client.cfg.tool_overrides and client.cfg.tool_overrides[mcp_tool.name]) or {} + local tool_opts = vim.tbl_deep_extend("force", client.cfg.default_tool_opts or {}, override.opts or {}) + local output_callback = vim.tbl_deep_extend("force", default_output, override.output or {}) + + local tool = { + name = prefixed_name, + opts = tool_opts, + schema = { + type = "function", + ["function"] = { + name = prefixed_name, + description = mcp_tool.description, + parameters = mcp_tool.inputSchema, + strict = true, + }, + }, + system_prompt = override.system_prompt, + cmds = { + ---Execute the MCP tool + ---@param self CodeCompanion.Tool.MCPToolBridge + ---@param args table The arguments from the LLM's tool call + ---@param input? any The output from the previous function call + ---@param output_handler function Async callback for completion + ---@return nil|table + function(self, args, input, output_handler) + client:call_tool(mcp_tool.name, args, function(ok, result_or_error) + local output + if not ok then -- RPC failure + output = { status = "error", data = result_or_error } + else + local result = result_or_error + if result.isError then -- Tool execution error + output = { status = "error", data = result.content } + else + output = { status = "success", data = result.content } + end + end + output_handler(output) + end, { timeout_ms = override.timeout_ms }) + end, + }, + output = output_callback, + } + + local tool_cfg = { + description = mcp_tool.title or mcp_tool.name, + callback = tool, + enabled = override.enabled, + -- User should use the generated tool group instead of individual tools + visible = false, + -- `_mcp_info` marks the tool as originating from an MCP server + opts = { _mcp_info = { server = client.name } }, + } + + return prefixed_name, tool_cfg +end + +---Setup tools from an MCP server into CodeCompanion +---@param client CodeCompanion.MCP.Client +---@param mcp_tools MCP.Tool[] +---@return string[] tools +function M.setup_tools(client, mcp_tools) + local chat_tools = require("codecompanion.config").interactions.chat.tools + local tools = {} ---@type string[] + + for _, tool in ipairs(mcp_tools) do + local name, tool_cfg = M.build(client, tool) + if name and tool_cfg then + chat_tools[name] = tool_cfg + table.insert(tools, name) + end + end + + if #tools == 0 then + log:warn("[MCP.%s] has no valid tools to configure", client.name) + return {} + end + + local server_prompt = { + fmt("%s `%s`: %s.", CONSTANTS.MESSAGES.TOOL_ACCESS, client.name, table.concat(tools, ", ")), + } + + -- The prompt should also contain instructions from the server, if any. + local server_instructions = client:get_server_instructions() + if server_instructions and server_instructions ~= "" then + table.insert(server_prompt, "Detailed instructions for this MCP server:") + table.insert(server_prompt, server_instructions) + end + + chat_tools.groups[fmt("%s%s", CONSTANTS.TOOL_PREFIX, client.name)] = { + description = string.format("Tools from MCP Server '%s'", client.name), + tools = tools, + prompt = table.concat(server_prompt, "\n"), + opts = { collapse_tools = true }, + } + return tools +end + +return M diff --git a/lua/codecompanion/types.lua b/lua/codecompanion/types.lua index dadb5e277..5476710d8 100644 --- a/lua/codecompanion/types.lua +++ b/lua/codecompanion/types.lua @@ -23,6 +23,41 @@ ---@alias ACP.availableCommands ACP.AvailableCommand[] +---@meta Model Context Protocol + +---@class MCP.JSONRPCRequest +---@field jsonrpc "2.0" +---@field id integer | string +---@field method string +---@field params table? + +---@class MCP.JSONRPCResultResponse +---@field jsonrpc "2.0" +---@field id integer | string +---@field result table? + +---@class MCP.JSONRPCErrorResponse +---@field jsonrpc "2.0" +---@field id integer | string +---@field error { code: integer, message: string, data: any? } + +---@class MCP.Tool +---@field name string +---@field inputSchema table +---@field description? string +---@field title? string +---@field execution? table + +---@class MCP.TextContent +---@field type "text" +---@field text string + +---@alias MCP.ContentBlock MCP.TextContent|any + +---@class MCP.CallToolResult +---@field isError? boolean +---@field content MCP.ContentBlock[] + ---@meta Tree-sitter ---@class vim.treesitter.LanguageTree diff --git a/tests/config.lua b/tests/config.lua index 4e29cb269..008f7b0a0 100644 --- a/tests/config.lua +++ b/tests/config.lua @@ -253,7 +253,6 @@ return { "cmd", }, }, - ["tool_group"] = { description = "Tool Group", system_prompt = "My tool group system prompt", @@ -362,6 +361,7 @@ return { }, }, }, + mcp = {}, opts = { blank_prompt = "", wait_timeout = 3000, diff --git a/tests/interactions/chat/mcp/test_mcp_client.lua b/tests/interactions/chat/mcp/test_mcp_client.lua new file mode 100644 index 000000000..c4a4ec59f --- /dev/null +++ b/tests/interactions/chat/mcp/test_mcp_client.lua @@ -0,0 +1,411 @@ +local h = require("tests.helpers") + +local child = MiniTest.new_child_neovim() + +local T = MiniTest.new_set({ + hooks = { + pre_case = function() + h.child_start(child) + child.lua([[ + Client = require("codecompanion.interactions.chat.mcp.client") + MockMCPClientTransport = require("tests.mocks.mcp_client_transport") + TRANSPORT = MockMCPClientTransport:new() + function mock_new_transport() + return TRANSPORT + end + + function read_mcp_tools() + return vim + .iter(vim.fn.readfile("tests/stubs/mcp/tools.jsonl")) + :map(function(s) + return s ~= "" and vim.json.decode(s) or nil + end) + :totable() + end + + function setup_default_initialization() + TRANSPORT:expect_jsonrpc_call("initialize", function(params) + return "result", { + protocolVersion = params.protocolVersion, + capabilities = { tools = {} }, + serverInfo = { name = "Test MCP Server", version = "1.0.0" }, + } + end) + TRANSPORT:expect_jsonrpc_notify("notifications/initialized", function() end) + end + + function setup_tool_list(tools) + TRANSPORT:expect_jsonrpc_call("tools/list", function() + return "result", { tools = tools or read_mcp_tools() } + end) + end + + function start_client_and_wait_loaded() + local tools_loaded + vim.api.nvim_create_autocmd("User", { + pattern = "CodeCompanionMCPToolsLoaded", + once = true, + callback = function() tools_loaded = true end, + }) + + CLI = Client:new("testMcp", { cmd = { "test-mcp" } }, { new_transport = mock_new_transport }) + CLI:start() + vim.wait(1000, function() return tools_loaded end) + end + ]]) + end, + post_case = function() + h.is_true(child.lua_get("TRANSPORT:all_handlers_consumed()")) + end, + post_once = child.stop, + }, +}) + +T["MCP Client"] = MiniTest.new_set() +T["MCP Client"]["start() starts and initializes the client once"] = function() + child.lua([[ + READY = false + INIT_PARAMS = {} + + vim.api.nvim_create_autocmd("User", { + pattern = "CodeCompanionMCPServerReady", + once = true, + callback = function() READY = true end, + }) + + TRANSPORT:expect_jsonrpc_call("initialize", function(params) + table.insert(INIT_PARAMS, params) + return "result", { + protocolVersion = params.protocolVersion, + capabilities = { tools = {} }, + serverInfo = { name = "Test MCP Server", version = "1.0.0" }, + } + end) + TRANSPORT:expect_jsonrpc_notify("notifications/initialized", function() end) + + setup_tool_list() + CLI = Client:new("testMcp", { cmd = { "test-mcp" } }, { new_transport = mock_new_transport }) + CLI:start() + CLI:start() -- repeated call should be no-op + CLI:start() + vim.wait(1000, function() return READY end) + CLI:start() -- repeated call should be no-op + CLI:start() + ]]) + + h.is_true(child.lua_get("READY")) + h.eq(child.lua_get("INIT_PARAMS[1]"), { + protocolVersion = "2025-11-25", + clientInfo = { + name = "CodeCompanion.nvim", + version = "NO VERSION", + }, + capabilities = {}, + }) + h.is_true(child.lua_get("CLI.ready")) +end + +T["MCP Client"]["tools are loaded in pages"] = function() + local result = child.lua([[ + setup_default_initialization() + + local mcp_tools = read_mcp_tools() + local page_size = 2 + TRANSPORT:expect_jsonrpc_call("tools/list", function(params) + local start_idx = tonumber(params.cursor) or 1 + local end_idx = math.min(start_idx + page_size - 1, #mcp_tools) + local page_tools = {} + for i = start_idx, end_idx do + table.insert(page_tools, mcp_tools[i]) + end + local next_cursor = end_idx < #mcp_tools and tostring(end_idx + 1) or nil + return "result", { tools = page_tools, nextCursor = next_cursor } + end, { repeats = math.ceil(#mcp_tools / page_size) }) + + start_client_and_wait_loaded() + + local chat_tools = require("codecompanion.config").interactions.chat.tools + local group = chat_tools.groups["mcp:testMcp"] + local tools = vim + .iter(chat_tools) + :filter(function(_, v) + return vim.tbl_get(v, "opts", "_mcp_info", "server") == "testMcp" + end) + :fold({}, function(acc, k, v) + v = vim.deepcopy(v) + -- functions cannot cross process boundary + v.callback.cmds = nil + v.callback.output = nil + acc[k] = v + return acc + end) + return { + mcp_tools = read_mcp_tools(), + group = group, + tools = tools, + } + ]]) + + local mcp_tools = result.mcp_tools + local tools = result.tools + local group = result.group + + h.eq(vim.tbl_count(tools), #mcp_tools) + h.eq(#group.tools, #mcp_tools) + for _, mcp_tool in ipairs(mcp_tools) do + local cc_tool_name = "mcp_testMcp_" .. mcp_tool.name + h.expect_tbl_contains(cc_tool_name, group.tools) + local cc_tool = tools[cc_tool_name] + h.expect_truthy(cc_tool) + h.eq(mcp_tool.title or mcp_tool.name, cc_tool.description) + h.is_false(cc_tool.visible) + h.eq({ + type = "function", + ["function"] = { + name = cc_tool_name, + description = mcp_tool.description, + parameters = mcp_tool.inputSchema, + strict = true, + }, + }, cc_tool.callback.schema) + end + + h.expect_contains("testMcp", group.prompt) +end + +T["MCP Client"]["can process tool calls"] = function() + local result = child.lua([[ + setup_default_initialization() + setup_tool_list() + start_client_and_wait_loaded() + + TRANSPORT:expect_jsonrpc_call("tools/call", function(params) + if params.name == "echo" then + local value = params.arguments.value + if value == nil then + return "result", { isError = true, content = { { type = "text", text = "No value" } } } + end + return "result", { content = { { type = "text", text = params.arguments.value } } } + else + return "error", { code = -32601, message = "Tool not found" } + end + end, { repeats = 3 }) + + local call_results = {} + local function append_call_result(ok, result_or_error) + table.insert(call_results, { ok, result_or_error }) + end + CLI:call_tool("echo", { value = "xxxyyyzzz" }, append_call_result) + CLI:call_tool("echo", {}, append_call_result) + CLI:call_tool("nonexistent_tool", {}, append_call_result) + vim.wait(1000, function() return #call_results == 3 end) + return call_results + ]]) + + h.eq({ + { true, { content = { { type = "text", text = "xxxyyyzzz" } } } }, + { true, { isError = true, content = { { type = "text", text = "No value" } } } }, + { false, "MCP JSONRPC error: [-32601] Tool not found" }, + }, result) +end + +T["MCP Client"]["can handle reordered tool call responses"] = function() + local result = child.lua([[ + setup_default_initialization() + setup_tool_list() + start_client_and_wait_loaded() + + local latencies = { 300, 50, 150, 400 } + for _, latency in ipairs(latencies) do + TRANSPORT:expect_jsonrpc_call("tools/call", function(params) + return "result", { content = { { type = "text", text = params.arguments.value } } } + end, { latency_ms = latency }) + end + + local call_results = {} + local function append_call_result(ok, result_or_error) + table.insert(call_results, { ok, result_or_error }) + end + for i, latency in ipairs(latencies) do + CLI:call_tool("echo", { value = string.format("%d_%d", i, latency) }, append_call_result) + end + vim.wait(1000, function() return #call_results == #latencies end) + return call_results + ]]) + + h.eq({ + { true, { content = { { type = "text", text = "2_50" } } } }, + { true, { content = { { type = "text", text = "3_150" } } } }, + { true, { content = { { type = "text", text = "1_300" } } } }, + { true, { content = { { type = "text", text = "4_400" } } } }, + }, result) +end + +T["MCP Client"]["respects timeout option for tool calls"] = function() + local result = child.lua([[ + setup_default_initialization() + setup_tool_list() + start_client_and_wait_loaded() + + TRANSPORT:expect_jsonrpc_call("tools/call", function(params) + return "result", { content = { { type = "text", text = "fast response" } } } + end) + TRANSPORT:expect_jsonrpc_call("tools/call", function(params) + return "result", { content = { { type = "text", text = "slow response" } } } + end, { latency_ms = 200 }) + TRANSPORT:expect_jsonrpc_call("tools/call", function(params) + return "result", { content = { { type = "text", text = "very slow response" } } } + end, { latency_ms = 200 }) + + local call_results = {} + local function append_call_result(ok, result_or_error) + table.insert(call_results, { ok, result_or_error }) + end + + CLI:call_tool("echo", { value = "no_timeout" }, append_call_result) + CLI:call_tool("echo", { value = "short_timeout" }, append_call_result, { timeout_ms = 100 }) + CLI:call_tool("echo", { value = "long_timeout" }, append_call_result, { timeout_ms = 1000 }) + + vim.wait(2000, function() return #call_results == 3 end) + return call_results + ]]) + + h.is_true(result[1][1]) + h.eq(result[1][2].content[1].text, "fast response") + + h.is_false(result[2][1]) + h.expect_contains("timeout", result[2][2]) + + h.is_true(result[3][1]) + h.eq(result[3][2].content[1].text, "very slow response") +end + +T["MCP Client"]["roots capability is declared when roots config is provided"] = function() + local result = child.lua([[ + setup_default_initialization() + setup_tool_list() + + local roots = { + { uri = "file:///home/user/project1", name = "Project 1" }, + { uri = "file:///home/user/project2", name = "Project 2" }, + } + + CLI = Client:new("testMcp", { + cmd = { "test-mcp" }, + roots = function() return roots end, + }, { new_transport = mock_new_transport }) + CLI:start() + vim.wait(1000, function() return CLI.ready end) + + local received_resp + TRANSPORT:send_request_to_client("roots/list", nil, function(status, result) + received_resp = { status, result } + end) + + vim.wait(1000, function() return received_resp ~= nil end) + return { roots = roots, received_resp = received_resp } + ]]) + + h.eq(result.received_resp[1], "result") + h.eq(result.received_resp[2], { roots = result.roots }) +end + +T["MCP Client"]["roots list changed notification is sent when roots change"] = function() + local result = child.lua([[ + setup_default_initialization() + setup_tool_list() + + local root_lists = { + {}, + { + { uri = "file:///home/user/projectA", name = "Project A" }, + }, + { + { uri = "file:///home/user/projectA", name = "Project A" }, + { uri = "file:///home/user/projectB", name = "Project B" }, + }, + { + { uri = "file:///home/user/projectC", name = "Project C" }, + }, + } + local current_roots + + local notify_roots_list_changed + CLI = Client:new("testMcp", { + cmd = { "test-mcp" }, + roots = function() return current_roots end, + register_roots_list_changed = function(notify) + notify_roots_list_changed = notify + end, + }, { new_transport = mock_new_transport }) + CLI:start() + vim.wait(1000, function() return CLI.ready end) + + local received_resps = {} + for i = 1, #root_lists do + if current_roots ~= nil then + TRANSPORT:expect_jsonrpc_notify("roots/listChanged", function() end) + notify_roots_list_changed() + vim.wait(1000, function() return TRANSPORT:all_handlers_consumed() end) + end + current_roots = root_lists[i] + TRANSPORT:send_request_to_client("roots/list", nil, function(status, result) + received_resps[i] = { status, result } + end) + vim.wait(1000, function() return received_resps[i] ~= nil end) + end + + return { received_resps = received_resps, root_lists = root_lists } + ]]) + + for i, roots in ipairs(result.root_lists) do + h.eq(result.received_resps[i][1], "result") + h.eq(result.received_resps[i][2], { roots = roots }) + end +end + +T["MCP Client"]["transport closed automatically on initialization failure"] = function() + child.lua([[ + TRANSPORT:expect_jsonrpc_call("initialize", function(params) + return "error", { code = -32603, message = "Initialization failed" } + end) + + CLI = Client:new("testMcp", { cmd = { "test-mcp" } }, { new_transport = mock_new_transport }) + CLI:start() + vim.wait(1000, function() return TRANSPORT:all_handlers_consumed() end) + vim.wait(1000, function() return not CLI.ready end) + ]]) + + h.is_false(child.lua_get("CLI.ready")) + h.is_false(child.lua_get("TRANSPORT:started()")) +end + +T["MCP Client"]["stop() cleans up pending requests"] = function() + local call_result = child.lua([[ + setup_default_initialization() + setup_tool_list() + start_client_and_wait_loaded() + + -- initiate a SLOW tool call that won't respond before stop() + TRANSPORT:expect_jsonrpc_call("tools/call", function(params) + return "result", { content = { { type = "text", text = "slow response" } } } + end, { latency_ms = 1000 }) + + local call_result + CLI:call_tool("echo", { value = "will be cancelled" }, function(ok, result_or_error) + call_result = { ok, result_or_error } + end) + + vim.wait(50, function() return call_result ~= nil end) + CLI:stop() + vim.wait(1000, function() return call_result ~= nil end) + return call_result + ]]) + + h.is_false(child.lua_get("CLI.ready")) + h.is_false(child.lua_get("TRANSPORT:started()")) + h.is_false(call_result[1]) + h.expect_contains("close", call_result[2]) +end + +return T diff --git a/tests/interactions/chat/mcp/test_mcp_tools.lua b/tests/interactions/chat/mcp/test_mcp_tools.lua new file mode 100644 index 000000000..e0c3638c0 --- /dev/null +++ b/tests/interactions/chat/mcp/test_mcp_tools.lua @@ -0,0 +1,357 @@ +local h = require("tests.helpers") + +local child = MiniTest.new_child_neovim() + +local T = MiniTest.new_set({ + hooks = { + pre_case = function() + h.child_start(child) + child.lua([[ + local h = require("tests.helpers") + Client = require("codecompanion.interactions.chat.mcp.client") + MockMCPClientTransport = require("tests.mocks.mcp_client_transport") + + MCP_TOOLS = vim + .iter(vim.fn.readfile("tests/stubs/mcp/tools.jsonl")) + :map(function(s) + if s ~= "" then + return vim.json.decode(s) + end + end) + :totable() + + MATH_MCP_TRANSPORT = MockMCPClientTransport:new() + MATH_MCP_TOOLS = vim.iter(MCP_TOOLS):filter(function(tool) + return vim.startswith(tool.name, "math_") + end):totable() + + OTHER_MCP_TRANSPORT = MockMCPClientTransport:new() + OTHER_MCP_TOOLS = vim.iter(MCP_TOOLS):filter(function(tool) + return not vim.startswith(tool.name, "math_") + end):totable() + + Client.static.methods.new_transport.default = function(name, cfg) + local transport + local tools + if cfg.cmd[1] == "math_mcp" then + transport = MATH_MCP_TRANSPORT + tools = MATH_MCP_TOOLS + else + transport = OTHER_MCP_TRANSPORT + tools = OTHER_MCP_TOOLS + end + + transport:expect_jsonrpc_call("initialize", function(params) + return "result", { + protocolVersion = params.protocolVersion, + capabilities = { tools = {} }, + serverInfo = { name = "Test MCP Server", version = "1.0.0" }, + instructions = "Test MCP server instructions.", + } + end) + transport:expect_jsonrpc_notify("notifications/initialized", function(params) end) + transport:expect_jsonrpc_call("tools/list", function() + return "result", { tools = tools } + end) + return transport + end + + local adapter = { + name = "test_adapter_for_mcp_tools", + roles = { llm = "assistant", user = "user" }, + features = {}, + opts = { tools = true }, + url = "http://0.0.0.0", + schema = { model = { default = "dummy" } }, + handlers = { + response = { + parse_chat = function(self, data, tools) + for _, tool in ipairs(data.tools or {}) do + table.insert(tools, tool) + end + return { + status = "success", + output = { role = "assistant", content = data.content } + } + end + }, + tools = { + format_calls = function(self, llm_tool_calls) + return llm_tool_calls + end, + format_response = function(self, llm_tool_call, mcp_output) + return { role = "tool", content = mcp_output } + end, + } + }, + } + + function create_chat(mcp_cfg) + mcp_cfg = mcp_cfg or { + servers = { + math_mcp = { cmd = { "math_mcp" } }, + other_mcp = { cmd = { "other_mcp" } }, + }, + } + local loading = vim.tbl_count(mcp_cfg.servers) + vim.api.nvim_create_autocmd("User", { + pattern = "CodeCompanionMCPToolsLoaded", + callback = function() + loading = loading - 1 + return loading == 0 + end, + }) + local chat = h.setup_chat_buffer({ + interactions = { chat = { mcp = mcp_cfg } }, + adapters = { + http = { [adapter.name] = adapter }, + }, + }, { name = adapter.name }) + vim.wait(1000, function() return loading == 0 end) + return chat + end + ]]) + end, + post_case = function() + h.is_true(child.lua_get("MATH_MCP_TRANSPORT:all_handlers_consumed()")) + h.is_true(child.lua_get("OTHER_MCP_TRANSPORT:all_handlers_consumed()")) + end, + post_once = child.stop, + }, +}) + +T["MCP Tools"] = MiniTest.new_set() + +T["MCP Tools"]["MCP tools can be used as CodeCompanion tools"] = function() + h.mock_http(child) + h.queue_mock_http_response(child, { + content = "Call some tools", + tools = { + { ["function"] = { name = "mcp_math_mcp_math_add", arguments = { a = 1, b = 3 } } }, + { ["function"] = { name = "mcp_math_mcp_math_mul", arguments = { a = 4, b = 2 } } }, + { ["function"] = { name = "mcp_math_mcp_math_add", arguments = { a = 2, b = -3 } } }, + }, + }) + local chat_msgs = child.lua([[ + local chat = create_chat() + MATH_MCP_TRANSPORT:expect_jsonrpc_call("tools/call", function(params) + local retval + if params.name == "math_add" then + retval = params.arguments.a + params.arguments.b + elseif params.name == "math_mul" then + retval = params.arguments.a * params.arguments.b + else + return "error", { code = -32601, message = "Unknown tool: " .. params.name } + end + return "result", { + content = { { type = "text", text = tostring(retval) } } + } + end, { repeats = 3 }) + + chat:add_buf_message({ + role = "user", + content = "@{mcp:math_mcp} Use some tools.", + }) + chat:submit() + vim.wait(1000, function() return vim.bo[chat.bufnr].modifiable end) + return chat.messages + ]]) + + local tool_output_msgs = vim + .iter(chat_msgs) + :map(function(msg) + if msg.role == "tool" then + return msg.content + end + end) + :totable() + h.eq({ "4", "8", "-1" }, tool_output_msgs) + + local llm_req = child.lua_get("_G.mock_client:get_last_request().payload") + local has_prompt = vim.iter(llm_req.messages):any(function(msg) + return msg.content:find("math_mcp") + and msg.content:find("mcp_math_mcp_math_add") + and msg.content:find("mcp_math_mcp_math_mul") + and msg.content:find("Test MCP server instructions.") + end) + h.is_true(has_prompt) + + local math_mcp_tools = child.lua_get("MATH_MCP_TOOLS") + local llm_tool_schemas = llm_req.tools[1] + h.eq(#math_mcp_tools, vim.tbl_count(llm_tool_schemas)) + for _, mcp_tool in ipairs(math_mcp_tools) do + local cc_tool_name = "mcp_math_mcp_" .. mcp_tool.name + local llm_tool_schema = llm_tool_schemas[string.format("%s", cc_tool_name)] + h.eq(llm_tool_schema.type, "function") + h.eq(llm_tool_schema["function"].name, cc_tool_name) + h.eq(llm_tool_schema["function"].description, mcp_tool.description) + h.eq(llm_tool_schema["function"].parameters, mcp_tool.inputSchema) + end +end + +T["MCP Tools"]["MCP tools should handle errors correctly"] = function() + h.mock_http(child) + h.queue_mock_http_response(child, { + content = "Should fail", + tools = { + { ["function"] = { name = "mcp_other_mcp_make_list", arguments = { count = -1, item = "y" } } }, + }, + }) + + local chat_msgs = child.lua([[ + local chat = create_chat() + OTHER_MCP_TRANSPORT:expect_jsonrpc_call("tools/call", function(params) + if params.name == "echo" then + return "error", { code = -32603, message = "test jsonrpc error" } + elseif params.name == "make_list" then + if params.arguments.count < 0 then + return "result", { + isError = true, + content = { { type = "text", text = "count must be non-negative" } }, + } + end + local list = {} + for i = 1, params.arguments.count do + table.insert(list, { type = "text", text = params.arguments.item }) + end + return "result", { content = list } + end + end) + + chat:add_buf_message({ role = "user", content = "@{mcp.other_mcp} Should have errors" }) + chat:submit() + vim.wait(1000, function() return vim.bo[chat.bufnr].modifiable end) + return chat.messages + ]]) + + local tool_output_msgs = vim + .iter(chat_msgs) + :map(function(msg) + if msg.role == "tool" then + return msg.content + end + end) + :totable() + h.eq({ "MCP Tool execution failed:\ncount must be non-negative" }, tool_output_msgs) +end + +T["MCP Tools"]["allows overriding tool options and behavior"] = function() + h.mock_http(child) + h.queue_mock_http_response(child, { + content = "Call some tools", + tools = { + { ["function"] = { name = "mcp_other_mcp_say_hi" } }, + { ["function"] = { name = "mcp_other_mcp_make_list", arguments = { count = 3, item = "xyz" } } }, + { ["function"] = { name = "mcp_other_mcp_echo", arguments = { value = "ECHO REQ" } } }, + }, + }) + + local result = child.lua([[ + require("tests.log") + local chat = create_chat({ + servers = { + other_mcp = { + cmd = { "other_mcp" }, + server_instructions = function(orig) + return orig .. "\nAdditional instructions for other_mcp." + end, + default_tool_opts = { + require_approval_before = true, + }, + tool_overrides = { + echo = { + timeout_ms = 100, + output = { + prompt = function(self, tools) + return "Custom confirmation prompt for echo tool: " .. self.args.value + end, + }, + }, + say_hi = { + opts = { + require_approval_before = false, + }, + system_prompt = "TEST SYSTEM PROMPT FOR SAY_HI", + }, + make_list = { + output = { + success = function(self, tools, cmd, stdout) + local output = vim.iter(stdout[#stdout]):map(function(block) + assert(block.type == "text") + return block.text + end):join(",") + tools.chat:add_tool_output(self, output) + end + }, + }, + } + }, + } + }) + + OTHER_MCP_TRANSPORT:expect_jsonrpc_call("tools/call", function(params) + assert(params.name == "say_hi") + return "result", { content = { { type = "text", text = "Hello there!" } } } + end) + OTHER_MCP_TRANSPORT:expect_jsonrpc_call("tools/call", function(params) + assert(params.name == "make_list") + local content = {} + for i = 1, params.arguments.count do + table.insert(content, { type = "text", text = params.arguments.item }) + end + return "result", { content = content } + end) + OTHER_MCP_TRANSPORT:expect_jsonrpc_call("tools/call", function(params) + assert(params.name == "echo") + return "result", { content = { { type = "text", text = params.arguments.value } } } + end, { latency_ms = 10 * 1000 }) + + chat:add_buf_message({ role = "user", content = "@{mcp:other_mcp}" }) + + local confirmations = {} + local ui = require("codecompanion.utils.ui") + ui.confirm = function(prompt, choices) + table.insert(confirmations, prompt) + for i, choice in ipairs(choices) do + if choice:find("Allow once") then + return i + end + end + assert(false, "No 'Allow once' choice found") + end + + chat:submit() + vim.wait(1000, function() return vim.bo[chat.bufnr].modifiable end) + return { chat_msgs = chat.messages, confirmations = confirmations } + ]]) + + local has_server_instructions = vim.iter(result.chat_msgs):any(function(msg) + return msg.content:find("Test MCP server instructions.\nAdditional instructions for other_mcp.") + end) + h.is_true(has_server_instructions) + + local has_custom_tool_prompt = vim.iter(result.chat_msgs):any(function(msg) + return msg.content:find("TEST SYSTEM PROMPT FOR SAY_HI") + end) + h.is_true(has_custom_tool_prompt) + + local tool_output_msgs = vim + .iter(result.chat_msgs) + :map(function(msg) + if msg.role == "tool" then + return msg.content + end + end) + :totable() + h.eq(tool_output_msgs, { + "Hello there!", + "xyz,xyz,xyz", + "MCP Tool execution failed:\nMCP JSONRPC error: [-32603] Request timeout after 100ms", + }) + + h.eq(#result.confirmations, 2) + h.expect_contains("make_list", result.confirmations[1]) + h.eq(result.confirmations[2], "Custom confirmation prompt for echo tool: ECHO REQ") +end + +return T diff --git a/tests/mocks/mcp_client_transport.lua b/tests/mocks/mcp_client_transport.lua new file mode 100644 index 000000000..8dafdc043 --- /dev/null +++ b/tests/mocks/mcp_client_transport.lua @@ -0,0 +1,169 @@ +local log = require("codecompanion.utils.log") + +---A mock implementation of `Transport` +---@class CodeCompanion.MCP.MockMCPClientTransport : CodeCompanion.MCP.Transport +---@field private _started boolean +---@field private _on_line_read? fun(line: string) +---@field private _on_close? fun() +---@field private _line_handlers (fun(line: string): boolean)[] +local MockMCPClientTransport = {} +MockMCPClientTransport.__index = MockMCPClientTransport + +function MockMCPClientTransport:new() + return setmetatable({ + _started = false, + _line_handlers = {}, + }, self) +end + +function MockMCPClientTransport:start(on_line_read, on_close) + assert(not self._started, "Transport already started") + self._on_line_read = on_line_read + self._on_close = on_close + self._started = true +end + +function MockMCPClientTransport:started() + return self._started +end + +function MockMCPClientTransport:write(lines) + assert(self._started, "Transport not started") + if lines == nil then + self:stop() + return + end + vim.schedule(function() + for _, line in ipairs(lines) do + log:info("MockMCPClientTransport received line: %s", line) + assert(#self._line_handlers > 0, "No pending line handlers") + local handler = self._line_handlers[1] + local keep = handler(line) + if not keep then + table.remove(self._line_handlers, 1) + end + end + end) +end + +function MockMCPClientTransport:write_line_to_client(line, latency_ms) + assert(self._started, "Transport not started") + vim.defer_fn(function() + log:info("MockMCPClientTransport sending line to client: %s", line) + self._on_line_read(line) + end, latency_ms or 0) +end + +---@param handler fun(line: string): boolean handle a client written line; return true to preserve this handler for next line +---@return CodeCompanion.MCP.MockMCPClientTransport self +function MockMCPClientTransport:expect_client_write_line(handler) + table.insert(self._line_handlers, handler) + return self +end + +---@param method string +---@param handler fun(params?: table): "result"|"error", table +---@param opts? { repeats?: integer, latency_ms?: integer } +---@return CodeCompanion.MCP.MockMCPClientTransport self +function MockMCPClientTransport:expect_jsonrpc_call(method, handler, opts) + local remaining_repeats = opts and opts.repeats or 1 + return self:expect_client_write_line(function(line) + local function get_response() + local resp = { jsonrpc = "2.0" } + local ok, req = pcall(vim.json.decode, line, { luanil = { object = true } }) + if not ok then + resp.error = { code = -32700, message = string.format("Parse error: %s", req) } + return resp + end + resp.id = req.id + if req.jsonrpc ~= "2.0" then + resp.error = { code = -32600, message = string.format("Invalid JSON-RPC version: %s", req.jsonrpc) } + return resp + end + if req.method ~= method then + resp.error = { code = -32601, message = string.format("Expected method '%s', got '%s'", method, req.method) } + return resp + end + local status, result = handler(req.params) + if status == "result" then + resp.result = result + elseif status == "error" then + resp.error = result + else + error("Handler must return 'result' or 'error'") + end + return resp + end + + self:write_line_to_client(vim.json.encode(get_response()), opts and opts.latency_ms) + remaining_repeats = remaining_repeats - 1 + return remaining_repeats > 0 + end) +end + +---@param method string +---@param handler? fun(params?: table) +---@param opts? { repeats: integer } +---@return CodeCompanion.MCP.MockMCPClientTransport self +function MockMCPClientTransport:expect_jsonrpc_notify(method, handler, opts) + local remaining_repeats = opts and opts.repeats or 1 + return self:expect_client_write_line(function(line) + local ok, req = pcall(vim.json.decode, line, { luanil = { object = true } }) + if not ok then + log:error("Failed to parse JSON-RPC notification: %s", line) + elseif req.jsonrpc ~= "2.0" then + log:error("Invalid JSON-RPC version: %s", req.jsonrpc) + elseif req.method ~= method then + log:error("Unexpected JSON-RPC method. Expected: %s, Got: %s", method, req.method) + elseif handler then + handler(req.params) + end + remaining_repeats = remaining_repeats - 1 + return remaining_repeats > 0 + end) +end + +---@param method string +---@param params? table +---@param resp_handler fun(status: "result"|"error", result_or_error: table) +function MockMCPClientTransport:send_request_to_client(method, params, resp_handler) + assert(self:all_handlers_consumed(), "Cannot send request to client: pending line handlers exist") + local req_id = math.random(1, 1e9) + local req = { jsonrpc = "2.0", id = req_id, method = method, params = params } + self:expect_client_write_line(function(line) + local ok, resp = pcall(vim.json.decode, line, { luanil = { object = true } }) + if not ok then + log:error("Failed to parse JSON-RPC response: %s", line) + elseif resp.id ~= req_id then + log:error("Mismatched JSON-RPC response ID. Expected: %d, Got: %s", req_id, tostring(resp.id)) + elseif resp.result then + resp_handler("result", resp.result) + elseif resp.error then + resp_handler("error", resp.error) + else + log:error("Invalid JSON-RPC response: %s", line) + end + return false + end) + self:write_line_to_client(vim.json.encode(req)) +end + +function MockMCPClientTransport:all_handlers_consumed() + return #self._line_handlers == 0 +end + +function MockMCPClientTransport:stop() + vim.schedule(function() + self._started = false + self._on_close() + end) +end + +function MockMCPClientTransport:expect_transport_stop() + return self:expect_client_write_line(function(line) + assert(line == nil, "Expected transport to be stopped") + return false + end) +end + +return MockMCPClientTransport diff --git a/tests/stubs/mcp/tools.jsonl b/tests/stubs/mcp/tools.jsonl new file mode 100644 index 000000000..3b2133382 --- /dev/null +++ b/tests/stubs/mcp/tools.jsonl @@ -0,0 +1,5 @@ +{"name":"echo","description":"Echoes back the input","inputSchema":{"type":"object","properties":{"value":{"type":"string","description":"A string value to echo back"}},"required":["value"]}} +{"name":"say_hi","description":"Say Hi to you","inputSchema":{"type":"object","properties":{"name":{"type":"string","description":"Who are you?"}},"required":[]}} +{"name":"make_list","description":"Creates a list of items","inputSchema":{"type":"object","properties":{"count":{"type":"number","description":"Number of items"},"item":{"oneOf":[{"type":"string"},{"type":"number"}],"description":"The item to repeat in the list"}},"required":["count","item"]}} +{"name":"math_add","title":"Math/Add","description":"Adds two numbers","inputSchema":{"type":"object","properties":{"a":{"type":"number","description":"The first number to add"},"b":{"type":"number","description":"The second number to add"}},"required":["a","b"]}} +{"name":"math_mul","title":"Math/Mul","description":"Multiply two numbers","inputSchema":{"type":"object","properties":{"a":{"type":"number","description":"The first number to multiply"},"b":{"type":"number","description":"The second number to multiply"}},"required":["a","b"]}}