Skip to content

Don't conflate request and response IDs in Streamable HTTP transports #475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,30 +96,30 @@ internal async Task<HttpResponseMessage> SendHttpRequestAsync(JsonRpcMessage mes
}

var rpcRequest = message as JsonRpcRequest;
JsonRpcMessage? rpcResponseCandidate = null;
JsonRpcMessageWithId? rpcResponseOrError = null;

if (response.Content.Headers.ContentType?.MediaType == "application/json")
{
var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
rpcResponseCandidate = await ProcessMessageAsync(responseContent, cancellationToken).ConfigureAwait(false);
rpcResponseOrError = await ProcessMessageAsync(responseContent, rpcRequest, cancellationToken).ConfigureAwait(false);
}
else if (response.Content.Headers.ContentType?.MediaType == "text/event-stream")
{
using var responseBodyStream = await response.Content.ReadAsStreamAsync(cancellationToken);
rpcResponseCandidate = await ProcessSseResponseAsync(responseBodyStream, rpcRequest, cancellationToken).ConfigureAwait(false);
rpcResponseOrError = await ProcessSseResponseAsync(responseBodyStream, rpcRequest, cancellationToken).ConfigureAwait(false);
}

if (rpcRequest is null)
{
return response;
}

if (rpcResponseCandidate is not JsonRpcMessageWithId messageWithId || messageWithId.Id != rpcRequest.Id)
if (rpcResponseOrError is null)
{
throw new McpException($"Streamable HTTP POST response completed without a reply to request with ID: {rpcRequest.Id}");
}

if (rpcRequest.Method == RequestMethods.Initialize && rpcResponseCandidate is JsonRpcResponse)
if (rpcRequest.Method == RequestMethods.Initialize && rpcResponseOrError is JsonRpcResponse)
{
// We've successfully initialized! Copy session-id and start GET request if any.
if (response.Headers.TryGetValues("mcp-session-id", out var sessionIdValues))
Expand Down Expand Up @@ -193,20 +193,20 @@ private async Task ReceiveUnsolicitedMessagesAsync()
continue;
}

var message = await ProcessMessageAsync(sseEvent.Data, cancellationToken).ConfigureAwait(false);
var rpcResponseOrError = await ProcessMessageAsync(sseEvent.Data, relatedRpcRequest, cancellationToken).ConfigureAwait(false);

// The server SHOULD end the response here anyway, but we won't leave it to chance. This transport makes
// The server SHOULD end the HTTP response body here anyway, but we won't leave it to chance. This transport makes
// a GET request for any notifications that might need to be sent after the completion of each POST.
if (message is JsonRpcMessageWithId messageWithId && relatedRpcRequest?.Id == messageWithId.Id)
if (rpcResponseOrError is not null)
{
return messageWithId;
return rpcResponseOrError;
}
}

return null;
}

private async Task<JsonRpcMessage?> ProcessMessageAsync(string data, CancellationToken cancellationToken)
private async Task<JsonRpcMessageWithId?> ProcessMessageAsync(string data, JsonRpcRequest? relatedRpcRequest, CancellationToken cancellationToken)
{
try
{
Expand All @@ -218,7 +218,12 @@ private async Task ReceiveUnsolicitedMessagesAsync()
}

await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
return message;
if (message is JsonRpcResponse or JsonRpcError &&
message is JsonRpcMessageWithId rpcResponseOrError &&
rpcResponseOrError.Id == relatedRpcRequest?.Id)
{
return rpcResponseOrError;
}
}
catch (JsonException ex)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public async ValueTask DisposeAsync()
{
yield return message;

if (message.Data is JsonRpcMessageWithId response && response.Id == _pendingRequest)
if (message.Data is JsonRpcResponse or JsonRpcError && ((JsonRpcMessageWithId)message.Data).Id == _pendingRequest)
{
// Complete the SSE response stream now that all pending requests have been processed.
break;
Expand Down
8 changes: 4 additions & 4 deletions tests/Common/Utils/LoggedTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ public LoggedTest(ITestOutputHelper testOutputHelper)
{
CurrentTestOutputHelper = testOutputHelper,
};
LoggerProvider = new XunitLoggerProvider(_delegatingTestOutputHelper);
XunitLoggerProvider = new XunitLoggerProvider(_delegatingTestOutputHelper);
LoggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder =>
{
builder.AddProvider(LoggerProvider);
builder.AddProvider(XunitLoggerProvider);
});
}

