Skip to content

Commit 44ac392

Browse files
KeesVraman-m
authored andcommitted
Return the auth challenge in the WWW-Authenticate header on authentication failure
1 parent d310508 commit 44ac392

File tree

8 files changed

+253
-151
lines changed

8 files changed

+253
-151
lines changed

src/Ocelot/Authentication/Middleware/AuthenticationMiddleware.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
using Ocelot.Configuration;
44
using Ocelot.Logging;
55
using Ocelot.Middleware;
6+
using System.Runtime.Remoting.Contexts;
7+
using System.Threading.Tasks;
68

79
namespace Ocelot.Authentication.Middleware
810
{
@@ -36,6 +38,7 @@ public async Task Invoke(HttpContext httpContext)
3638

3739
if (result.Principal?.Identity == null)
3840
{
41+
await ChallengeAsync(httpContext, downstreamRoute);
3942
SetUnauthenticatedError(httpContext, path, null);
4043
return;
4144
}
@@ -49,6 +52,7 @@ public async Task Invoke(HttpContext httpContext)
4952
return;
5053
}
5154

55+
await ChallengeAsync(httpContext, downstreamRoute);
5256
SetUnauthenticatedError(httpContext, path, httpContext.User.Identity.Name);
5357
}
5458

@@ -59,6 +63,18 @@ private void SetUnauthenticatedError(HttpContext httpContext, string path, strin
5963
httpContext.Items.SetError(error);
6064
}
6165

66+
private async Task ChallengeAsync(HttpContext context, DownstreamRoute route)
67+
{
68+
// Perform a challenge. This populates the WWW-Authenticate header on the response
69+
await context.ChallengeAsync(route.AuthenticationOptions.AuthenticationProviderKey);
70+
71+
// Since the response gets re-created down the pipeline, we store the challenge in the Items, so we can re-apply it when sending the response
72+
if (context.Response.Headers.TryGetValue("WWW-Authenticate", out var authenticateHeader))
73+
{
74+
context.Items.SetAuthChallenge(authenticateHeader);
75+
}
76+
}
77+
6278
private async Task<AuthenticateResult> AuthenticateAsync(HttpContext context, DownstreamRoute route)
6379
{
6480
var options = route.AuthenticationOptions;

src/Ocelot/Middleware/HttpItemsExtensions.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ public static void SetError(this IDictionary<object, object> input, Error error)
4343
input.Upsert("Errors", errors);
4444
}
4545

