11using Microsoft . AspNetCore . Builder ;
22using Microsoft . AspNetCore . Http ;
33using Microsoft . Extensions . DependencyInjection ;
4+ using Microsoft . Extensions . Logging ;
45using ModelContextProtocol . AspNetCore . Tests . Utils ;
56using ModelContextProtocol . Client ;
7+ using ModelContextProtocol . Protocol ;
68using ModelContextProtocol . Server ;
9+ using ModelContextProtocol . Tests . Utils ;
710using System . ComponentModel ;
811using System . Net ;
912using System . Security . Claims ;
@@ -20,18 +23,21 @@ protected void ConfigureStateless(HttpServerTransportOptions options)
2023 options . Stateless = Stateless ;
2124 }
2225
23- protected async Task < IMcpClient > ConnectAsync ( string ? path = null , SseClientTransportOptions ? options = null )
26+ protected async Task < IMcpClient > ConnectAsync (
27+ string ? path = null ,
28+ SseClientTransportOptions ? transportOptions = null ,
29+ McpClientOptions ? clientOptions = null )
2430 {
2531 // Default behavior when no options are provided
2632 path ??= UseStreamableHttp ? "/" : "/sse" ;
2733
28- await using var transport = new SseClientTransport ( options ?? new SseClientTransportOptions ( )
34+ await using var transport = new SseClientTransport ( transportOptions ?? new SseClientTransportOptions ( )
2935 {
3036 Endpoint = new Uri ( $ "http://localhost{ path } ") ,
3137 TransportMode = UseStreamableHttp ? HttpTransportMode . StreamableHttp : HttpTransportMode . Sse ,
3238 } , HttpClient , LoggerFactory ) ;
3339
34- return await McpClientFactory . CreateAsync ( transport , loggerFactory : LoggerFactory , cancellationToken : TestContext . Current . CancellationToken ) ;
40+ return await McpClientFactory . CreateAsync ( transport , clientOptions , LoggerFactory , TestContext . Current . CancellationToken ) ;
3541 }
3642
3743 [ Fact ]
@@ -71,7 +77,7 @@ IHttpContextAccessor is not currently supported with non-stateless Streamable HT
7177
7278 await app . StartAsync ( TestContext . Current . CancellationToken ) ;
7379
74- var mcpClient = await ConnectAsync ( ) ;
80+ await using var mcpClient = await ConnectAsync ( ) ;
7581
7682 var response = await mcpClient . CallToolAsync (
7783 "EchoWithUserName" ,
@@ -111,13 +117,90 @@ public async Task Messages_FromNewUser_AreRejected()
111117 Assert . Equal ( HttpStatusCode . Forbidden , httpRequestException . StatusCode ) ;
112118 }
113119
114- protected ClaimsPrincipal CreateUser ( string name )
120+ [ Fact ]
121+ public async Task Sampling_DoesNotCloseStream_Prematurely ( )
122+ {
123+ Assert . SkipWhen ( Stateless , "Sampling is not supported in stateless mode." ) ;
124+
125+ Builder . Services . AddMcpServer ( ) . WithHttpTransport ( ConfigureStateless ) . WithTools < SamplingRegressionTools > ( ) ;
126+
127+ var mockLoggerProvider = new MockLoggerProvider ( ) ;
128+ Builder . Logging . AddProvider ( mockLoggerProvider ) ;
129+ Builder . Logging . SetMinimumLevel ( LogLevel . Debug ) ;
130+
131+ await using var app = Builder . Build ( ) ;
132+
133+ // Reset the LoggerFactory used by the client to use the MockLoggerProvider as well.
134+ LoggerFactory = app . Services . GetRequiredService < ILoggerFactory > ( ) ;
135+
136+ app . MapMcp ( ) ;
137+
138+ await app . StartAsync ( TestContext . Current . CancellationToken ) ;
139+
140+ var sampleCount = 0 ;
141+ var clientOptions = new McpClientOptions
142+ {
143+ Capabilities = new ( )
144+ {
145+ Sampling = new ( )
146+ {
147+ SamplingHandler = async ( parameters , _ , _ ) =>
148+ {
149+ Assert . NotNull ( parameters ? . Messages ) ;
150+ var message = Assert . Single ( parameters . Messages ) ;
151+ Assert . Equal ( Role . User , message . Role ) ;
152+ Assert . Equal ( "text" , message . Content . Type ) ;
153+ Assert . Equal ( "Test prompt for sampling" , message . Content . Text ) ;
154+
155+ sampleCount ++ ;
156+ return new CreateMessageResult
157+ {
158+ Model = "test-model" ,
159+ Role = Role . Assistant ,
160+ Content = new Content
161+ {
162+ Type = "text" ,
163+ Text = "Sampling response from client"
164+ }
165+ } ;
166+ } ,
167+ } ,
168+ } ,
169+ } ;
170+
171+ await using var mcpClient = await ConnectAsync ( clientOptions : clientOptions ) ;
172+
173+ var result = await mcpClient . CallToolAsync ( "sampling-tool" , new Dictionary < string , object ? >
174+ {
175+ [ "prompt" ] = "Test prompt for sampling"
176+ } , cancellationToken : TestContext . Current . CancellationToken ) ;
177+
178+ Assert . NotNull ( result ) ;
179+ Assert . False ( result . IsError ) ;
180+ var textContent = Assert . Single ( result . Content ) ;
181+ Assert . Equal ( "text" , textContent . Type ) ;
182+ Assert . Equal ( "Sampling completed successfully. Client responded: Sampling response from client" , textContent . Text ) ;
183+
184+ Assert . Equal ( 2 , sampleCount ) ;
185+
186+ // Verify that the tool call and the sampling request both used the same ID to ensure we cover against regressions.
187+ // https://github.com/modelcontextprotocol/csharp-sdk/issues/464
188+ Assert . Single ( mockLoggerProvider . LogMessages , m =>
189+ m . Category == "ModelContextProtocol.Client.McpClient" &&
190+ m . Message . Contains ( "request '2' for method 'tools/call'" ) ) ;
191+
192+ Assert . Single ( mockLoggerProvider . LogMessages , m =>
193+ m . Category == "ModelContextProtocol.Server.McpServer" &&
194+ m . Message . Contains ( "request '2' for method 'sampling/createMessage'" ) ) ;
195+ }
196+
197+ private ClaimsPrincipal CreateUser ( string name )
115198 => new ClaimsPrincipal ( new ClaimsIdentity (
116199 [ new Claim ( "name" , name ) , new Claim ( ClaimTypes . NameIdentifier , name ) ] ,
117200 "TestAuthType" , "name" , "role" ) ) ;
118201
119202 [ McpServerToolType ]
120- protected class EchoHttpContextUserTools ( IHttpContextAccessor contextAccessor )
203+ private class EchoHttpContextUserTools ( IHttpContextAccessor contextAccessor )
121204 {
122205 [ McpServerTool , Description ( "Echoes the input back to the client with their user name." ) ]
123206 public string EchoWithUserName ( string message )
@@ -127,4 +210,37 @@ public string EchoWithUserName(string message)
127210 return $ "{ userName } : { message } ";
128211 }
129212 }
213+
214+ [ McpServerToolType ]
215+ private class SamplingRegressionTools
216+ {
217+ [ McpServerTool ( Name = "sampling-tool" ) ]
218+ public static async Task < string > SamplingToolAsync ( IMcpServer server , string prompt , CancellationToken cancellationToken )
219+ {
220+ // This tool reproduces the scenario described in https://github.com/modelcontextprotocol/csharp-sdk/issues/464
221+ // 1. The client calls tool with request ID 2, because it's the first request after the initialize request.
222+ // 2. This tool makes two sampling requests which use IDs 1 and 2.
223+ // 3. In the old buggy Streamable HTTP transport code, this would close the SSE response stream,
224+ // because the second sampling request used an ID matching the tool call.
225+ var samplingRequest = new CreateMessageRequestParams
226+ {
227+ Messages = [
228+ new SamplingMessage
229+ {
230+ Role = Role . User ,
231+ Content = new Content
232+ {
233+ Type = "text" ,
234+ Text = prompt
235+ } ,
236+ }
237+ ] ,
238+ } ;
239+
240+ await server . SampleAsync ( samplingRequest , cancellationToken ) ;
241+ var samplingResult = await server . SampleAsync ( samplingRequest , cancellationToken ) ;
242+
243+ return $ "Sampling completed successfully. Client responded: { samplingResult . Content . Text } ";
244+ }
245+ }
130246}
0 commit comments