Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.taobao.arthas.mcp.server.CommandExecutor;
import com.taobao.arthas.mcp.server.protocol.spec.*;
import com.taobao.arthas.mcp.server.session.ArthasCommandSessionManager;
import com.taobao.arthas.mcp.server.task.*;
import com.taobao.arthas.mcp.server.util.Assert;
import com.taobao.arthas.mcp.server.util.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.time.Duration;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.*;
import java.util.function.BiFunction;

/**
* A Netty-based MCP server implementation that provides access to tools, resources, and prompts.
* A Netty-based MCP server implementation that provides access to tools,
* resources, and prompts.
*
* @author Yeaury
*/
Expand All @@ -42,40 +42,68 @@ public class McpNettyServer {

private final CopyOnWriteArrayList<McpServerFeatures.ToolSpecification> tools = new CopyOnWriteArrayList<>();

private final ConcurrentHashMap<String, McpServerFeatures.ToolSpecification> toolsByName = new ConcurrentHashMap<>();

private final CopyOnWriteArrayList<McpSchema.ResourceTemplate> resourceTemplates = new CopyOnWriteArrayList<>();

private final ConcurrentHashMap<String, McpServerFeatures.ResourceSpecification> resources = new ConcurrentHashMap<>();

private final ConcurrentHashMap<String, McpServerFeatures.PromptSpecification> prompts = new ConcurrentHashMap<>();

private final ServerTaskToolHandler serverTaskToolHandler;

private final ArthasCommandSessionManager sessionManager;

private McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.DEBUG;

private List<String> protocolVersions;

McpNettyServer(McpStreamableServerTransportProvider mcpTransportProvider,
ObjectMapper objectMapper, Duration requestTimeout,
McpServerFeatures.McpServerConfig features,
CommandExecutor commandExecutor) {
ObjectMapper objectMapper, Duration requestTimeout,
McpServerFeatures.McpServerConfig features,
CommandExecutor commandExecutor,
ArthasCommandSessionManager sessionManager) {
this.mcpTransportProvider = mcpTransportProvider;
this.objectMapper = objectMapper;
this.serverInfo = features.getServerInfo();
this.serverCapabilities = features.getServerCapabilities();
this.instructions = features.getInstructions();
this.tools.addAll(features.getTools());

for (McpServerFeatures.ToolSpecification tool : features.getTools()) {
this.toolsByName.put(tool.getTool().getName(), tool);
}

this.resources.putAll(features.getResources());
this.resourceTemplates.addAll(features.getResourceTemplates());
this.prompts.putAll(features.getPrompts());

this.sessionManager = sessionManager;

this.serverTaskToolHandler = new ServerTaskToolHandler(
features.getTaskTools(),
features.getTaskOptions(),
objectMapper,
this::notifyAllClients,
Duration.ofSeconds(30),
sessionManager);

Map<String, McpRequestHandler<?>> requestHandlers = prepareRequestHandlers();
Map<String, McpNotificationHandler> notificationHandlers = prepareNotificationHandlers(features);

this.protocolVersions = mcpTransportProvider.protocolVersions();

TaskStore<McpSchema.ServerTaskPayloadResult> taskStore = this.serverTaskToolHandler
.getTaskStore();
TaskMessageQueue taskMessageQueue = this.serverTaskToolHandler.getTaskMessageQueue();

mcpTransportProvider.setSessionFactory(new DefaultMcpStreamableServerSessionFactory(requestTimeout,
this::initializeRequestHandler, requestHandlers, notificationHandlers, commandExecutor));
this::initializeRequestHandler, requestHandlers, notificationHandlers, commandExecutor,
taskStore, taskMessageQueue));
}

private Map<String, McpNotificationHandler> prepareNotificationHandlers(McpServerFeatures.McpServerConfig features) {
private Map<String, McpNotificationHandler> prepareNotificationHandlers(
McpServerFeatures.McpServerConfig features) {
Map<String, McpNotificationHandler> notificationHandlers = new HashMap<>();

notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED,
Expand All @@ -86,9 +114,9 @@ private Map<String, McpNotificationHandler> prepareNotificationHandlers(McpServe

if (Utils.isEmpty(rootsChangeConsumers)) {
rootsChangeConsumers = Collections.singletonList(
(exchange, roots) -> CompletableFuture.runAsync(() ->
logger.warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))
);
(exchange, roots) -> CompletableFuture.runAsync(() -> logger.warn(
"Roots list changed notification, but no consumers provided. Roots list changed: {}",
roots)));
}

notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED,
Expand Down Expand Up @@ -129,10 +157,13 @@ private Map<String, McpRequestHandler<?>> prepareRequestHandlers() {
requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler());
}