46+
public static void SetAuthChallenge(this IDictionary<object, object> input, string challengeString) =>
47+
input.Upsert("AuthChallenge", challengeString);
48+
49+
public static string AuthChallenge(this IDictionary<object, object> input) =>
50+
input.Get<string>("AuthChallenge");
51+
4652
public static void SetIInternalConfiguration(this IDictionary<object, object> input, IInternalConfiguration config)
4753
{
4854
input.Upsert("IInternalConfiguration", config);

src/Ocelot/Multiplexer/MultiplexingMiddleware.cs

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
1-
using Microsoft.AspNetCore.Http;
1+
using Microsoft.AspNetCore.Http;
22
using Microsoft.Extensions.Primitives;
33
using Newtonsoft.Json.Linq;
4-
using Ocelot.Configuration;
4+
using Ocelot.Configuration;
55
using Ocelot.Configuration.File;
6-
using Ocelot.DownstreamRouteFinder.UrlMatcher;
7-
using Ocelot.Logging;
8-
using Ocelot.Middleware;
6+
using Ocelot.DownstreamRouteFinder.UrlMatcher;
7+
using Ocelot.Logging;
8+
using Ocelot.Middleware;
99
using System.Collections;
1010
using Route = Ocelot.Configuration.Route;
11-
11+
1212
namespace Ocelot.Multiplexer;
1313

1414
public class MultiplexingMiddleware : OcelotMiddleware
1515
{
1616
private readonly RequestDelegate _next;
1717
private readonly IResponseAggregatorFactory _factory;
1818
private const string RequestIdString = "RequestId";
19-
19+
2020
public MultiplexingMiddleware(RequestDelegate next,
2121
IOcelotLoggerFactory loggerFactory,
2222
IResponseAggregatorFactory factory)
@@ -25,7 +25,7 @@ public MultiplexingMiddleware(RequestDelegate next,
2525
_factory = factory;
2626
_next = next;
2727
}
28-
28+
2929
public async Task Invoke(HttpContext httpContext)
3030
{
3131
var downstreamRouteHolder = httpContext.Items.DownstreamRouteHolder();
@@ -38,37 +38,37 @@ public async Task Invoke(HttpContext httpContext)
3838
await ProcessSingleRouteAsync(httpContext, downstreamRoutes[0]);
3939
return;
4040
}
41-
41+
4242
// Case 2: if no downstream routes
4343
if (downstreamRoutes.Count == 0)
4444
{
4545
return;
4646
}
47-
47+
4848
// Case 3: if multiple downstream routes
4949
var routeKeysConfigs = route.DownstreamRouteConfig;
5050
if (routeKeysConfigs == null || routeKeysConfigs.Count == 0)
5151
{
5252
await ProcessRoutesAsync(httpContext, route);
5353
return;
5454
}
55-
55+
5656
// Case 4: if multiple downstream routes with route keys
5757
var mainResponseContext = await ProcessMainRouteAsync(httpContext, downstreamRoutes[0]);
5858
if (mainResponseContext == null)
5959
{
6060
return;
6161
}
62-
62+
6363
var responsesContexts = await ProcessRoutesWithRouteKeysAsync(httpContext, downstreamRoutes, routeKeysConfigs, mainResponseContext);
6464
if (responsesContexts.Length == 0)
6565
{
6666
return;
6767
}
68-
68+
6969
await MapResponsesAsync(httpContext, route, mainResponseContext, responsesContexts);
7070
}
71-
71+
7272
/// <summary>
7373
/// Helper method to determine if only the first downstream route should be processed.
7474
/// It is the case if the request is a websocket request or if there is only one downstream route.
@@ -78,7 +78,7 @@ public async Task Invoke(HttpContext httpContext)
7878
/// <returns>True if only the first downstream route should be processed.</returns>
7979
private static bool ShouldProcessSingleRoute(HttpContext context, ICollection routes)
8080
=> context.WebSockets.IsWebSocketRequest || routes.Count == 1;
81-
81+
8282
/// <summary>
8383
/// Processing a single downstream route (no route keys).
8484
/// In that case, no need to make copies of the http context.
@@ -89,9 +89,10 @@ private static bool ShouldProcessSingleRoute(HttpContext context, ICollection ro
8989
protected virtual Task ProcessSingleRouteAsync(HttpContext context, DownstreamRoute route)
9090
{
9191
context.Items.UpsertDownstreamRoute(route);
92+
context.Items.SetAuthChallenge(/*finished*/context.Items.AuthChallenge());
9293
return _next.Invoke(context);
9394
}
94-
95+
9596
/// <summary>
9697
/// Processing the downstream routes (no route keys).
9798
/// </summary>
@@ -105,7 +106,7 @@ private async Task ProcessRoutesAsync(HttpContext context, Route route)
105106
var contexts = await Task.WhenAll(tasks);
106107
await MapAsync(context, route, new(contexts));
107108
}
108-
109+
109110
/// <summary>
110111
/// When using route keys, the first route is the main route and the rest are additional routes.
111112
/// Since we need to break if the main route response is null, we must process the main route first.
@@ -119,7 +120,7 @@ private async Task<HttpContext> ProcessMainRouteAsync(HttpContext context, Downs
119120
await _next.Invoke(context);
120121
return context;
121122
}
122-
123+
123124
/// <summary>
124125
/// Processing the downstream routes with route keys except the main route that has already been processed.
125126
/// </summary>
@@ -133,7 +134,7 @@ protected virtual async Task<HttpContext[]> ProcessRoutesWithRouteKeysAsync(Http
133134
var processing = new List<Task<HttpContext>>();
134135
var content = await mainResponse.Items.DownstreamResponse().Content.ReadAsStringAsync();
135136
var jObject = JToken.Parse(content);
136-
137+
137138
foreach (var downstreamRoute in routes.Skip(1))
138139
{
139140
var matchAdvancedAgg = routeKeysConfigs.FirstOrDefault(q => q.RouteKey == downstreamRoute.Key);
@@ -142,13 +143,13 @@ protected virtual async Task<HttpContext[]> ProcessRoutesWithRouteKeysAsync(Http
142143
processing.AddRange(ProcessRouteWithComplexAggregation(matchAdvancedAgg, jObject, context, downstreamRoute));
143144
continue;
144145
}
145-
146+
146147
processing.Add(ProcessRouteAsync(context, downstreamRoute));
147148
}
148-
149+
149150
return await Task.WhenAll(processing);
150151
}
151-
152+
152153
/// <summary>
153154
/// Mapping responses.
154155
/// </summary>
@@ -158,7 +159,7 @@ private Task MapResponsesAsync(HttpContext context, Route route, HttpContext mai
158159
contexts.AddRange(responsesContexts);
159160
return MapAsync(context, route, contexts);
160161
}
161-
162+
162163
/// <summary>
163164
/// Processing a route with aggregation.
164165
/// </summary>
@@ -173,7 +174,7 @@ private IEnumerable<Task<HttpContext>> ProcessRouteWithComplexAggregation(Aggreg
173174
tPnv.Add(new PlaceholderNameAndValue('{' + matchAdvancedAgg.Parameter + '}', value));
174175
processing.Add(ProcessRouteAsync(httpContext, downstreamRoute, tPnv));
175176
}
176-
177+
177178
return processing;
178179
}
179180

