Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,19 @@
import com.taobao.arthas.mcp.server.protocol.config.McpServerProperties.ServerProtocol;
import com.taobao.arthas.mcp.server.protocol.server.McpTransportContextExtractor;
import com.taobao.arthas.mcp.server.protocol.spec.McpError;
import com.taobao.arthas.mcp.server.protocol.spec.McpSchema;
import com.taobao.arthas.mcp.server.tool.ToolCallback;
import com.taobao.arthas.mcp.server.util.Assert;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.*;
import io.netty.util.CharsetUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;

import static com.taobao.arthas.mcp.server.protocol.spec.McpSchema.*;

/**
* MCP HTTP请求处理器,分发请求到无状态或流式处理器。
*
Expand Down Expand Up @@ -146,15 +139,15 @@ private void sendError(ChannelHandlerContext ctx, HttpResponseStatus status, Mcp
response.headers().set(HttpHeaderNames.CONTENT_LENGTH, response.content().readableBytes());
response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, "*");

ctx.writeAndFlush(response);
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
} catch (Exception e) {
logger.error("Failed to send error response: {}", e.getMessage());
FullHttpResponse response = new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1,
HttpResponseStatus.INTERNAL_SERVER_ERROR
);
response.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0);
ctx.writeAndFlush(response);
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ private void handlePostRequest(ChannelHandlerContext ctx, FullHttpRequest reques
response.headers().set(HttpHeaders.MCP_SESSION_ID, init.session().getId());
response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, "*");

ctx.writeAndFlush(response);
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
} catch (Exception e) {
logger.error("Failed to serialize init response: {}", e.getMessage());
sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR,
Expand Down Expand Up @@ -402,7 +402,7 @@ private void handlePostRequest(ChannelHandlerContext ctx, FullHttpRequest reques
HttpResponseStatus.ACCEPTED
);
response.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0);
ctx.writeAndFlush(response);
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
})
.exceptionally(e -> {
logger.error("Failed to accept response: {}", e.getMessage());
Expand All @@ -418,7 +418,7 @@ private void handlePostRequest(ChannelHandlerContext ctx, FullHttpRequest reques
HttpResponseStatus.ACCEPTED
);
response.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0);
ctx.writeAndFlush(response);
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
})
.exceptionally(e -> {
logger.error("Failed to accept notification: {}", e.getMessage());
Expand All @@ -442,14 +442,15 @@ private void handlePostRequest(ChannelHandlerContext ctx, FullHttpRequest reques

try {
session.responseStream(jsonrpcRequest, sessionTransport, transportContext)
.exceptionally(e -> {
logger.error("Failed to handle request stream: {}", e.getMessage());
ctx.close();
return null;
.whenComplete((result, e) -> {
if (e != null) {
logger.error("Failed to handle request stream: {}", e.getMessage());
sessionTransport.close();
}
});
} catch (Exception e) {
logger.error("Failed to handle request stream: {}", e.getMessage());
ctx.close();
sessionTransport.close();
}
} else {
sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR,
Expand Down Expand Up @@ -532,15 +533,15 @@ private void sendError(ChannelHandlerContext ctx, HttpResponseStatus status, Mcp
response.headers().set(HttpHeaderNames.CONTENT_LENGTH, response.content().readableBytes());
response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, "*");

ctx.writeAndFlush(response);
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
} catch (Exception e) {
logger.error(FAILED_TO_SEND_ERROR_RESPONSE, e.getMessage());
FullHttpResponse response = new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1,
HttpResponseStatus.INTERNAL_SERVER_ERROR
);
response.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0);
ctx.writeAndFlush(response);
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ public CompletableFuture<Void> responseStream(McpSchema.JSONRPCRequest jsonrpcRe
logger.warn("Failed to store error response event: {}", e.getMessage());
}

return transport.sendMessage(errorResponse, null);
return transport.sendMessage(errorResponse, null)
.thenCompose(v -> transport.closeGracefully());
}
ArthasCommandContext commandContext = createCommandContext(transportContext.get(MCP_AUTH_SUBJECT_KEY));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
public class ArthasCommandSessionManager {
private static final Logger logger = LoggerFactory.getLogger(ArthasCommandSessionManager.class);

// Arthas 默认 session 超时时间是 30 分钟,这里设置一个稍短的时间作为预判断
// 如果距离上次访问超过这个时间,认为 session 可能已过期,主动重建
private static final long SESSION_EXPIRY_THRESHOLD_MS = 25 * 60 * 1000; // 25 分钟

private final CommandExecutor commandExecutor;
private final ConcurrentHashMap<String, CommandSessionBinding> sessionBindings = new ConcurrentHashMap<>();

Expand All @@ -26,11 +30,15 @@ public static class CommandSessionBinding {
private final String mcpSessionId;
private final String arthasSessionId;
private final String consumerId;
private final long createdTime;
private volatile long lastAccessTime;

public CommandSessionBinding(String mcpSessionId, String arthasSessionId, String consumerId) {
this.mcpSessionId = mcpSessionId;
this.arthasSessionId = arthasSessionId;
this.consumerId = consumerId;
this.createdTime = System.currentTimeMillis();
this.lastAccessTime = this.createdTime;
}

public String getMcpSessionId() {
Expand All @@ -44,6 +52,18 @@ public String getArthasSessionId() {
public String getConsumerId() {
return consumerId;
}

public long getCreatedTime() {
return createdTime;
}

public long getLastAccessTime() {
return lastAccessTime;
}

public void updateAccessTime() {
this.lastAccessTime = System.currentTimeMillis();
}
}

public CommandSessionBinding createCommandSession(String mcpSessionId) {
Expand Down Expand Up @@ -90,6 +110,8 @@ public CommandSessionBinding getCommandSession(String mcpSessionId, Object authS
logger.debug("Using existing valid session: MCP={}, Arthas={}", mcpSessionId, binding.getArthasSessionId());
}

binding.updateAccessTime();

if (authSubject != null) {
try {
commandExecutor.setSessionAuth(binding.getArthasSessionId(), authSubject);
Expand All @@ -109,32 +131,15 @@ public CommandSessionBinding getCommandSession(String mcpSessionId, Object authS
* 通过尝试获取结果来验证session和consumer是否仍然存在
*/
private boolean isSessionValid(CommandSessionBinding binding) {
try {
Map<String, Object> result = commandExecutor.pullResults(binding.getArthasSessionId(), binding.getConsumerId());

if (result == null) {
return true;
}

Boolean success = (Boolean) result.get("success");
if (Boolean.TRUE.equals(success)) {
return true;
}

String errorMessage = (String) result.get("error");
if (errorMessage != null && (errorMessage.contains("Session not found") ||
errorMessage.contains("Consumer not found") ||
errorMessage.contains("session is inactive"))) {
logger.debug("Session validation failed: {}", errorMessage);
return false;
}

return true;

} catch (Exception e) {
logger.debug("Session validation error: {}", e.getMessage());
long timeSinceLastAccess = System.currentTimeMillis() - binding.getLastAccessTime();

if (timeSinceLastAccess > SESSION_EXPIRY_THRESHOLD_MS) {
logger.debug("Session possibly expired (inactive for {} ms): MCP={}, Arthas={}",
timeSinceLastAccess, binding.getMcpSessionId(), binding.getArthasSessionId());
return false;
}

return true;
}

public void closeCommandSession(String mcpSessionId) {
Expand Down
Loading