Skip to content

Commit 920a338

Browse files
committed
feat: implement MCP Tasks support for long-running tool operations
1 parent a8a6a2f commit 920a338

File tree

54 files changed

+5962
-1352
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+5962
-1352
lines changed

arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/McpNettyServer.java

Lines changed: 134 additions & 198 deletions
Large diffs are not rendered by default.

arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/McpNettyServerExchange.java

Lines changed: 191 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,17 @@
1010
import com.taobao.arthas.mcp.server.protocol.spec.McpSchema.LoggingLevel;
1111
import com.taobao.arthas.mcp.server.protocol.spec.McpSchema.LoggingMessageNotification;
1212
import com.taobao.arthas.mcp.server.protocol.spec.McpSession;
13+
import com.taobao.arthas.mcp.server.task.QueuedMessage;
14+
import com.taobao.arthas.mcp.server.task.TaskDefaults;
15+
import com.taobao.arthas.mcp.server.task.TaskMessageQueue;
16+
import com.taobao.arthas.mcp.server.task.TaskStore;
1317
import com.taobao.arthas.mcp.server.util.Assert;
1418
import org.slf4j.Logger;
1519
import org.slf4j.LoggerFactory;
1620

21+
import java.time.Duration;
1722
import java.util.concurrent.CompletableFuture;
23+
import java.util.concurrent.atomic.AtomicLong;
1824

