Skip to content

Commit 5f95192

Browse files
MihaZupanMiha Zupan
andauthored
Improve the TlsFilter sample (dotnet#2879)
Co-authored-by: Miha Zupan <mizupan@microsoft.com>
1 parent 452f1a7 commit 5f95192

File tree

1 file changed

+142
-57
lines changed

1 file changed

+142
-57
lines changed

testassets/ReverseProxy.Direct/TlsFilter.cs

Lines changed: 142 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
using System;
55
using System.Buffers;
6+
using System.IO.Pipelines;
7+
using System.Security.Authentication;
8+
using System.Threading;
69
using System.Threading.Tasks;
710
using Microsoft.AspNetCore.Connections;
811
using Microsoft.Extensions.Logging;
@@ -12,99 +15,181 @@ namespace Yarp.ReverseProxy.Sample;
1215

1316
public static class TlsFilter
1417
{
18+
// Use reasonable limits. Parsing across multiple segments has an O(N^2) worst case, so limit the N.
19+
private const int ClientHelloTimeoutMs = 10_000;
20+
private const int MaxClientHelloSize = 10 * 1024; // 10 KB
21+
1522
// This sniffs the TLS handshake and rejects requests that meat specific criteria.
1623
internal static async Task ProcessAsync(ConnectionContext connectionContext, Func<Task> next, ILogger logger)
1724
{
18-
var input = connectionContext.Transport.Input;
19-
// Count how many bytes we've examined so we never go backwards, Pipes don't allow that.
20-
var minBytesExamined = 0L;
21-
while (true)
25+
using (var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(connectionContext.ConnectionClosed))
2226
{
23-
var result = await input.ReadAsync();
24-
var buffer = result.Buffer;
25-
26-
if (result.IsCompleted)
27-
{
28-
return;
29-
}
30-
31-
if (buffer.Length == 0)
32-
{
33-
continue;
34-
}
27+
timeoutCts.CancelAfter(ClientHelloTimeoutMs);
3528

36-
if (!TryReadHello(buffer, logger, out var abort))
37-
{
38-
minBytesExamined = buffer.Length;
39-
input.AdvanceTo(buffer.Start, buffer.End);
40-
continue;
41-
}
29+
var input = connectionContext.Transport.Input;
4230

43-
var examined = buffer.Slice(buffer.Start, minBytesExamined).End;
44-
input.AdvanceTo(buffer.Start, examined);
31+
// Count how many bytes we've examined so we never go backwards, Pipes don't allow that.
32+
var minBytesExamined = 0L;
4533

46-
if (abort)
34+
while (true)
4735
{
48-
// Close the connection.
49-
return;
36+
var result = await input.ReadAsync(timeoutCts.Token);
37+
var buffer = result.Buffer;
38+
39+
if (result.IsCompleted || result.IsCanceled)
40+
{
41+
return;
42+
}
43+
44+
if (buffer.Length == 0)
45+
{
46+
continue;
47+
}
48+
49+
if (!TryReadTlsFrame(buffer, logger, out var frameInfo) && frameInfo.ParsingStatus == TlsFrameHelper.ParsingStatus.IncompleteFrame)
50+
{
51+
// We didn't find a TLS frame, we need to read more data.
52+
minBytesExamined = buffer.Length;
53+
54+
if (minBytesExamined >= MaxClientHelloSize)
55+
{
56+
logger.LogInformation("Client Hello too large. Aborting.");
57+
return;
58+
}
59+
60+
input.AdvanceTo(buffer.Start, buffer.End);
61+
continue;
62+
}
63+
64+
// We're done. We either have a frame we can analyze, or we're giving up.
65+
var examined = buffer.Slice(buffer.Start, minBytesExamined).End;
66+
input.AdvanceTo(buffer.Start, examined);
67+
68+
if (frameInfo.ParsingStatus != TlsFrameHelper.ParsingStatus.Ok || frameInfo.HandshakeType != TlsHandshakeType.ClientHello)
69+
{
70+
logger.LogInformation("Invalid or unexpected TLS frame. Aborting.");
71+
return;
72+
}
73+
74+
// Perform any additional validation on the Client Hello here.
75+
// Rate limiting, throttling checks, J4A fingerprinting, logging, etc. can be performed here as well.
76+
77+
if (!TryProcessClientHello(frameInfo, logger))
78+
{
79+
// Abort the connection.
80+
return;
81+
}
82+
83+
// All checks passed, we can continue processing the request.
84+
85+
#if !NET10_0_OR_GREATER
86+
// Workaround for https://github.com/dotnet/runtime/issues/107213, which was fixed in .NET 10.
87+
if (minBytesExamined > 0)
88+
{
89+
connectionContext.Transport = new DuplexPipe(
90+
PipeReader.Create(input.AsStream(), new StreamPipeReaderOptions(bufferSize: Math.Max(4096, (int)minBytesExamined))),
91+
connectionContext.Transport.Output);
92+
}
93+
#endif
94+
95+
break;
5096
}
51-
52-
break;
5397
}
5498

5599
await next();
56100
}
57101

58-
private static bool TryReadHello(ReadOnlySequence<byte> buffer, ILogger logger, out bool abort)
102+
/// <summary>Process the Client Hello and returns whether it passed validation.</summary>
103+
private static bool TryProcessClientHello(TlsFrameHelper.TlsFrameInfo clientHello, ILogger logger)
59104
{
60-
abort = false;
105+
// This is a sample demonstrating several checks you can perform on the Client Hello.
106+
// Replace the logic in this method with your own validation logic.
61107

62-
if (!buffer.IsSingleSegment)
63-
{
64-
throw new NotImplementedException("Multiple buffer segments");
65-
}
66-
var data = buffer.First.Span;
108+
string sni = clientHello.TargetName;
67109

68-
TlsFrameHelper.TlsFrameInfo info = default;
69-
if (!TlsFrameHelper.TryGetFrameInfo(data, ref info))
110+
if (string.IsNullOrEmpty(sni))
70111
{
71-
if (info.ParsingStatus == TlsFrameHelper.ParsingStatus.InvalidFrame)
72-
{
73-
logger.LogInformation("Invalid TLS frame");
74-
abort = true;
75-
}
112+
logger.LogInformation("Expected SNI to be specified.");
76113
return false;
77114
}
78115

79-
if (!info.SupportedVersions.HasFlag(System.Security.Authentication.SslProtocols.Tls12))
116+
if (!AllowHost(sni))
80117
{
81-
logger.LogInformation("Unsupported versions: {versions}", info.SupportedVersions);
82-
abort = true;
83-
}
84-
else
85-
{
86-
logger.LogInformation("Protocol versions: {versions}", info.SupportedVersions);
118+
logger.LogInformation("Unexpected SNI: {sni}.", sni);
119+
return false;
87120
}
88121

89-
if (!AllowHost(info.TargetName))
122+
if (!clientHello.SupportedVersions.HasFlag(SslProtocols.Tls12) && !clientHello.SupportedVersions.HasFlag(SslProtocols.Tls13))
90123
{
91-
logger.LogInformation("Disallowed host: {host}", info.TargetName);
92-
abort = true;
124+
logger.LogInformation("Client for '{sni}' does not support TLS 1.2 or 1.3.", sni);
125+
return false;
93126
}
94-
else
127+
128+
if (!clientHello.ApplicationProtocols.HasFlag(TlsFrameHelper.ApplicationProtocolInfo.Http2))
95129
{
96-
logger.LogInformation("SNI: {host}", info.TargetName);
130+
logger.LogInformation("Client for '{sni}' does not support HTTP/2.", sni);
131+
return false;
97132
}
98133

134+
// All checks passed, we can continue processing the request.
99135
return true;
100136
}
101137

102138
private static bool AllowHost(string targetName)
103139
{
104-
if (string.Equals("localhost", targetName, StringComparison.OrdinalIgnoreCase))
140+
return
141+
targetName.Equals("localhost", StringComparison.OrdinalIgnoreCase) ||
142+
targetName.Equals("contoso.com", StringComparison.OrdinalIgnoreCase);
143+
}
144+
145+
/// <summary>Attempt to parse the first TLS frame from the <paramref name="buffer"/> and indicate whether more data is needed.</summary>
146+
private static bool TryReadTlsFrame(ReadOnlySequence<byte> buffer, ILogger logger, out TlsFrameHelper.TlsFrameInfo frame)
147+
{
148+
frame = default;
149+
150+
// Try to process the first segment first.
151+
var data = buffer.First.Span;
152+
153+
if (TlsFrameHelper.TryGetFrameInfo(data, ref frame))
105154
{
155+
// This is the common fast path.
106156
return true;
107157
}
108-
return false;
158+
159+
if (frame.ParsingStatus != TlsFrameHelper.ParsingStatus.IncompleteFrame)
160+
{
161+
// The input is invalid, reading more data won't help.
162+
return false;
163+
}
164+
165+
if (buffer.IsSingleSegment)
166+
{
167+
// We only have one segment and it didn't contain a valid TLS frame. We'll have to read more data.
168+
return false;
169+
}
170+
171+
// We have multiple segments. TlsFrameHelper only works with a single span, so we need to combine them.
172+
// This may happen on every new read, which is why we limit how much data we're willing to process.
173+
174+
var pooledBuffer = ArrayPool<byte>.Shared.Rent((int)buffer.Length);
175+
buffer.CopyTo(pooledBuffer);
176+
data = pooledBuffer.AsSpan(0, (int)buffer.Length);
177+
178+
bool success = TlsFrameHelper.TryGetFrameInfo(data, ref frame);
179+
180+
ArrayPool<byte>.Shared.Return(pooledBuffer);
181+
182+
if (success)
183+
{
184+
logger.LogDebug("Parsed multi-segment TLS frame after {length} bytes", buffer.Length);
185+
}
186+
187+
return success;
188+
}
189+
190+
private sealed class DuplexPipe(PipeReader input, PipeWriter output) : IDuplexPipe
191+
{
192+
public PipeReader Input { get; } = input;
193+
public PipeWriter Output { get; } = output;
109194
}
110195
}

0 commit comments

Comments
 (0)