Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions src/ReverseProxy/Forwarder/HttpForwarder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,13 @@ public async ValueTask<ForwarderError> SendAsync(
(destinationRequest, requestContent, _) = await CreateRequestMessageAsync(
context, destinationPrefix, transformer, config, isStreamingRequest, activityCancellationSource);

// Transforms generated a response, do not proxy.
if (RequestUtilities.IsResponseSet(context.Response))
{
Log.NotProxying(_logger, context.Response.StatusCode);
return ForwarderError.None;
}

destinationResponse = await httpClient.SendAsync(destinationRequest, activityCancellationSource.Token);
}
}
Expand Down Expand Up @@ -471,8 +478,22 @@ private void FixupUpgradeRequestHeaders(HttpContext context, HttpRequestMessage
{
request.Headers.TryAddWithoutValidation(HeaderNames.Connection, HeaderNames.Upgrade);
request.Headers.TryAddWithoutValidation(HeaderNames.Upgrade, WebSocketName);
var key = ProtocolHelper.CreateSecWebSocketKey();
request.Headers.TryAddWithoutValidation(HeaderNames.SecWebSocketKey, key);

// The client shouldn't be sending a Sec-WebSocket-Key header with H2 WebSockets, but if it did, let's use it.
if (RequestUtilities.TryGetValues(request.Headers, HeaderNames.SecWebSocketKey, out var clientKey))
{
if (!ProtocolHelper.CheckSecWebSocketKey(clientKey))
{
Log.InvalidSecWebSocketKeyHeader(_logger, clientKey);
// The request will not be forwarded if we change the status code.
context.Response.StatusCode = StatusCodes.Status400BadRequest;
}
}
else
{
var key = ProtocolHelper.CreateSecWebSocketKey();
request.Headers.TryAddWithoutValidation(HeaderNames.SecWebSocketKey, key);
}
}
// H1->H1, re-add the original Connection, Upgrade headers.
else
Expand Down
9 changes: 8 additions & 1 deletion src/ReverseProxy/Forwarder/ProtocolHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,20 @@ internal static bool CheckSecWebSocketKey(string? key)
/// </summary>
internal static string CreateSecWebSocketAccept(string? key)
{
Debug.Assert(CheckSecWebSocketKey(key)); // This should have already been validated elsewhere.
if (!CheckSecWebSocketKey(key))
{
// This could happen if a custom message handler modified headers incorrectly.
Debug.Fail("This should have already been validated elsewhere");
throw new InvalidOperationException("Unexpected Sec-WebSocket-Key header format.");
}

// GUID appended by the server as part of the security key response. Defined in the RFC.
var wsServerGuidBytes = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"u8;
Span<byte> bytes = stackalloc byte[24 /* Base64 guid length */ + wsServerGuidBytes.Length];

// Get the corresponding ASCII bytes for seckey+wsServerGuidBytes
var encodedSecKeyLength = Encoding.ASCII.GetBytes(key, bytes);
Debug.Assert(encodedSecKeyLength == 24);
wsServerGuidBytes.CopyTo(bytes.Slice(encodedSecKeyLength));

// Hash the seckey+wsServerGuidBytes bytes
Expand Down
24 changes: 21 additions & 3 deletions test/ReverseProxy.FunctionalTests/Common/TestEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Xunit;
using Xunit.Abstractions;
using Yarp.ReverseProxy.Configuration;
using Yarp.Tests.Common;
Expand Down Expand Up @@ -73,7 +74,8 @@ public async Task Invoke(Func<string, Task> clientFunc, CancellationToken cancel
ConfigureDestinationServices, ConfigureDestinationApp, UseHttpSysOnDestination);
await destination.StartAsync(cancellationToken);

using var proxy = CreateProxy(destination.GetAddress());
Exception proxyException = null;
using var proxy = CreateProxy(destination.GetAddress(), ex => proxyException = ex);
await proxy.StartAsync(cancellationToken);

try
Expand All @@ -85,9 +87,11 @@ public async Task Invoke(Func<string, Task> clientFunc, CancellationToken cancel
await proxy.StopAsync(cancellationToken);
await destination.StopAsync(cancellationToken);
}

Assert.Null(proxyException);
}

public IHost CreateProxy(string destinationAddress)
public IHost CreateProxy(string destinationAddress, Action<Exception> onProxyException = null)
{
return CreateHost(ProxyProtocol, UseHttpsOnProxy, HeaderEncoding,
services =>
Expand Down Expand Up @@ -125,6 +129,19 @@ public IHost CreateProxy(string destinationAddress)
},
app =>
{
app.Use(async (context, next) =>
{
try
{
await next();
}
catch (Exception ex)
{
onProxyException?.Invoke(ex);
throw;
}
});

ConfigureProxyApp(app);
app.UseRouting();
app.UseEndpoints(builder =>
Expand All @@ -142,6 +159,7 @@ private IHost CreateHost(HttpProtocols protocols, bool useHttps, Encoding reques
{
config.AddInMemoryCollection(new Dictionary<string, string>()
{
{ "Logging:LogLevel:Yarp", "Trace" },
{ "Logging:LogLevel:Microsoft", "Trace" },
{ "Logging:LogLevel:Microsoft.AspNetCore.Hosting.Diagnostics", "Information" }
});
Expand All @@ -152,7 +170,7 @@ private IHost CreateHost(HttpProtocols protocols, bool useHttps, Encoding reques
loggingBuilder.AddEventSourceLogger();
if (TestOutput != null)
{
loggingBuilder.AddXunit(TestOutput);
loggingBuilder.Services.AddSingleton<ILoggerProvider>(new TestLoggerProvider(TestOutput));
}
})
.ConfigureWebHost(webHostBuilder =>
Expand Down
94 changes: 90 additions & 4 deletions test/ReverseProxy.FunctionalTests/WebSocketTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -486,13 +486,26 @@ public async Task InvalidKeyHeader_400(HttpProtocols destinationProtocol)
var test = CreateTestEnvironment();
test.ProxyProtocol = HttpProtocols.Http1;
test.DestinationProtocol = destinationProtocol;
test.DestinationHttpVersionPolicy = HttpVersionPolicy.RequestVersionExact;
test.DestinationHttpVersion = destinationProtocol == HttpProtocols.Http1 ? HttpVersion.Version11 : HttpVersion.Version20;

test.ConfigureProxyApp = builder =>
{
builder.Use((context, next) =>
builder.Use(async (context, next) =>
{
context.Request.Headers[HeaderNames.SecWebSocketKey] = "ThisIsAnIncorrectKeyHeaderLongerThan24Bytes";
return next(context);

var logs = TestLogger.Collect();
await next(context);

if (destinationProtocol == HttpProtocols.Http1)
{
Assert.DoesNotContain(logs, log => log.EventId == EventIds.InvalidSecWebSocketKeyHeader);
}
else
{
Assert.Contains(logs, log => log.EventId == EventIds.InvalidSecWebSocketKeyHeader);
}
});
};

Expand All @@ -510,6 +523,79 @@ await test.Invoke(async uri =>
}, cts.Token);
}

[Fact]
public async Task WebSocket20_To_11_WithWellFormedKeyHeader_OriginalKeyIsUsed()
{
using var cts = CreateTimer();

var clientKey = ProtocolHelper.CreateSecWebSocketKey();

var test = CreateTestEnvironment();
test.ProxyProtocol = HttpProtocols.Http2;
test.DestinationProtocol = HttpProtocols.Http1;

var originalDestinationApp = test.ConfigureDestinationApp;
test.ConfigureDestinationApp = app =>
{
app.Use((context, next) =>
{
Assert.True(context.Request.Headers.TryGetValue(HeaderNames.SecWebSocketKey, out var key));
Assert.Equal(clientKey, key);
return next(context);
});
originalDestinationApp(app);
};

await test.Invoke(async uri =>
{
using var client = new ClientWebSocket();
client.Options.HttpVersion = HttpVersion.Version20;
client.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionExact;

client.Options.SetRequestHeader(HeaderNames.SecWebSocketKey, clientKey);

await SendWebSocketRequestAsync(client, uri, "HTTP/1.1", cts.Token);
}, cts.Token);
}

[Fact]
public async Task WebSocket20_To_11_WithInvalidKeyHeader_RequestRejected()
{
using var cts = CreateTimer();

var test = CreateTestEnvironment();
test.ProxyProtocol = HttpProtocols.Http2;
test.DestinationProtocol = HttpProtocols.Http1;

test.ConfigureProxyApp = builder =>
{
builder.Use(async (context, next) =>
{
var logs = TestLogger.Collect();
await next(context);
Assert.Contains(logs, log => log.EventId == EventIds.InvalidSecWebSocketKeyHeader);
});
};

await test.Invoke(async uri =>
{
var webSocketsTarget = uri.Replace("http://", "ws://");
var targetUri = new Uri(new Uri(webSocketsTarget, UriKind.Absolute), "websockets");

using var client = new ClientWebSocket();
client.Options.HttpVersion = HttpVersion.Version20;
client.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionExact;
client.Options.CollectHttpResponseDetails = true;

client.Options.SetRequestHeader(HeaderNames.SecWebSocketKey, "Foo");

using var invoker = CreateInvoker();
var wse = await Assert.ThrowsAsync<WebSocketException>(() => client.ConnectAsync(targetUri, invoker, cts.Token));
Assert.Equal("The server returned status code '400' when status code '200' was expected.", wse.Message);
Assert.Equal(HttpStatusCode.BadRequest, client.HttpStatusCode);
}, cts.Token);
}

private async Task SendWebSocketRequestAsync(ClientWebSocket client, string uri, string destinationProtocol, CancellationToken token)
{
var webSocketsTarget = uri.Replace("https://", "wss://").Replace("http://", "ws://");
Expand All @@ -520,8 +606,8 @@ private async Task SendWebSocketRequestAsync(ClientWebSocket client, string uri,

var buffer = new byte[1024];
var textToSend = $"Hello World!";
var numBytes = Encoding.UTF8.GetBytes(textToSend, buffer.AsSpan());
await client.SendAsync(new ArraySegment<byte>(buffer, 0, numBytes),
var numBytes = Encoding.UTF8.GetBytes(textToSend, buffer);
await client.SendAsync(buffer.AsMemory(0, numBytes),
WebSocketMessageType.Text,
endOfMessage: true,
token);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,11 @@ public async Task ValidateSchema_Samples()

foreach (string file in Directory.EnumerateFiles(repoRoot, "*.json", SearchOption.AllDirectories))
{
if (file.Contains("\\obj\\", StringComparison.Ordinal) || file.Contains("/obj/", StringComparison.Ordinal))
{
continue;
}

if (file.Contains("appsettings", StringComparison.OrdinalIgnoreCase))
{
var contents = await File.ReadAllTextAsync(file);
Expand All @@ -688,7 +693,7 @@ public async Task ValidateSchema_Samples()

if (contents.Contains("\"ReverseProxy\"", StringComparison.OrdinalIgnoreCase))
{
Assert.True(results.Details.Count > 5);
Assert.True(results.Details.Count > 5, $"No details for '{file}'");
}
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

namespace Yarp.ReverseProxy.Common;

internal sealed class TestLogger(ILogger xunitLogger, string categoryName) : ILogger
public sealed class TestLogger(ILogger xunitLogger, string categoryName) : ILogger
{
public record LogEntry(string CategoryName, LogLevel LogLevel, EventId EventId, string Message, Exception Exception);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

namespace Yarp.ReverseProxy.Common;

internal sealed class TestLoggerProvider(ITestOutputHelper output) : ILoggerProvider
public sealed class TestLoggerProvider(ITestOutputHelper output) : ILoggerProvider
{
private readonly XunitLoggerProvider _xunitLoggerProvider = new(output);

Expand Down
Loading