1925
/**
2026
* Represents the interaction between MCP server and client. Provides methods for communication, logging, and context management.
@@ -47,8 +53,16 @@ public class McpNettyServerExchange {
4753

4854
private final McpTransportContext transportContext;
4955

56+
private final TaskMessageQueue taskMessageQueue;
57+
58+
private final TaskStore<McpSchema.ServerTaskPayloadResult> taskStore;
59+
5060
private volatile LoggingLevel minLoggingLevel = LoggingLevel.INFO;
5161

62+
private final AtomicLong sideChannelRequestCounter = new AtomicLong(0);
63+
64+
private static final Duration SIDE_CHANNEL_TIMEOUT = Duration.ofMinutes(TaskDefaults.DEFAULT_SIDE_CHANNEL_TIMEOUT_MINUTES);
65+
5266
private static final TypeReference<McpSchema.CreateMessageResult> CREATE_MESSAGE_RESULT_TYPE_REF =
5367
new TypeReference<McpSchema.CreateMessageResult>() {
5468
};
@@ -60,18 +74,48 @@ public class McpNettyServerExchange {
6074
private static final TypeReference<McpSchema.ElicitResult> ELICIT_USER_INPUT_RESULT_TYPE_REF =
6175
new TypeReference<McpSchema.ElicitResult>() {
6276
};
77+
78+
private static final TypeReference<McpSchema.GetTaskResult> GET_TASK_RESULT_TYPE_REF =
79+
new TypeReference<McpSchema.GetTaskResult>() {
80+
};
81+
82+
private static final TypeReference<McpSchema.CreateTaskResult> CREATE_TASK_RESULT_TYPE_REF =
83+
new TypeReference<McpSchema.CreateTaskResult>() {
84+
};
85+
86+
private static final TypeReference<McpSchema.ListTasksResult> LIST_TASKS_RESULT_TYPE_REF =
87+
new TypeReference<McpSchema.ListTasksResult>() {
88+
};
89+
90+
private static final TypeReference<McpSchema.CancelTaskResult> CANCEL_TASK_RESULT_TYPE_REF =
91+
new TypeReference<McpSchema.CancelTaskResult>() {
92+
};
6393

6494
public static final TypeReference<Object> OBJECT_TYPE_REF = new TypeReference<Object>() {
6595
};
6696

6797
public McpNettyServerExchange(String sessionId, McpSession session,
68-
McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo,
69-
McpTransportContext transportContext) {
98+
McpSchema.ClientCapabilities clientCapabilities,
99+
McpSchema.Implementation clientInfo,
100+
McpTransportContext transportContext,
101+
TaskMessageQueue taskMessageQueue) {
102+
this(sessionId, session, clientCapabilities, clientInfo, transportContext, taskMessageQueue, null);
103+
}
104+
105+
106+
public McpNettyServerExchange(String sessionId, McpSession session,
107+
McpSchema.ClientCapabilities clientCapabilities,
108+
McpSchema.Implementation clientInfo,
109+
McpTransportContext transportContext,
110+
TaskMessageQueue taskMessageQueue,
111+
TaskStore<McpSchema.ServerTaskPayloadResult> taskStore) {
70112
this.sessionId = sessionId;
71113
this.session = session;
72114
this.clientCapabilities = clientCapabilities;
73115
this.clientInfo = clientInfo;
74116
this.transportContext = transportContext;
117+
this.taskMessageQueue = taskMessageQueue;
118+
this.taskStore = taskStore;
75119
}
76120

77121
public CompletableFuture<Void> sendNotification(String method, Object params) {
@@ -109,6 +153,9 @@ public McpTransportContext getTransportContext() {
109153
return this.transportContext;
110154
}
111155

156+
public String sessionId() {
157+
return this.sessionId;
158+
}
112159
/**
113160
* Create a new message using client sampling capability. MCP provides a standardized way for servers to request
114161
* LLM sampling ("completion" or "generation") through the client. This flow allows clients to maintain control
@@ -120,19 +167,33 @@ public McpTransportContext getTransportContext() {
120167
*/
121168
public CompletableFuture<McpSchema.CreateMessageResult> createMessage(
122169
McpSchema.CreateMessageRequest createMessageRequest) {
170+
return createMessage(createMessageRequest, null);
171+
}
172+
173+
public CompletableFuture<McpSchema.CreateMessageResult> createMessage(
174+
McpSchema.CreateMessageRequest createMessageRequest,
175+
String taskId) {
123176
if (this.clientCapabilities == null) {
124177
logger.error("Client not initialized, cannot create message");
125178
CompletableFuture<McpSchema.CreateMessageResult> future = new CompletableFuture<>();
126-
future.completeExceptionally(new McpError("Client must be initialized first. Please call initialize method!"));
179+
future.completeExceptionally(new McpError("Client must be initialized. Call the initialize method first!"));
127180
return future;
128181
}
129182
if (this.clientCapabilities.getSampling() == null) {
130183
logger.error("Client not configured with sampling capability, cannot create message");
131184
CompletableFuture<McpSchema.CreateMessageResult> future = new CompletableFuture<>();
132-
future.completeExceptionally(new McpError("Client must be configured with sampling capability"));
185+
future.completeExceptionally(new McpError("Client must be configured with sampling capabilities"));
133186
return future;
134187
}
135188

189+
// Side-channel flow: enqueue request and wait for response via tasks/result
190+
if (taskId != null && this.taskMessageQueue != null && this.taskStore != null) {
191+
return sideChannelRequest(taskId, McpSchema.METHOD_SAMPLING_CREATE_MESSAGE,
192+
createMessageRequest, McpSchema.CreateMessageResult.class,
193+
"Waiting for sampling response");
194+
}
195+
196+
// No task context: send immediately
136197
logger.debug("Creating client message, session ID: {}", this.sessionId);
137198
return this.session
138199
.sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, createMessageRequest, CREATE_MESSAGE_RESULT_TYPE_REF)
@@ -175,18 +236,28 @@ public CompletableFuture<McpSchema.ListRootsResult> listRoots(String cursor) {
175236
}
176237

177238
public CompletableFuture<Void> loggingNotification(LoggingMessageNotification loggingMessageNotification) {
239+
return loggingNotification(loggingMessageNotification, null);
240+
}
241+
242+
public CompletableFuture<Void> loggingNotification(LoggingMessageNotification loggingMessageNotification, String taskId) {
178243
if (loggingMessageNotification == null) {
179244
CompletableFuture<Void> future = new CompletableFuture<>();
180-
future.completeExceptionally(new McpError("log messages cannot be empty"));
245+
future.completeExceptionally(new McpError("Logging message must not be null"));
181246
return future;
182247
}
183248

184249
if (this.isNotificationForLevelAllowed(loggingMessageNotification.getLevel())) {
250+
// Side-channel flow: enqueue notification for delivery via tasks/result
251+
if (taskId != null && this.taskMessageQueue != null) {
252+
return sideChannelNotification(taskId, McpSchema.METHOD_NOTIFICATION_MESSAGE,
253+
loggingMessageNotification);
254+
}
255+
185256
return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, loggingMessageNotification)
186257
.whenComplete((result, error) -> {
187258
if (error != null) {
188-
logger.error("Failed to send logging notification, level: {}, session ID: {}, error: {}", loggingMessageNotification.getLevel(),
189-
this.sessionId, error.getMessage());
259+
logger.error("Failed to send logging notification, level: {}, session ID: {}, error: {}",
260+
loggingMessageNotification.getLevel(), this.sessionId, error.getMessage());
190261
}
191262
});
192263
}
@@ -197,10 +268,15 @@ public CompletableFuture<Object> ping() {
197268
return this.session.sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF);
198269
}
199270