@@ -186,11 +187,11 @@ private async Task<HttpContext> ProcessRouteAsync(HttpContext sourceContext, Dow
186187
var newHttpContext = await CreateThreadContextAsync(sourceContext, route);
187188
CopyItemsToNewContext(newHttpContext, sourceContext, placeholders);
188189
newHttpContext.Items.UpsertDownstreamRoute(route);
189-
190+
190191
await _next.Invoke(newHttpContext);
191192
return newHttpContext;
192193
}
193-
194+
194195
/// <summary>
195196
/// Copying some needed parameters to the Http context items.
196197
/// </summary>
@@ -247,7 +248,7 @@ protected virtual async Task<HttpContext> CreateThreadContextAsync(HttpContext s
247248
target.Response.RegisterForDisposeAsync(bodyStream); // manage Stream lifetime by HttpResponse object
248249
return target;
249250
}
250-
251+
251252
protected virtual Task MapAsync(HttpContext httpContext, Route route, List<HttpContext> contexts)
252253
{
253254
if (route.DownstreamRoute.Count == 1)
@@ -282,4 +283,4 @@ protected virtual async Task<Stream> CloneRequestBodyAsync(HttpRequest request,
282283

283284
return targetBuffer;
284285
}
285-
}
286+
}

src/Ocelot/Responder/HttpContextResponder.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using Microsoft.Extensions.Primitives;
44
using Ocelot.Headers;
55
using Ocelot.Middleware;
6+
using System.Runtime.Remoting.Messaging;
67

78
namespace Ocelot.Responder;
89

@@ -77,6 +78,11 @@ public async Task SetErrorResponseOnContext(HttpContext context, DownstreamRespo
7778
}
7879
}
7980

81+
public void SetAuthChallengeOnContext(HttpContext context, string challenge)
82+
{
83+
AddHeaderIfDoesntExist(context, new Header("WWW-Authenticate", new[] { challenge }));
84+
}
85+
8086
private static void SetStatusCode(HttpContext context, int statusCode)
8187
{
8288
if (!context.Response.HasStarted)

src/Ocelot/Responder/IHttpResponder.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Microsoft.AspNetCore.Http;
22
using Ocelot.Middleware;
3-
3+
44
namespace Ocelot.Responder
55
{
66
public interface IHttpResponder
@@ -10,5 +10,7 @@ public interface IHttpResponder
1010
void SetErrorResponseOnContext(HttpContext context, int statusCode);
1111

1212
Task SetErrorResponseOnContext(HttpContext context, DownstreamResponse response);
13+
14+
void SetAuthChallengeOnContext(HttpContext context, string challenge);
1315
}
1416
}

src/Ocelot/Responder/Middleware/ResponderMiddleware.cs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,20 @@ private async Task SetErrorResponse(HttpContext context, List<Error> errors)
6161
var statusCode = _codeMapper.Map(errors);
6262
_responder.SetErrorResponseOnContext(context, statusCode);
6363

