Skip to content

Commit 0d41e22

Browse files
committed
Redirect to dex only for browsers, otherwise return 401.
Adds a test to validate that all endpoints require auth
1 parent 1efa110 commit 0d41e22

3 files changed

Lines changed: 149 additions & 0 deletions

File tree

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package net.neoforged.meta.security;
2+
3+
import jakarta.servlet.ServletException;
4+
import jakarta.servlet.http.HttpServletRequest;
5+
import jakarta.servlet.http.HttpServletResponse;
6+
import org.springframework.http.MediaType;
7+
import org.springframework.security.core.AuthenticationException;
8+
import org.springframework.security.web.AuthenticationEntryPoint;
9+
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
10+
11+
import java.io.IOException;
12+
13+
/**
14+
* Custom authentication entry point that redirects browsers to OAuth login
15+
* but returns 401 Unauthorized for non-browser clients (API clients).
16+
* <p>
17+
* Determines if a request is from a browser by checking if the Accept header
18+
* contains "text/html".
19+
*/
20+
public class BrowserAwareAuthenticationEntryPoint implements AuthenticationEntryPoint {
21+
22+
private final LoginUrlAuthenticationEntryPoint browserEntryPoint;
23+
24+
public BrowserAwareAuthenticationEntryPoint(String loginFormUrl) {
25+
this.browserEntryPoint = new LoginUrlAuthenticationEntryPoint(loginFormUrl);
26+
}
27+
28+
@Override
29+
public void commence(HttpServletRequest request, HttpServletResponse response,
30+
AuthenticationException authException) throws IOException, ServletException {
31+
32+
// Check if the request accepts HTML (indicates a browser)
33+
if (isBrowserDocumentRequest(request)) {
34+
// Redirect to OAuth login page for browsers
35+
browserEntryPoint.commence(request, response, authException);
36+
} else {
37+
// Return 401 Unauthorized for API clients
38+
response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "Unauthorized");
39+
}
40+
}
41+
42+
/**
43+
* We heuristically consider any request that explicitly asks for text/html without weighting a browser request.
44+
*/
45+
private boolean isBrowserDocumentRequest(HttpServletRequest request) {
46+
var mediaTypes = MediaType.parseMediaTypes(request.getHeader("Accept"));
47+
return mediaTypes.contains(MediaType.TEXT_HTML);
48+
}
49+
}

src/main/java/net/neoforged/meta/security/SecurityConfiguration.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ public SecurityFilterChain uiSecurityFilterChain(HttpSecurity http) {
7878
.sessionManagement(session -> session
7979
.sessionCreationPolicy(SessionCreationPolicy.IF_REQUIRED)
8080
)
81+
.exceptionHandling(exceptions -> exceptions
82+
.authenticationEntryPoint(new BrowserAwareAuthenticationEntryPoint("/oauth2/authorization/dex"))
83+
)
8184
.build();
8285
}
8386
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package net.neoforged.meta;
2+
3+
import org.junit.jupiter.params.ParameterizedTest;
4+
import org.junit.jupiter.params.provider.Arguments;
5+
import org.junit.jupiter.params.provider.MethodSource;
6+
import org.springframework.beans.factory.annotation.Autowired;
7+
import org.springframework.boot.test.context.SpringBootTest;
8+
import org.springframework.boot.test.web.server.LocalServerPort;
9+
import org.springframework.http.HttpMethod;
10+
import org.springframework.test.annotation.DirtiesContext;
11+
import org.springframework.test.context.ActiveProfiles;
12+
import org.springframework.web.bind.annotation.RequestMethod;
13+
import org.springframework.web.client.RestClient;
14+
import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping;
15+
import org.springframework.web.util.UrlPathHelper;
16+
17+
import java.util.HashSet;
18+
import java.util.Set;
19+
import java.util.stream.Stream;
20+
21+
import static org.junit.jupiter.api.Assertions.assertEquals;
22+
import static org.junit.jupiter.api.Assertions.assertNotNull;
23+
24+
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
25+
@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_CLASS)
26+
@ActiveProfiles("test")
27+
public class TestAuthenticatedEndpoints {
28+
29+
@LocalServerPort
30+
private int port;
31+
32+
@ParameterizedTest
33+
@MethodSource("provideEndpoints")
34+
void testAllEndpointsAuthenticatedForApiConsumers(String method, String path) throws Exception {
35+
var restClient = RestClient.create();
36+
var baseUrl = "http://localhost:" + port;
37+
38+
try (var response = restClient
39+
.method(HttpMethod.valueOf(method))
40+
.uri(baseUrl + path)
41+
.exchange((_, clientResponse) -> clientResponse)) {
42+
int statusCode = response.getStatusCode().value();
43+
assertEquals(401, statusCode,
44+
"Expected 401 Unauthorized for " + method + " " + path + " but got " + statusCode);
45+
}
46+
}
47+
48+
@ParameterizedTest
49+
@MethodSource("provideEndpoints")
50+
void testAllEndpointsAuthenticatedForBrowsers(String method, String path) throws Exception {
51+
var restClient = RestClient.create();
52+
var baseUrl = "http://localhost:" + port;
53+
54+
try (var response = restClient
55+
.method(HttpMethod.valueOf(method))
56+
.uri(baseUrl + path)
57+
.header("Accept", "text/html")
58+
.exchange((_, clientResponse) -> clientResponse)) {
59+
int statusCode = response.getStatusCode().value();
60+
assertEquals(302, statusCode,
61+
"Expected redirect for " + method + " " + path + " but got " + statusCode);
62+
String location = response.getHeaders().getFirst("Location");
63+
assertNotNull(location, "location");
64+
location = new UrlPathHelper().removeSemicolonContent(location); // sometimes ;jsessionid= is appended
65+
assertEquals(baseUrl + "/oauth2/authorization/dex", location);
66+
}
67+
}
68+
69+
/**
70+
* Automatically discover all registered API endpoints from Spring.
71+
*/
72+
static Stream<Arguments> provideEndpoints(
73+
@Autowired RequestMappingHandlerMapping requestMappingHandlerMapping) {
74+
return requestMappingHandlerMapping.getHandlerMethods().entrySet().stream()
75+
.flatMap(entry -> {
76+
var mapping = entry.getKey();
77+
var patterns = new HashSet<>(mapping.getPatternValues());
78+
var httpMethods = mapping.getMethodsCondition().getMethods();
79+
80+
patterns.remove("/"); // The index page is white-listed to be accessible anonymously
81+
82+
// If no methods specified, default to GET
83+
final var finalMethods = httpMethods.isEmpty()
84+
? Set.of(RequestMethod.GET)
85+
: httpMethods;
86+
87+
return patterns.stream()
88+
.flatMap(pattern -> finalMethods.stream()
89+
.map(requestMethod -> {
90+
// Replace path variables with dummy values
91+
var resolvedPath = pattern.replaceAll("\\{[^}]+}", "123");
92+
return Arguments.of(requestMethod.name(), resolvedPath);
93+
}));
94+
})
95+
.distinct();
96+
}
97+
}

0 commit comments

Comments
 (0)