// Add tasks API handlers via ServerTaskToolHandler
this.serverTaskToolHandler.logCapabilityMismatches(this.serverCapabilities.getTasks());
requestHandlers.putAll(this.serverTaskToolHandler.getRequestHandlers(this.serverCapabilities.getTasks()));

return requestHandlers;
}


// ---------------------------------------
// Lifecycle Management
// ---------------------------------------
Expand All @@ -150,8 +181,7 @@ private CompletableFuture<McpSchema.InitializeResult> initializeRequestHandler(

if (protocolVersions.contains(initializeRequest.getProtocolVersion())) {
serverProtocolVersion = initializeRequest.getProtocolVersion();
}
else {
} else {
logger.warn(
"Client requested unsupported protocol version: {}, " + "so the server will suggest {} instead",
initializeRequest.getProtocolVersion(), serverProtocolVersion);
Expand All @@ -170,10 +200,19 @@ public McpSchema.Implementation getServerInfo() {
}

public CompletableFuture<Void> closeGracefully() {
return this.mcpTransportProvider.closeGracefully();
if (this.serverTaskToolHandler != null) {
return this.serverTaskToolHandler.closeGracefully()
.thenCompose(v -> mcpTransportProvider.closeGracefully());
}
return mcpTransportProvider.closeGracefully();
}

public void close() {
try {
closeGracefully().get(5, TimeUnit.SECONDS);
} catch (Exception e) {
logger.warn("Error during graceful close", e);
}
this.mcpTransportProvider.close();
}

Expand Down Expand Up @@ -225,7 +264,8 @@ public CompletableFuture<Void> addTool(McpServerFeatures.ToolSpecification toolS

return CompletableFuture.supplyAsync(() -> {
// Check for duplicate tool names
if (this.tools.stream().anyMatch(th -> th.getTool().getName().equals(toolSpecification.getTool().getName()))) {
if (this.tools.stream()
.anyMatch(th -> th.getTool().getName().equals(toolSpecification.getTool().getName()))) {
throw new CompletionException(
new McpError("Tool with name '" + toolSpecification.getTool().getName() + "' already exists"));
}
Expand Down Expand Up @@ -280,33 +320,94 @@ public CompletableFuture<Void> notifyToolsListChanged() {
return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null);
}


private CompletableFuture<Void> notifyAllClients(String method, Object notification) {
return this.mcpTransportProvider.notifyClients(method, notification);
}

public CompletableFuture<Void> addTaskTool(TaskAwareToolSpecification taskToolSpecification) {
if (this.serverCapabilities.getTools() == null) {
CompletableFuture<Void> future = new CompletableFuture<>();
future.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.METHOD_NOT_FOUND)
.message("Server must be configured with tool capabilities")
.build());
return future;
}
return this.serverTaskToolHandler.addTaskTool(taskToolSpecification, this.serverCapabilities.getTools());
}

/**
* Remove a task-aware tool at runtime.
* @param toolName The name of the task-aware tool to remove
* @return Mono that completes when clients have been notified of the change
*/
public CompletableFuture<Void> removeTaskTool(String toolName) {
if (this.serverCapabilities.getTools() == null) {
CompletableFuture<Void> future = new CompletableFuture<>();
future.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.METHOD_NOT_FOUND)
.message("Server must be configured with tool capabilities")
.build());
return future;
}
return this.serverTaskToolHandler.removeTaskTool(toolName, this.serverCapabilities.getTools());
}

private McpRequestHandler<McpSchema.ListToolsResult> toolsListRequestHandler() {
return (exchange, commandContext, params) -> {
List<McpSchema.Tool> tools = new ArrayList<>();
List<McpSchema.Tool> toolList = new ArrayList<>();
for (McpServerFeatures.ToolSpecification toolSpec : this.tools) {
tools.add(toolSpec.getTool());
toolList.add(toolSpec.getTool());
}
toolList.addAll(this.serverTaskToolHandler.getToolDefinitions());

return CompletableFuture.completedFuture(new McpSchema.ListToolsResult(tools, null));
return CompletableFuture.completedFuture(new McpSchema.ListToolsResult(toolList, null));
};
}

private McpRequestHandler<McpSchema.CallToolResult> toolsCallRequestHandler() {
private McpRequestHandler<Object> toolsCallRequestHandler() {
return (exchange, commandContext, params) -> {
McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params,
new TypeReference<McpSchema.CallToolRequest>() {
});
Optional<McpServerFeatures.ToolSpecification> toolSpecification = this.tools.stream()
.filter(tr -> callToolRequest.getName().equals(tr.getTool().getName()))
.findAny();

if (!toolSpecification.isPresent()) {
CompletableFuture<McpSchema.CallToolResult> future = new CompletableFuture<>();
future.completeExceptionally(new McpError("no tool found: " + callToolRequest.getName()));
return future;
String toolName = callToolRequest.getName();

McpServerFeatures.ToolSpecification normalTool = this.toolsByName.get(toolName);
if (normalTool != null) {
// Normal tools don't support task enhancement requests
if (callToolRequest.getTask() != null) {
CompletableFuture<Object> future = new CompletableFuture<>();
future.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.METHOD_NOT_FOUND)
.message("Tool '" + toolName + "' does not support task-augmented requests")
.data("Remove the 'task' parameter or use a task-aware tool")
.build());
return future;
}
return normalTool.getCall().apply(exchange, commandContext, callToolRequest)
.thenApply(result -> (Object) result);
}

// task aware tools are delegated to ServerTaskToolHandler
CompletableFuture<Object> taskToolResult = this.serverTaskToolHandler.handleToolCall(
exchange, commandContext, callToolRequest);
if (taskToolResult != null) {
return taskToolResult.thenApply(result -> {
if (result == null) {
throw McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS)
.message("Unknown tool: " + callToolRequest.getName())
.data("Tool not found: " + callToolRequest.getName())
.build();
}
return result;
});
}

return toolSpecification.get().getCall().apply(exchange, commandContext, callToolRequest);
CompletableFuture<Object> future = new CompletableFuture<>();
future.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS)
.message("Unknown tool: " + callToolRequest.getName())
.data("Tool not found: " + callToolRequest.getName())
.build());
return future;
};
}