64-
if (errors.All(e => e.Code != OcelotErrorCode.QuotaExceededError))
64+
if (errors.Any(e => e.Code == OcelotErrorCode.QuotaExceededError))
6565
{
66-
return;
66+
var downstreamResponse = context.Items.DownstreamResponse();
67+
await _responder.SetErrorResponseOnContext(context, downstreamResponse);
6768
}
6869

69-
var downstreamResponse = context.Items.DownstreamResponse();
70-
await _responder.SetErrorResponseOnContext(context, downstreamResponse);
70+
if (errors.Any(e => e.Code == OcelotErrorCode.UnauthenticatedError))
71+
{
72+
var challenge = context.Items.AuthChallenge();
73+
if (!string.IsNullOrEmpty(challenge))
74+
{
75+
_responder.SetAuthChallengeOnContext(context, challenge);
76+
}
77+
}
7178
}
7279
}
7380
}

test/Ocelot.AcceptanceTests/Authentication/AuthenticationTests.cs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
using IdentityServer4.Models;
33
using Microsoft.AspNetCore.Hosting;
44
using Microsoft.AspNetCore.Http;
5+
using Ocelot.Configuration.File;
6+
using System.Net.Http;
57

68
namespace Ocelot.AcceptanceTests.Authentication
79
{
@@ -112,6 +114,68 @@ public void Should_return_201_using_identity_server_reference_token()
112114
.BDDfy();
113115
}
114116

117+
[Fact]
118+
[Trait("Feat", "1387")]
119+
public void Should_return_www_authenticate_header_on_401()
120+
{
121+
var port = PortFinder.GetRandomPort();
122+
var route = GivenDefaultAuthRoute(port);
123+
var configuration = GivenConfiguration(route);
124+
this.Given(x => GivenThereIsAConfiguration(configuration))
125+
.And(x => GivenOcelotIsRunningWithJwtAuth("Test"))
126+
.And(x => GivenIHaveNoTokenForMyRequest())
127+
.When(x => WhenIGetUrlOnTheApiGateway("/"))
128+
.Then(x => ThenTheStatusCodeShouldBe(HttpStatusCode.Unauthorized))
129+
.And(x => ThenTheResponseShouldContainAuthChallenge())
130+
.BDDfy();
131+
}
132+
133+
public void GivenOcelotIsRunningWithJwtAuth(string authenticationProviderKey)
134+
{
135+
var builder = new ConfigurationBuilder()
136+
.SetBasePath(Directory.GetCurrentDirectory())
137+
.AddJsonFile("appsettings.json", optional: true, reloadOnChange: false)
138+
.AddJsonFile("ocelot.json", false, false)
139+
.AddEnvironmentVariables();
140+
141+
var configuration = builder.Build();
142+
_webHostBuilder = new WebHostBuilder();
143+
_webHostBuilder.ConfigureServices(s =>
144+
{
145+
s.AddSingleton(_webHostBuilder);
146+
});
147+
148+
_ocelotServer = new TestServer(_webHostBuilder
149+
.UseConfiguration(configuration)
150+
.ConfigureServices(s =>
151+
{
152+
s.AddAuthentication().AddJwtBearer(authenticationProviderKey, options =>
153+
{
154+
});
155+
s.AddOcelot(configuration);
156+
})
157+
.ConfigureLogging(l =>
158+
{
159+
l.AddConsole();
160+
l.AddDebug();
161+
})
162+
.Configure(a =>
163+
{
164+
a.UseOcelot().Wait();
165+
}));
166+
167+
_ocelotClient = _ocelotServer.CreateClient();
168+
}
169+
public void GivenIHaveNoTokenForMyRequest()
170+
{
171+
_ocelotClient.DefaultRequestHeaders.Authorization = null;
172+
}
173+
public void ThenTheResponseShouldContainAuthChallenge()
174+
{
175+
_response.Headers.TryGetValues("WWW-Authenticate", out var headerValue).ShouldBeTrue();
176+
headerValue.ShouldNotBeEmpty();
177+
}
178+
115179
[IgnorePublicMethod]
116180
public async Task GivenThereIsAnIdentityServerOn(string url, AccessTokenType tokenType)
117181
{

0 commit comments

Comments
 (0)