Skip to content

Commit 96322f1

Browse files
authored
Fix Sec-WebSocket-Key handling during H2 => H1 WebSocket downgrades (#2806)
* Fix Sec-WebSocket-Key handling during H2 => H1 WebSocket downgrades * Fix tests
1 parent 6b1876b commit 96322f1

File tree

7 files changed

+150
-13
lines changed

7 files changed

+150
-13
lines changed

src/ReverseProxy/Forwarder/HttpForwarder.cs

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,13 @@ public async ValueTask<ForwarderError> SendAsync(
201201
(destinationRequest, requestContent, _) = await CreateRequestMessageAsync(
202202
context, destinationPrefix, transformer, config, isStreamingRequest, activityCancellationSource);
203203

204+
// Transforms generated a response, do not proxy.
205+
if (RequestUtilities.IsResponseSet(context.Response))
206+
{
207+
Log.NotProxying(_logger, context.Response.StatusCode);
208+
return ForwarderError.None;
209+
}
210+
204211
destinationResponse = await httpClient.SendAsync(destinationRequest, activityCancellationSource.Token);
205212
}
206213
}
@@ -471,8 +478,22 @@ private void FixupUpgradeRequestHeaders(HttpContext context, HttpRequestMessage
471478
{
472479
request.Headers.TryAddWithoutValidation(HeaderNames.Connection, HeaderNames.Upgrade);
473480
request.Headers.TryAddWithoutValidation(HeaderNames.Upgrade, WebSocketName);
474-
var key = ProtocolHelper.CreateSecWebSocketKey();
475-
request.Headers.TryAddWithoutValidation(HeaderNames.SecWebSocketKey, key);
481+
482+
// The client shouldn't be sending a Sec-WebSocket-Key header with H2 WebSockets, but if it did, let's use it.
483+
if (RequestUtilities.TryGetValues(request.Headers, HeaderNames.SecWebSocketKey, out var clientKey))
484+
{
485+
if (!ProtocolHelper.CheckSecWebSocketKey(clientKey))
486+
{
487+
Log.InvalidSecWebSocketKeyHeader(_logger, clientKey);
488+
// The request will not be forwarded if we change the status code.
489+
context.Response.StatusCode = StatusCodes.Status400BadRequest;
490+
}
491+
}
492+
else
493+
{
494+
var key = ProtocolHelper.CreateSecWebSocketKey();
495+
request.Headers.TryAddWithoutValidation(HeaderNames.SecWebSocketKey, key);
496+
}
476497
}
477498
// H1->H1, re-add the original Connection, Upgrade headers.
478499
else

src/ReverseProxy/Forwarder/ProtocolHelper.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,20 @@ internal static bool CheckSecWebSocketKey(string? key)
6868
/// </summary>
6969
internal static string CreateSecWebSocketAccept(string? key)
7070
{
71-
Debug.Assert(CheckSecWebSocketKey(key)); // This should have already been validated elsewhere.
71+
if (!CheckSecWebSocketKey(key))
72+
{
73+
// This could happen if a custom message handler modified headers incorrectly.
74+
Debug.Fail("This should have already been validated elsewhere");
75+
throw new InvalidOperationException("Unexpected Sec-WebSocket-Key header format.");
76+
}
77+
7278
// GUID appended by the server as part of the security key response. Defined in the RFC.
7379
var wsServerGuidBytes = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"u8;
7480
Span<byte> bytes = stackalloc byte[24 /* Base64 guid length */ + wsServerGuidBytes.Length];
7581

7682
// Get the corresponding ASCII bytes for seckey+wsServerGuidBytes
7783
var encodedSecKeyLength = Encoding.ASCII.GetBytes(key, bytes);
84+
Debug.Assert(encodedSecKeyLength == 24);
7885
wsServerGuidBytes.CopyTo(bytes.Slice(encodedSecKeyLength));
7986

8087
// Hash the seckey+wsServerGuidBytes bytes

test/ReverseProxy.FunctionalTests/Common/TestEnvironment.cs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
using Microsoft.Extensions.DependencyInjection;
1818
using Microsoft.Extensions.Hosting;
1919
using Microsoft.Extensions.Logging;
20+
using Xunit;
2021
using Xunit.Abstractions;
2122
using Yarp.ReverseProxy.Configuration;
2223
using Yarp.Tests.Common;
@@ -73,7 +74,8 @@ public async Task Invoke(Func<string, Task> clientFunc, CancellationToken cancel
7374
ConfigureDestinationServices, ConfigureDestinationApp, UseHttpSysOnDestination);
7475
await destination.StartAsync(cancellationToken);
7576

76-
using var proxy = CreateProxy(destination.GetAddress());
77+
Exception proxyException = null;
78+
using var proxy = CreateProxy(destination.GetAddress(), ex => proxyException = ex);
7779
await proxy.StartAsync(cancellationToken);
7880

7981
try
@@ -85,9 +87,11 @@ public async Task Invoke(Func<string, Task> clientFunc, CancellationToken cancel
8587
await proxy.StopAsync(cancellationToken);
8688
await destination.StopAsync(cancellationToken);
8789
}
90+
91+
Assert.Null(proxyException);
8892
}
8993

90-
public IHost CreateProxy(string destinationAddress)
94+
public IHost CreateProxy(string destinationAddress, Action<Exception> onProxyException = null)
9195
{
9296
return CreateHost(ProxyProtocol, UseHttpsOnProxy, HeaderEncoding,
9397
services =>
@@ -125,6 +129,19 @@ public IHost CreateProxy(string destinationAddress)
125129
},
126130
app =>
127131
{
132+
app.Use(async (context, next) =>
133+
{
134+
try
135+
{
136+
await next();
137+
}
138+
catch (Exception ex)
139+
{
140+
onProxyException?.Invoke(ex);
141+
throw;
142+
}
143+
});
144+
128145
ConfigureProxyApp(app);
129146
app.UseRouting();
130147
app.UseEndpoints(builder =>
@@ -142,6 +159,7 @@ private IHost CreateHost(HttpProtocols protocols, bool useHttps, Encoding reques
142159
{
143160
config.AddInMemoryCollection(new Dictionary<string, string>()
144161
{
162+
{ "Logging:LogLevel:Yarp", "Trace" },
145163
{ "Logging:LogLevel:Microsoft", "Trace" },
146164
{ "Logging:LogLevel:Microsoft.AspNetCore.Hosting.Diagnostics", "Information" }
147165
});
@@ -152,7 +170,7 @@ private IHost CreateHost(HttpProtocols protocols, bool useHttps, Encoding reques
152170
loggingBuilder.AddEventSourceLogger();
153171
if (TestOutput != null)
154172
{
155-
loggingBuilder.AddXunit(TestOutput);
173+
loggingBuilder.Services.AddSingleton<ILoggerProvider>(new TestLoggerProvider(TestOutput));
156174
}
157175
})
158176
.ConfigureWebHost(webHostBuilder =>

test/ReverseProxy.FunctionalTests/WebSocketTests.cs

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -486,13 +486,26 @@ public async Task InvalidKeyHeader_400(HttpProtocols destinationProtocol)
486486
var test = CreateTestEnvironment();
487487
test.ProxyProtocol = HttpProtocols.Http1;
488488
test.DestinationProtocol = destinationProtocol;
489+
test.DestinationHttpVersionPolicy = HttpVersionPolicy.RequestVersionExact;
490+
test.DestinationHttpVersion = destinationProtocol == HttpProtocols.Http1 ? HttpVersion.Version11 : HttpVersion.Version20;
489491

490492
test.ConfigureProxyApp = builder =>
491493
{
492-
builder.Use((context, next) =>
494+
builder.Use(async (context, next) =>
493495
{
494496
context.Request.Headers[HeaderNames.SecWebSocketKey] = "ThisIsAnIncorrectKeyHeaderLongerThan24Bytes";
495-
return next(context);
497+
498+
var logs = TestLogger.Collect();
499+
await next(context);
500+
501+
if (destinationProtocol == HttpProtocols.Http1)
502+
{
503+
Assert.DoesNotContain(logs, log => log.EventId == EventIds.InvalidSecWebSocketKeyHeader);
504+
}
505+
else
506+
{
507+
Assert.Contains(logs, log => log.EventId == EventIds.InvalidSecWebSocketKeyHeader);
508+
}
496509
});
497510
};
498511

@@ -510,6 +523,79 @@ await test.Invoke(async uri =>
510523
}, cts.Token);
511524
}
512525

526+
[Fact]
527+
public async Task WebSocket20_To_11_WithWellFormedKeyHeader_OriginalKeyIsUsed()
528+
{
529+
using var cts = CreateTimer();
530+
531+
var clientKey = ProtocolHelper.CreateSecWebSocketKey();
532+
533+
var test = CreateTestEnvironment();
534+
test.ProxyProtocol = HttpProtocols.Http2;
535+
test.DestinationProtocol = HttpProtocols.Http1;
536+
537+
var originalDestinationApp = test.ConfigureDestinationApp;
538+
test.ConfigureDestinationApp = app =>
539+
{
540+
app.Use((context, next) =>
541+
{
542+
Assert.True(context.Request.Headers.TryGetValue(HeaderNames.SecWebSocketKey, out var key));
543+
Assert.Equal(clientKey, key);
544+
return next(context);
545+
});
546+
originalDestinationApp(app);
547+
};
548+
549+
await test.Invoke(async uri =>
550+
{
551+
using var client = new ClientWebSocket();
552+
client.Options.HttpVersion = HttpVersion.Version20;
553+
client.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionExact;
554+
555+
client.Options.SetRequestHeader(HeaderNames.SecWebSocketKey, clientKey);
556+
557+
await SendWebSocketRequestAsync(client, uri, "HTTP/1.1", cts.Token);
558+
}, cts.Token);
559+
}
560+
561+
[Fact]
562+
public async Task WebSocket20_To_11_WithInvalidKeyHeader_RequestRejected()
563+
{
564+
using var cts = CreateTimer();
565+
566+
var test = CreateTestEnvironment();
567+
test.ProxyProtocol = HttpProtocols.Http2;
568+
test.DestinationProtocol = HttpProtocols.Http1;
569+
570+
test.ConfigureProxyApp = builder =>
571+
{
572+
builder.Use(async (context, next) =>
573+
{
574+
var logs = TestLogger.Collect();
575+
await next(context);
576+
Assert.Contains(logs, log => log.EventId == EventIds.InvalidSecWebSocketKeyHeader);
577+
});
578+
};
579+
580+
await test.Invoke(async uri =>
581+
{
582+
var webSocketsTarget = uri.Replace("http://", "ws://");
583+
var targetUri = new Uri(new Uri(webSocketsTarget, UriKind.Absolute), "websockets");
584+
585+
using var client = new ClientWebSocket();
586+
client.Options.HttpVersion = HttpVersion.Version20;
587+
client.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionExact;
588+
client.Options.CollectHttpResponseDetails = true;
589+
590+
client.Options.SetRequestHeader(HeaderNames.SecWebSocketKey, "Foo");
591+
592+
using var invoker = CreateInvoker();
593+
var wse = await Assert.ThrowsAsync<WebSocketException>(() => client.ConnectAsync(targetUri, invoker, cts.Token));
594+
Assert.Equal("The server returned status code '400' when status code '200' was expected.", wse.Message);
595+
Assert.Equal(HttpStatusCode.BadRequest, client.HttpStatusCode);
596+
}, cts.Token);
597+
}
598+
513599
private async Task SendWebSocketRequestAsync(ClientWebSocket client, string uri, string destinationProtocol, CancellationToken token)
514600
{
515601
var webSocketsTarget = uri.Replace("https://", "wss://").Replace("http://", "ws://");
@@ -520,8 +606,8 @@ private async Task SendWebSocketRequestAsync(ClientWebSocket client, string uri,
520606

521607
var buffer = new byte[1024];
522608
var textToSend = $"Hello World!";
523-
var numBytes = Encoding.UTF8.GetBytes(textToSend, buffer.AsSpan());
524-
await client.SendAsync(new ArraySegment<byte>(buffer, 0, numBytes),
609+
var numBytes = Encoding.UTF8.GetBytes(textToSend, buffer);
610+
await client.SendAsync(buffer.AsMemory(0, numBytes),
525611
WebSocketMessageType.Text,
526612
endOfMessage: true,
527613
token);

test/ReverseProxy.Tests/Configuration/ConfigProvider/ConfigurationConfigProviderTests.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,11 @@ public async Task ValidateSchema_Samples()
676676

677677
foreach (string file in Directory.EnumerateFiles(repoRoot, "*.json", SearchOption.AllDirectories))
678678
{
679+
if (file.Contains("\\obj\\", StringComparison.Ordinal) || file.Contains("/obj/", StringComparison.Ordinal))
680+
{
681+
continue;
682+
}
683+
679684
if (file.Contains("appsettings", StringComparison.OrdinalIgnoreCase))
680685
{
681686
var contents = await File.ReadAllTextAsync(file);
@@ -688,7 +693,7 @@ public async Task ValidateSchema_Samples()
688693

689694
if (contents.Contains("\"ReverseProxy\"", StringComparison.OrdinalIgnoreCase))
690695
{
691-
Assert.True(results.Details.Count > 5);
696+
Assert.True(results.Details.Count > 5, $"No details for '{file}'");
692697
}
693698
}
694699
else

test/ReverseProxy.Tests/Common/TestLogger.cs renamed to test/Tests.Common/TestLogger.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
namespace Yarp.ReverseProxy.Common;
1010

11-
internal sealed class TestLogger(ILogger xunitLogger, string categoryName) : ILogger
11+
public sealed class TestLogger(ILogger xunitLogger, string categoryName) : ILogger
1212
{
1313
public record LogEntry(string CategoryName, LogLevel LogLevel, EventId EventId, string Message, Exception Exception);
1414

test/ReverseProxy.Tests/Common/TestLoggerProvider.cs renamed to test/Tests.Common/TestLoggerProvider.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
namespace Yarp.ReverseProxy.Common;
99

10-
internal sealed class TestLoggerProvider(ITestOutputHelper output) : ILoggerProvider
10+
public sealed class TestLoggerProvider(ITestOutputHelper output) : ILoggerProvider
1111
{
1212
private readonly XunitLoggerProvider _xunitLoggerProvider = new(output);
1313

0 commit comments

Comments
 (0)