Expand All @@ -327,7 +428,8 @@ public CompletableFuture<Void> addResource(McpServerFeatures.ResourceSpecificati
}

return CompletableFuture.supplyAsync(() -> {
if (this.resources.putIfAbsent(resourceSpecification.getResource().getUri(), resourceSpecification) != null) {
if (this.resources.putIfAbsent(resourceSpecification.getResource().getUri(),
resourceSpecification) != null) {
throw new CompletionException(new McpError(
"Resource with URI '" + resourceSpecification.getResource().getUri() + "' already exists"));
}
Expand Down Expand Up @@ -393,7 +495,7 @@ private McpRequestHandler<McpSchema.ListResourcesResult> resourcesListRequestHan

private McpRequestHandler<McpSchema.ListResourceTemplatesResult> resourceTemplateListRequestHandler() {
return (exchange, commandContext, params) -> CompletableFuture
.completedFuture(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null));
.completedFuture(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null));
}

private McpRequestHandler<McpSchema.ReadResourceResult> resourcesReadRequestHandler() {
Expand Down Expand Up @@ -430,10 +532,11 @@ public CompletableFuture<Void> addPrompt(McpServerFeatures.PromptSpecification p

return CompletableFuture.supplyAsync(() -> {
McpServerFeatures.PromptSpecification existing = this.prompts
.putIfAbsent(promptSpecification.getPrompt().getName(), promptSpecification);
.putIfAbsent(promptSpecification.getPrompt().getName(), promptSpecification);
if (existing != null) {
throw new CompletionException(
new McpError("Prompt with name '" + promptSpecification.getPrompt().getName() + "' already exists"));
new McpError(
"Prompt with name '" + promptSpecification.getPrompt().getName() + "' already exists"));
}

logger.debug("Added prompt handler: {}", promptSpecification.getPrompt().getName());
Expand Down Expand Up @@ -533,10 +636,10 @@ private McpRequestHandler<Map<String, Object>> setLoggerRequestHandler() {
McpSchema.SetLevelRequest.class);
this.minLoggingLevel = request.getLevel();
return CompletableFuture.completedFuture(Collections.emptyMap());
}
catch (Exception e) {
} catch (Exception e) {
CompletableFuture<Map<String, Object>> future = new CompletableFuture<>();
future.completeExceptionally(new McpError("An error occurred while processing a request to set the log level: " + e.getMessage()));
future.completeExceptionally(new McpError(
"An error occurred while processing a request to set the log level: " + e.getMessage()));
return future;
}
};
Expand Down
Loading
Loading