@@ -142,6 +142,10 @@ protected virtual async Task InvokeAsync(HttpContext context, RequestDelegate ne
142
142
return ;
143
143
}
144
144
145
+ // Perform CSRF protection if necessary
146
+ if ( await HandleCsrfProtectionAsync ( context , next ) )
147
+ return ;
148
+
145
149
// Authenticate request if necessary
146
150
if ( await HandleAuthorizeAsync ( context , next ) )
147
151
return ;
@@ -484,7 +488,36 @@ static void ApplyFileToRequest(IFormFile file, string target, GraphQLRequest? re
484
488
}
485
489
486
490
/// <summary>
487
- /// Perform authentication, if required, and return <see langword="true"/> if the
491
+ /// Performs CSRF protection, if required, and returns <see langword="true"/> if the
492
+ /// request was handled (typically by returning an error message). If <see langword="false"/>
493
+ /// is returned, the request is processed normally.
494
+ /// </summary>
495
+ protected virtual async ValueTask < bool > HandleCsrfProtectionAsync ( HttpContext context , RequestDelegate next )
496
+ {
497
+ if ( ! _options . CsrfProtectionEnabled )
498
+ return false ;
499
+ if ( context . Request . Headers . TryGetValue ( "Content-Type" , out var contentTypes ) && contentTypes . Count > 0 && contentTypes [ 0 ] != null )
500
+ {
501
+ var contentType = contentTypes [ 0 ] ! ;
502
+ if ( contentType . IndexOf ( ';' ) > 0 )
503
+ {
504
+ contentType = contentType . Substring ( 0 , contentType . IndexOf ( ';' ) ) ;
505
+ }
506
+ contentType = contentType . Trim ( ) . ToLowerInvariant ( ) ;
507
+ if ( ! ( contentType == "text/plain" || contentType == "application/x-www-form-urlencoded" || contentType == "multipart/form-data" ) )
508
+ return false ;
509
+ }
510
+ foreach ( var header in _options . CsrfProtectionHeaders )
511
+ {
512
+ if ( context . Request . Headers . TryGetValue ( header , out var values ) && values . Count > 0 && values [ 0 ] ? . Length > 0 )
513
+ return false ;
514
+ }
515
+ await HandleCsrfProtectionErrorAsync ( context , next ) ;
516
+ return true ;
517
+ }
518
+
519
+ /// <summary>
520
+ /// Perform authentication, if required, and returns <see langword="true"/> if the
488
521
/// request was handled (typically by returning an error message). If <see langword="false"/>
489
522
/// is returned, the request is processed normally.
490
523
/// </summary>
@@ -1034,21 +1067,29 @@ protected virtual Task HandleNotAuthorizedPolicyAsync(HttpContext context, Reque
1034
1067
/// </summary>
1035
1068
protected virtual async ValueTask < bool > HandleDeserializationErrorAsync ( HttpContext context , RequestDelegate next , Exception exception )
1036
1069
{
1037
- await WriteErrorResponseAsync ( context , HttpStatusCode . BadRequest , new JsonInvalidError ( exception ) ) ;
1070
+ await WriteErrorResponseAsync ( context , new JsonInvalidError ( exception ) ) ;
1038
1071
return true ;
1039
1072
}
1040
1073
1074
+ /// <summary>
1075
+ /// Writes a '.' message to the output.
1076
+ /// </summary>
1077
+ protected virtual async Task HandleCsrfProtectionErrorAsync ( HttpContext context , RequestDelegate next )
1078
+ {
1079
+ await WriteErrorResponseAsync ( context , new CsrfProtectionError ( _options . CsrfProtectionHeaders ) ) ;
1080
+ }
1081
+
1041
1082
/// <summary>
1042
1083
/// Writes a '400 Batched requests are not supported.' message to the output.
1043
1084
/// </summary>
1044
1085
protected virtual Task HandleBatchedRequestsNotSupportedAsync ( HttpContext context , RequestDelegate next )
1045
- => WriteErrorResponseAsync ( context , HttpStatusCode . BadRequest , new BatchedRequestsNotSupportedError ( ) ) ;
1086
+ => WriteErrorResponseAsync ( context , new BatchedRequestsNotSupportedError ( ) ) ;
1046
1087
1047
1088
/// <summary>
1048
1089
/// Writes a '400 Invalid requested WebSocket sub-protocol(s).' message to the output.
1049
1090
/// </summary>
1050
1091
protected virtual Task HandleWebSocketSubProtocolNotSupportedAsync ( HttpContext context , RequestDelegate next )
1051
- => WriteErrorResponseAsync ( context , HttpStatusCode . BadRequest , new WebSocketSubProtocolNotSupportedError ( context . WebSockets . WebSocketRequestedProtocols ) ) ;
1092
+ => WriteErrorResponseAsync ( context , new WebSocketSubProtocolNotSupportedError ( context . WebSockets . WebSocketRequestedProtocols ) ) ;
1052
1093
1053
1094
/// <summary>
1054
1095
/// Writes a '415 Invalid Content-Type header: could not be parsed.' message to the output.
@@ -1079,6 +1120,12 @@ protected virtual Task HandleInvalidHttpMethodErrorAsync(HttpContext context, Re
1079
1120
return next ( context ) ;
1080
1121
}
1081
1122
1123
+ /// <summary>
1124
+ /// Writes the specified error as a JSON-formatted GraphQL response.
1125
+ /// </summary>
1126
+ protected virtual Task WriteErrorResponseAsync ( HttpContext context , ExecutionError executionError )
1127
+ => WriteErrorResponseAsync ( context , executionError is IHasPreferredStatusCode withCode ? withCode . PreferredStatusCode : HttpStatusCode . BadRequest , executionError ) ;
1128
+
1082
1129
/// <summary>
1083
1130
/// Writes the specified error message as a JSON-formatted GraphQL response, with the specified HTTP status code.
1084
1131
/// </summary>
0 commit comments