33
44using System ;
55using System . Buffers ;
6+ using System . IO . Pipelines ;
7+ using System . Security . Authentication ;
8+ using System . Threading ;
69using System . Threading . Tasks ;
710using Microsoft . AspNetCore . Connections ;
811using Microsoft . Extensions . Logging ;
@@ -12,99 +15,181 @@ namespace Yarp.ReverseProxy.Sample;
1215
1316public static class TlsFilter
1417{
18+ // Use reasonable limits. Parsing across multiple segments has an O(N^2) worst case, so limit the N.
19+ private const int ClientHelloTimeoutMs = 10_000 ;
20+ private const int MaxClientHelloSize = 10 * 1024 ; // 10 KB
21+
1522 // This sniffs the TLS handshake and rejects requests that meat specific criteria.
1623 internal static async Task ProcessAsync ( ConnectionContext connectionContext , Func < Task > next , ILogger logger )
1724 {
18- var input = connectionContext . Transport . Input ;
19- // Count how many bytes we've examined so we never go backwards, Pipes don't allow that.
20- var minBytesExamined = 0L ;
21- while ( true )
25+ using ( var timeoutCts = CancellationTokenSource . CreateLinkedTokenSource ( connectionContext . ConnectionClosed ) )
2226 {
23- var result = await input . ReadAsync ( ) ;
24- var buffer = result . Buffer ;
25-
26- if ( result . IsCompleted )
27- {
28- return ;
29- }
30-
31- if ( buffer . Length == 0 )
32- {
33- continue ;
34- }
27+ timeoutCts . CancelAfter ( ClientHelloTimeoutMs ) ;
3528
36- if ( ! TryReadHello ( buffer , logger , out var abort ) )
37- {
38- minBytesExamined = buffer . Length ;
39- input . AdvanceTo ( buffer . Start , buffer . End ) ;
40- continue ;
41- }
29+ var input = connectionContext . Transport . Input ;
4230
43- var examined = buffer . Slice ( buffer . Start , minBytesExamined ) . End ;
44- input . AdvanceTo ( buffer . Start , examined ) ;
31+ // Count how many bytes we've examined so we never go backwards, Pipes don't allow that.
32+ var minBytesExamined = 0L ;
4533
46- if ( abort )
34+ while ( true )
4735 {
48- // Close the connection.
49- return ;
36+ var result = await input . ReadAsync ( timeoutCts . Token ) ;
37+ var buffer = result . Buffer ;
38+
39+ if ( result . IsCompleted || result . IsCanceled )
40+ {
41+ return ;
42+ }
43+
44+ if ( buffer . Length == 0 )
45+ {
46+ continue ;
47+ }
48+
49+ if ( ! TryReadTlsFrame ( buffer , logger , out var frameInfo ) && frameInfo . ParsingStatus == TlsFrameHelper . ParsingStatus . IncompleteFrame )
50+ {
51+ // We didn't find a TLS frame, we need to read more data.
52+ minBytesExamined = buffer . Length ;
53+
54+ if ( minBytesExamined >= MaxClientHelloSize )
55+ {
56+ logger . LogInformation ( "Client Hello too large. Aborting." ) ;
57+ return ;
58+ }
59+
60+ input . AdvanceTo ( buffer . Start , buffer . End ) ;
61+ continue ;
62+ }
63+
64+ // We're done. We either have a frame we can analyze, or we're giving up.
65+ var examined = buffer . Slice ( buffer . Start , minBytesExamined ) . End ;
66+ input . AdvanceTo ( buffer . Start , examined ) ;
67+
68+ if ( frameInfo . ParsingStatus != TlsFrameHelper . ParsingStatus . Ok || frameInfo . HandshakeType != TlsHandshakeType . ClientHello )
69+ {
70+ logger . LogInformation ( "Invalid or unexpected TLS frame. Aborting." ) ;
71+ return ;
72+ }
73+
74+ // Perform any additional validation on the Client Hello here.
75+ // Rate limiting, throttling checks, J4A fingerprinting, logging, etc. can be performed here as well.
76+
77+ if ( ! TryProcessClientHello ( frameInfo , logger ) )
78+ {
79+ // Abort the connection.
80+ return ;
81+ }
82+
83+ // All checks passed, we can continue processing the request.
84+
85+ #if ! NET10_0_OR_GREATER
86+ // Workaround for https://github.com/dotnet/runtime/issues/107213, which was fixed in .NET 10.
87+ if ( minBytesExamined > 0 )
88+ {
89+ connectionContext . Transport = new DuplexPipe (
90+ PipeReader . Create ( input . AsStream ( ) , new StreamPipeReaderOptions ( bufferSize : Math . Max ( 4096 , ( int ) minBytesExamined ) ) ) ,
91+ connectionContext . Transport . Output ) ;
92+ }
93+ #endif
94+
95+ break ;
5096 }
51-
52- break ;
5397 }
5498
5599 await next ( ) ;
56100 }
57101
58- private static bool TryReadHello ( ReadOnlySequence < byte > buffer , ILogger logger , out bool abort )
102+ /// <summary>Process the Client Hello and returns whether it passed validation.</summary>
103+ private static bool TryProcessClientHello ( TlsFrameHelper . TlsFrameInfo clientHello , ILogger logger )
59104 {
60- abort = false ;
105+ // This is a sample demonstrating several checks you can perform on the Client Hello.
106+ // Replace the logic in this method with your own validation logic.
61107
62- if ( ! buffer . IsSingleSegment )
63- {
64- throw new NotImplementedException ( "Multiple buffer segments" ) ;
65- }
66- var data = buffer . First . Span ;
108+ string sni = clientHello . TargetName ;
67109
68- TlsFrameHelper . TlsFrameInfo info = default ;
69- if ( ! TlsFrameHelper . TryGetFrameInfo ( data , ref info ) )
110+ if ( string . IsNullOrEmpty ( sni ) )
70111 {
71- if ( info . ParsingStatus == TlsFrameHelper . ParsingStatus . InvalidFrame )
72- {
73- logger . LogInformation ( "Invalid TLS frame" ) ;
74- abort = true ;
75- }
112+ logger . LogInformation ( "Expected SNI to be specified." ) ;
76113 return false ;
77114 }
78115
79- if ( ! info . SupportedVersions . HasFlag ( System . Security . Authentication . SslProtocols . Tls12 ) )
116+ if ( ! AllowHost ( sni ) )
80117 {
81- logger . LogInformation ( "Unsupported versions: {versions}" , info . SupportedVersions ) ;
82- abort = true ;
83- }
84- else
85- {
86- logger . LogInformation ( "Protocol versions: {versions}" , info . SupportedVersions ) ;
118+ logger . LogInformation ( "Unexpected SNI: {sni}." , sni ) ;
119+ return false ;
87120 }
88121
89- if ( ! AllowHost ( info . TargetName ) )
122+ if ( ! clientHello . SupportedVersions . HasFlag ( SslProtocols . Tls12 ) && ! clientHello . SupportedVersions . HasFlag ( SslProtocols . Tls13 ) )
90123 {
91- logger . LogInformation ( "Disallowed host: {host} " , info . TargetName ) ;
92- abort = true ;
124+ logger . LogInformation ( "Client for '{sni}' does not support TLS 1.2 or 1.3. " , sni ) ;
125+ return false ;
93126 }
94- else
127+
128+ if ( ! clientHello . ApplicationProtocols . HasFlag ( TlsFrameHelper . ApplicationProtocolInfo . Http2 ) )
95129 {
96- logger . LogInformation ( "SNI: {host}" , info . TargetName ) ;
130+ logger . LogInformation ( "Client for '{sni}' does not support HTTP/2." , sni ) ;
131+ return false ;
97132 }
98133
134+ // All checks passed, we can continue processing the request.
99135 return true ;
100136 }
101137
102138 private static bool AllowHost ( string targetName )
103139 {
104- if ( string . Equals ( "localhost" , targetName , StringComparison . OrdinalIgnoreCase ) )
140+ return
141+ targetName . Equals ( "localhost" , StringComparison . OrdinalIgnoreCase ) ||
142+ targetName . Equals ( "contoso.com" , StringComparison . OrdinalIgnoreCase ) ;
143+ }
144+
145+ /// <summary>Attempt to parse the first TLS frame from the <paramref name="buffer"/> and indicate whether more data is needed.</summary>
146+ private static bool TryReadTlsFrame ( ReadOnlySequence < byte > buffer , ILogger logger , out TlsFrameHelper . TlsFrameInfo frame )
147+ {
148+ frame = default ;
149+
150+ // Try to process the first segment first.
151+ var data = buffer . First . Span ;
152+
153+ if ( TlsFrameHelper . TryGetFrameInfo ( data , ref frame ) )
105154 {
155+ // This is the common fast path.
106156 return true ;
107157 }
108- return false ;
158+
159+ if ( frame . ParsingStatus != TlsFrameHelper . ParsingStatus . IncompleteFrame )
160+ {
161+ // The input is invalid, reading more data won't help.
162+ return false ;
163+ }
164+
165+ if ( buffer . IsSingleSegment )
166+ {
167+ // We only have one segment and it didn't contain a valid TLS frame. We'll have to read more data.
168+ return false ;
169+ }
170+
171+ // We have multiple segments. TlsFrameHelper only works with a single span, so we need to combine them.
172+ // This may happen on every new read, which is why we limit how much data we're willing to process.
173+
174+ var pooledBuffer = ArrayPool < byte > . Shared . Rent ( ( int ) buffer . Length ) ;
175+ buffer . CopyTo ( pooledBuffer ) ;
176+ data = pooledBuffer . AsSpan ( 0 , ( int ) buffer . Length ) ;
177+
178+ bool success = TlsFrameHelper . TryGetFrameInfo ( data , ref frame ) ;
179+
180+ ArrayPool < byte > . Shared . Return ( pooledBuffer ) ;
181+
182+ if ( success )
183+ {
184+ logger . LogDebug ( "Parsed multi-segment TLS frame after {length} bytes" , buffer . Length ) ;
185+ }
186+
187+ return success ;
188+ }
189+
190+ private sealed class DuplexPipe ( PipeReader input , PipeWriter output ) : IDuplexPipe
191+ {
192+ public PipeReader Input { get ; } = input ;
193+ public PipeWriter Output { get ; } = output ;
109194 }
110195}
0 commit comments