diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/handler/McpHttpRequestHandler.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/handler/McpHttpRequestHandler.java index 176ca13051..f7e73dce8b 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/handler/McpHttpRequestHandler.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/handler/McpHttpRequestHandler.java @@ -4,26 +4,19 @@ import com.taobao.arthas.mcp.server.protocol.config.McpServerProperties.ServerProtocol; import com.taobao.arthas.mcp.server.protocol.server.McpTransportContextExtractor; import com.taobao.arthas.mcp.server.protocol.spec.McpError; -import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; -import com.taobao.arthas.mcp.server.tool.ToolCallback; import com.taobao.arthas.mcp.server.util.Assert; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.*; import io.netty.util.CharsetUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.List; -import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; -import static com.taobao.arthas.mcp.server.protocol.spec.McpSchema.*; - /** * MCP HTTP请求处理器,分发请求到无状态或流式处理器。 * @@ -146,7 +139,7 @@ private void sendError(ChannelHandlerContext ctx, HttpResponseStatus status, Mcp response.headers().set(HttpHeaderNames.CONTENT_LENGTH, response.content().readableBytes()); response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, "*"); - ctx.writeAndFlush(response); + ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); } catch (Exception e) { logger.error("Failed to send error response: {}", e.getMessage()); FullHttpResponse response = new DefaultFullHttpResponse( @@ -154,7 +147,7 @@ private void sendError(ChannelHandlerContext ctx, HttpResponseStatus status, Mcp HttpResponseStatus.INTERNAL_SERVER_ERROR ); response.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0); - ctx.writeAndFlush(response); + ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); } } diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/handler/McpStreamableHttpRequestHandler.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/handler/McpStreamableHttpRequestHandler.java index c48b05018f..4c6b3e1772 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/handler/McpStreamableHttpRequestHandler.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/handler/McpStreamableHttpRequestHandler.java @@ -352,7 +352,7 @@ private void handlePostRequest(ChannelHandlerContext ctx, FullHttpRequest reques response.headers().set(HttpHeaders.MCP_SESSION_ID, init.session().getId()); response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, "*"); - ctx.writeAndFlush(response); + ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); } catch (Exception e) { logger.error("Failed to serialize init response: {}", e.getMessage()); sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, @@ -402,7 +402,7 @@ private void handlePostRequest(ChannelHandlerContext ctx, FullHttpRequest reques HttpResponseStatus.ACCEPTED ); response.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0); - ctx.writeAndFlush(response); + ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); }) .exceptionally(e -> { logger.error("Failed to accept response: {}", e.getMessage()); @@ -418,7 +418,7 @@ private void handlePostRequest(ChannelHandlerContext ctx, FullHttpRequest reques HttpResponseStatus.ACCEPTED ); response.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0); - ctx.writeAndFlush(response); + ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); }) .exceptionally(e -> { logger.error("Failed to accept notification: {}", e.getMessage()); @@ -442,14 +442,15 @@ private void handlePostRequest(ChannelHandlerContext ctx, FullHttpRequest reques try { session.responseStream(jsonrpcRequest, sessionTransport, transportContext) - .exceptionally(e -> { - logger.error("Failed to handle request stream: {}", e.getMessage()); - ctx.close(); - return null; + .whenComplete((result, e) -> { + if (e != null) { + logger.error("Failed to handle request stream: {}", e.getMessage()); + sessionTransport.close(); + } }); } catch (Exception e) { logger.error("Failed to handle request stream: {}", e.getMessage()); - ctx.close(); + sessionTransport.close(); } } else { sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, @@ -532,7 +533,7 @@ private void sendError(ChannelHandlerContext ctx, HttpResponseStatus status, Mcp response.headers().set(HttpHeaderNames.CONTENT_LENGTH, response.content().readableBytes()); response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, "*"); - ctx.writeAndFlush(response); + ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); } catch (Exception e) { logger.error(FAILED_TO_SEND_ERROR_RESPONSE, e.getMessage()); FullHttpResponse response = new DefaultFullHttpResponse( @@ -540,7 +541,7 @@ private void sendError(ChannelHandlerContext ctx, HttpResponseStatus status, Mcp HttpResponseStatus.INTERNAL_SERVER_ERROR ); response.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0); - ctx.writeAndFlush(response); + ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); } } diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/McpStreamableServerSession.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/McpStreamableServerSession.java index 12c6a68d19..e40ded4dc9 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/McpStreamableServerSession.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/McpStreamableServerSession.java @@ -167,7 +167,8 @@ public CompletableFuture responseStream(McpSchema.JSONRPCRequest jsonrpcRe logger.warn("Failed to store error response event: {}", e.getMessage()); } - return transport.sendMessage(errorResponse, null); + return transport.sendMessage(errorResponse, null) + .thenCompose(v -> transport.closeGracefully()); } ArthasCommandContext commandContext = createCommandContext(transportContext.get(MCP_AUTH_SUBJECT_KEY)); diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/session/ArthasCommandSessionManager.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/session/ArthasCommandSessionManager.java index b2b7c288a3..924d074649 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/session/ArthasCommandSessionManager.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/session/ArthasCommandSessionManager.java @@ -15,6 +15,10 @@ public class ArthasCommandSessionManager { private static final Logger logger = LoggerFactory.getLogger(ArthasCommandSessionManager.class); + // Arthas 默认 session 超时时间是 30 分钟,这里设置一个稍短的时间作为预判断 + // 如果距离上次访问超过这个时间,认为 session 可能已过期,主动重建 + private static final long SESSION_EXPIRY_THRESHOLD_MS = 25 * 60 * 1000; // 25 分钟 + private final CommandExecutor commandExecutor; private final ConcurrentHashMap sessionBindings = new ConcurrentHashMap<>(); @@ -26,11 +30,15 @@ public static class CommandSessionBinding { private final String mcpSessionId; private final String arthasSessionId; private final String consumerId; + private final long createdTime; + private volatile long lastAccessTime; public CommandSessionBinding(String mcpSessionId, String arthasSessionId, String consumerId) { this.mcpSessionId = mcpSessionId; this.arthasSessionId = arthasSessionId; this.consumerId = consumerId; + this.createdTime = System.currentTimeMillis(); + this.lastAccessTime = this.createdTime; } public String getMcpSessionId() { @@ -44,6 +52,18 @@ public String getArthasSessionId() { public String getConsumerId() { return consumerId; } + + public long getCreatedTime() { + return createdTime; + } + + public long getLastAccessTime() { + return lastAccessTime; + } + + public void updateAccessTime() { + this.lastAccessTime = System.currentTimeMillis(); + } } public CommandSessionBinding createCommandSession(String mcpSessionId) { @@ -90,6 +110,8 @@ public CommandSessionBinding getCommandSession(String mcpSessionId, Object authS logger.debug("Using existing valid session: MCP={}, Arthas={}", mcpSessionId, binding.getArthasSessionId()); } + binding.updateAccessTime(); + if (authSubject != null) { try { commandExecutor.setSessionAuth(binding.getArthasSessionId(), authSubject); @@ -109,32 +131,15 @@ public CommandSessionBinding getCommandSession(String mcpSessionId, Object authS * 通过尝试获取结果来验证session和consumer是否仍然存在 */ private boolean isSessionValid(CommandSessionBinding binding) { - try { - Map result = commandExecutor.pullResults(binding.getArthasSessionId(), binding.getConsumerId()); - - if (result == null) { - return true; - } - - Boolean success = (Boolean) result.get("success"); - if (Boolean.TRUE.equals(success)) { - return true; - } - - String errorMessage = (String) result.get("error"); - if (errorMessage != null && (errorMessage.contains("Session not found") || - errorMessage.contains("Consumer not found") || - errorMessage.contains("session is inactive"))) { - logger.debug("Session validation failed: {}", errorMessage); - return false; - } - - return true; - - } catch (Exception e) { - logger.debug("Session validation error: {}", e.getMessage()); + long timeSinceLastAccess = System.currentTimeMillis() - binding.getLastAccessTime(); + + if (timeSinceLastAccess > SESSION_EXPIRY_THRESHOLD_MS) { + logger.debug("Session possibly expired (inactive for {} ms): MCP={}, Arthas={}", + timeSinceLastAccess, binding.getMcpSessionId(), binding.getArthasSessionId()); return false; } + + return true; } public void closeCommandSession(String mcpSessionId) {