200-
public CompletableFuture<McpSchema.ElicitResult> createElicitation(McpSchema.ElicitRequest request) {
201-
if (request == null) {
271+
272+
public CompletableFuture<McpSchema.ElicitResult> createElicitation(McpSchema.ElicitRequest elicitRequest) {
273+
return createElicitation(elicitRequest, null);
274+
}
275+
276+
public CompletableFuture<McpSchema.ElicitResult> createElicitation(McpSchema.ElicitRequest elicitRequest, String taskId) {
277+
if (elicitRequest == null) {
202278
CompletableFuture<McpSchema.ElicitResult> future = new CompletableFuture<>();
203-
future.completeExceptionally(new McpError("elicit request cannot be null"));
279+
future.completeExceptionally(new McpError("Elicit request must not be null"));
204280
return future;
205281
}
206282
if (this.clientCapabilities == null) {
@@ -213,8 +289,17 @@ public CompletableFuture<McpSchema.ElicitResult> createElicitation(McpSchema.Eli
213289
future.completeExceptionally(new McpError("Client must be configured with elicitation capabilities"));
214290
return future;
215291
}
292+
293+
// Side-channel flow: enqueue request and wait for response via tasks/result
294+
if (taskId != null && this.taskMessageQueue != null && this.taskStore != null) {
295+
return sideChannelRequest(taskId, McpSchema.METHOD_ELICITATION_CREATE,
296+
elicitRequest, McpSchema.ElicitResult.class,
297+
"Waiting for user input");
298+
}
299+
300+
// No task context: send immediately
216301
return this.session
217-
.sendRequest(McpSchema.METHOD_ELICITATION_CREATE, request, ELICIT_USER_INPUT_RESULT_TYPE_REF)
302+
.sendRequest(McpSchema.METHOD_ELICITATION_CREATE, elicitRequest, ELICIT_USER_INPUT_RESULT_TYPE_REF)
218303
.whenComplete((result, error) -> {
219304
if (error != null) {
220305
logger.error("Failed to elicit user input, session ID: {}, error: {}", this.sessionId, error.getMessage());
@@ -235,12 +320,23 @@ private boolean isNotificationForLevelAllowed(LoggingLevel loggingLevel) {
235320
}
236321

237322
public CompletableFuture<Void> progressNotification(McpSchema.ProgressNotification progressNotification) {
323+
return progressNotification(progressNotification, null);
324+
}
325+
326+
public CompletableFuture<Void> progressNotification(McpSchema.ProgressNotification progressNotification, String taskId) {
238327
if (progressNotification == null) {
239328
CompletableFuture<Void> future = new CompletableFuture<>();
240-
future.completeExceptionally(new McpError("progress notifications cannot be empty"));
329+
future.completeExceptionally(new McpError("Progress notification must not be null"));
241330
return future;
242331
}
243332

333+
// Side-channel flow: enqueue notification for delivery via tasks/result
334+
if (taskId != null && this.taskMessageQueue != null) {
335+
return sideChannelNotification(taskId, McpSchema.METHOD_NOTIFICATION_PROGRESS,
336+
progressNotification);
337+
}
338+
339+
// Send immediately
244340
return this.session
245341
.sendNotification(McpSchema.METHOD_NOTIFICATION_PROGRESS, progressNotification)
246342
.whenComplete((result, error) -> {
@@ -251,4 +347,87 @@ public CompletableFuture<Void> progressNotification(McpSchema.ProgressNotificati
251347
}
252348
});
253349
}
350+
351+
352+
public CompletableFuture<McpSchema.GetTaskResult> getTask(McpSchema.GetTaskRequest getTaskRequest) {
353+
return this.session.sendRequest(McpSchema.METHOD_TASKS_GET, getTaskRequest, GET_TASK_RESULT_TYPE_REF);
354+
}
355+
356+
public CompletableFuture<McpSchema.GetTaskResult> getTask(String taskId) {
357+
return this.getTask(new McpSchema.GetTaskRequest(taskId, null));
358+
}
359+
360+
public <T> CompletableFuture<T> getTaskResult(
361+
McpSchema.GetTaskPayloadRequest getTaskPayloadRequest,
362+
TypeReference<T> resultTypeRef) {
363+
return this.session.sendRequest(McpSchema.METHOD_TASKS_RESULT, getTaskPayloadRequest, resultTypeRef);
364+
}
365+
366+
public <T> CompletableFuture<T> getTaskResult(
367+
String taskId,
368+
TypeReference<T> resultTypeRef) {
369+
return this.getTaskResult(new McpSchema.GetTaskPayloadRequest(taskId, null), resultTypeRef);
370+
}
371+
372+
public CompletableFuture<McpSchema.ListTasksResult> listTasks() {
373+
return this.listTasks(null);
374+
}
375+
376+
public CompletableFuture<McpSchema.ListTasksResult> listTasks(String cursor) {
377+
return this.session.sendRequest(McpSchema.METHOD_TASKS_LIST,
378+
new McpSchema.PaginatedRequest(cursor),
379+
LIST_TASKS_RESULT_TYPE_REF);
380+
}
381+
382+
public CompletableFuture<McpSchema.CancelTaskResult> cancelTask(McpSchema.CancelTaskRequest cancelTaskRequest) {
383+
return this.session.sendRequest(McpSchema.METHOD_TASKS_CANCEL, cancelTaskRequest, CANCEL_TASK_RESULT_TYPE_REF);
384+
}
385+
386+
public CompletableFuture<McpSchema.CancelTaskResult> cancelTask(String taskId) {
387+
Assert.notNull(taskId, "Task ID must not be null");
388+
if (taskId.trim().isEmpty()) {
389+
CompletableFuture<McpSchema.CancelTaskResult> future = new CompletableFuture<>();
390+
future.completeExceptionally(new IllegalArgumentException("Task ID must not be empty"));
391+
return future;
392+
}
393+
return cancelTask(new McpSchema.CancelTaskRequest(taskId, null));
394+
}
395+
396+
// === Side-Channel Helpers ===
397+
398+
@SuppressWarnings("unchecked")
399+
private <T extends McpSchema.Result> CompletableFuture<T> sideChannelRequest(
400+
String taskId, String method, McpSchema.Request request,
401+
Class<T> resultType, String inputMessage) {
402+
403+
String requestId = "sc-" + this.sessionId + "-" + this.sideChannelRequestCounter.getAndIncrement();
404+
405+
logger.debug("Side-channel request: taskId={}, method={}, requestId={}", taskId, method, requestId);
406+
407+
// 1. Enqueue the request for the side-channel handler to pick up
408+
QueuedMessage.Request queuedRequest = new QueuedMessage.Request(requestId, method, request);
409+
410+
return this.taskMessageQueue.enqueue(taskId, queuedRequest)
411+
.thenCompose(v -> {
412+
// 2. Set task to INPUT_REQUIRED so client polls tasks/result
413+
return this.taskStore.updateTaskStatus(taskId, this.sessionId,
414+
McpSchema.TaskStatus.INPUT_REQUIRED, inputMessage);
415+
})
416+
.thenCompose(v -> {
417+
// 3. Wait for the response to arrive via the queue
418+
return this.taskMessageQueue.waitForResponse(taskId, requestId, SIDE_CHANNEL_TIMEOUT);
419+
})
420+
.thenCompose(response -> {
421+
// 4. Restore task to WORKING status
422+
return this.taskStore.updateTaskStatus(taskId, this.sessionId,
423+
McpSchema.TaskStatus.WORKING, null)
424+
.thenApply(v -> (T) response.result());
425+
});
426+
}
427+
428+
private CompletableFuture<Void> sideChannelNotification(String taskId, String method, Object notification) {
429+
logger.debug("Side-channel notification: taskId={}, method={}", taskId, method);
430+
QueuedMessage.Notification queuedNotification = new QueuedMessage.Notification(method, notification);
431+
return this.taskMessageQueue.enqueue(taskId, queuedNotification);
432+
}
254433
}

0 commit comments

Comments
 (0)