public ITestOutputHelper TestOutputHelper => _delegatingTestOutputHelper;
public ILoggerFactory LoggerFactory { get; }
public ILoggerProvider LoggerProvider { get; }
public ILoggerFactory LoggerFactory { get; set; }
public ILoggerProvider XunitLoggerProvider { get; }

public virtual void Dispose()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public async Task CanConnect_WithMcpClient_AfterCustomizingRoute(string routePat

await app.StartAsync(TestContext.Current.CancellationToken);

var mcpClient = await ConnectAsync(requestPath);
await using var mcpClient = await ConnectAsync(requestPath);

Assert.Equal("TestCustomRouteServer", mcpClient.ServerInfo.Name);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public async Task CanConnect_WithMcpClient_AfterCustomizingRoute(string routePat

await app.StartAsync(TestContext.Current.CancellationToken);

var mcpClient = await ConnectAsync(requestPath);
await using var mcpClient = await ConnectAsync(requestPath);

Assert.Equal("TestCustomRouteServer", mcpClient.ServerInfo.Name);
}
Expand Down Expand Up @@ -135,7 +135,7 @@ public async Task SseMode_Works_WithSseEndpoint()

await app.StartAsync(TestContext.Current.CancellationToken);

await using var mcpClient = await ConnectAsync(options: new()
await using var mcpClient = await ConnectAsync(transportOptions: new()
{
Endpoint = new Uri("http://localhost/sse"),
TransportMode = HttpTransportMode.Sse
Expand Down
128 changes: 122 additions & 6 deletions tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using ModelContextProtocol.AspNetCore.Tests.Utils;
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using ModelContextProtocol.Tests.Utils;
using System.ComponentModel;
using System.Net;
using System.Security.Claims;
Expand All @@ -20,18 +23,21 @@ protected void ConfigureStateless(HttpServerTransportOptions options)
options.Stateless = Stateless;
}

protected async Task<IMcpClient> ConnectAsync(string? path = null, SseClientTransportOptions? options = null)
protected async Task<IMcpClient> ConnectAsync(
string? path = null,
SseClientTransportOptions? transportOptions = null,
McpClientOptions? clientOptions = null)
{
// Default behavior when no options are provided
path ??= UseStreamableHttp ? "/" : "/sse";

await using var transport = new SseClientTransport(options ?? new SseClientTransportOptions()
await using var transport = new SseClientTransport(transportOptions ?? new SseClientTransportOptions()
{
Endpoint = new Uri($"http://localhost{path}"),
TransportMode = UseStreamableHttp ? HttpTransportMode.StreamableHttp : HttpTransportMode.Sse,
}, HttpClient, LoggerFactory);

return await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken);
return await McpClientFactory.CreateAsync(transport, clientOptions, LoggerFactory, TestContext.Current.CancellationToken);
}

[Fact]
Expand Down Expand Up @@ -71,7 +77,7 @@ IHttpContextAccessor is not currently supported with non-stateless Streamable HT

await app.StartAsync(TestContext.Current.CancellationToken);

var mcpClient = await ConnectAsync();
await using var mcpClient = await ConnectAsync();

var response = await mcpClient.CallToolAsync(
"EchoWithUserName",
Expand Down Expand Up @@ -111,13 +117,90 @@ public async Task Messages_FromNewUser_AreRejected()
Assert.Equal(HttpStatusCode.Forbidden, httpRequestException.StatusCode);
}

protected ClaimsPrincipal CreateUser(string name)
[Fact]
public async Task Sampling_DoesNotCloseStream_Prematurely()
{
Assert.SkipWhen(Stateless, "Sampling is not supported in stateless mode.");

Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools<SamplingRegressionTools>();

var mockLoggerProvider = new MockLoggerProvider();
Builder.Logging.AddProvider(mockLoggerProvider);
Builder.Logging.SetMinimumLevel(LogLevel.Debug);

await using var app = Builder.Build();

// Reset the LoggerFactory used by the client to use the MockLoggerProvider as well.
LoggerFactory = app.Services.GetRequiredService<ILoggerFactory>();

app.MapMcp();

await app.StartAsync(TestContext.Current.CancellationToken);

var sampleCount = 0;
var clientOptions = new McpClientOptions
{
Capabilities = new()
{
Sampling = new()
{
SamplingHandler = async (parameters, _, _) =>
{
Assert.NotNull(parameters?.Messages);
var message = Assert.Single(parameters.Messages);
Assert.Equal(Role.User, message.Role);
Assert.Equal("text", message.Content.Type);
Assert.Equal("Test prompt for sampling", message.Content.Text);

sampleCount++;
return new CreateMessageResult
{
Model = "test-model",
Role = Role.Assistant,
Content = new Content
{
Type = "text",
Text = "Sampling response from client"
}
};
},
},
},
};

await using var mcpClient = await ConnectAsync(clientOptions: clientOptions);

var result = await mcpClient.CallToolAsync("sampling-tool", new Dictionary<string, object?>
{
["prompt"] = "Test prompt for sampling"
}, cancellationToken: TestContext.Current.CancellationToken);

Assert.NotNull(result);
Assert.False(result.IsError);
var textContent = Assert.Single(result.Content);
Assert.Equal("text", textContent.Type);
Assert.Equal("Sampling completed successfully. Client responded: Sampling response from client", textContent.Text);

Assert.Equal(2, sampleCount);

// Verify that the tool call and the sampling request both used the same ID to ensure we cover against regressions.
// https://github.com/modelcontextprotocol/csharp-sdk/issues/464
Assert.Single(mockLoggerProvider.LogMessages, m =>
m.Category == "ModelContextProtocol.Client.McpClient" &&
m.Message.Contains("request '2' for method 'tools/call'"));

Assert.Single(mockLoggerProvider.LogMessages, m =>
m.Category == "ModelContextProtocol.Client.McpServer" &&
m.Message.Contains("request '2' for method 'sampling/createMessage'"));
}

private ClaimsPrincipal CreateUser(string name)
=> new ClaimsPrincipal(new ClaimsIdentity(
[new Claim("name", name), new Claim(ClaimTypes.NameIdentifier, name)],
"TestAuthType", "name", "role"));

[McpServerToolType]
protected class EchoHttpContextUserTools(IHttpContextAccessor contextAccessor)
private class EchoHttpContextUserTools(IHttpContextAccessor contextAccessor)
{
[McpServerTool, Description("Echoes the input back to the client with their user name.")]
public string EchoWithUserName(string message)
Expand All @@ -127,4 +210,37 @@ public string EchoWithUserName(string message)
return $"{userName}: {message}";
}
}

[McpServerToolType]
private class SamplingRegressionTools
{
[McpServerTool(Name = "sampling-tool")]
public static async Task<string> SamplingToolAsync(IMcpServer server, string prompt, CancellationToken cancellationToken)
{
// This tool reproduces the scenario described in https://github.com/modelcontextprotocol/csharp-sdk/issues/464
// 1. The client calls tool with request ID 2, because it's the first request after the initialize request.
// 2. This tool makes two sampling requests which use IDs 1 and 2.
// 3. In the old buggy Streamable HTTP transport code, this would close the SSE response stream,
// because the second sampling request used an ID matching the tool call.
var samplingRequest = new CreateMessageRequestParams
{
Messages = [
new SamplingMessage
{
Role = Role.User,
Content = new Content
{
Type = "text",
Text = prompt
},
}
],
};

await server.SampleAsync(samplingRequest, cancellationToken);
var samplingResult = await server.SampleAsync(samplingRequest, cancellationToken);

return $"Sampling completed successfully. Client responded: {samplingResult.Content.Text}";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public KestrelInMemoryTest(ITestOutputHelper testOutputHelper)
Builder = WebApplication.CreateSlimBuilder();
Builder.Services.RemoveAll<IConnectionListenerFactory>();
Builder.Services.AddSingleton<IConnectionListenerFactory>(_inMemoryTransport);
Builder.Services.AddSingleton(LoggerProvider);
Builder.Services.AddSingleton(XunitLoggerProvider);

HttpClient = new HttpClient(new SocketsHttpHandler()
{
Expand Down
Loading