|
7 | 7 | import com.azure.spring.cloud.autoconfigure.implementation.aad.security.jwt.AadJwtIssuerValidator; |
8 | 8 | import com.azure.spring.cloud.autoconfigure.implementation.aad.security.AadResourceServerHttpSecurityConfigurer; |
9 | 9 | import com.azure.spring.cloud.autoconfigure.implementation.context.AzureGlobalPropertiesAutoConfiguration; |
| 10 | +import com.nimbusds.jose.jwk.source.JWKSourceBuilder; |
10 | 11 | import com.nimbusds.jwt.proc.JWTClaimsSetAwareJWSKeySelector; |
11 | 12 | import org.junit.jupiter.api.Test; |
12 | 13 | import org.springframework.boot.autoconfigure.AutoConfigurations; |
|
32 | 33 | import org.springframework.security.web.SecurityFilterChain; |
33 | 34 | import org.springframework.test.util.ReflectionTestUtils; |
34 | 35 |
|
| 36 | +import java.time.Duration; |
35 | 37 | import java.util.Collection; |
36 | 38 | import java.util.LinkedHashMap; |
37 | 39 | import java.util.List; |
@@ -67,6 +69,44 @@ void testCreateJwtDecoderByJwkKeySetUri() { |
67 | 69 | }); |
68 | 70 | } |
69 | 71 |
|
| 72 | + @Test |
| 73 | + void testJwtDecoderTimeoutDefaultValues() { |
| 74 | + resourceServerContextRunner() |
| 75 | + .withPropertyValues("spring.cloud.azure.active-directory.enabled=true") |
| 76 | + .run(context -> { |
| 77 | + AadAuthenticationProperties properties = context.getBean(AadAuthenticationProperties.class); |
| 78 | + assertThat(properties.getJwtConnectTimeout()) |
| 79 | + .isEqualTo(Duration.ofMillis(JWKSourceBuilder.DEFAULT_HTTP_CONNECT_TIMEOUT)); |
| 80 | + assertThat(properties.getJwtReadTimeout()) |
| 81 | + .isEqualTo(Duration.ofMillis(JWKSourceBuilder.DEFAULT_HTTP_READ_TIMEOUT)); |
| 82 | + // Verify the default timeouts are applied to the RestTemplate used by the JwtDecoder |
| 83 | + final JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class); |
| 84 | + verifyJwtDecoderRestTemplateTimeouts(jwtDecoder, |
| 85 | + JWKSourceBuilder.DEFAULT_HTTP_CONNECT_TIMEOUT, |
| 86 | + JWKSourceBuilder.DEFAULT_HTTP_READ_TIMEOUT); |
| 87 | + }); |
| 88 | + } |
| 89 | + |
| 90 | + @Test |
| 91 | + void testJwtDecoderTimeoutCustomValues() { |
| 92 | + resourceServerContextRunner() |
| 93 | + .withPropertyValues( |
| 94 | + "spring.cloud.azure.active-directory.enabled=true", |
| 95 | + "spring.cloud.azure.active-directory.jwt-connect-timeout=2000", |
| 96 | + "spring.cloud.azure.active-directory.jwt-read-timeout=3000") |
| 97 | + .run(context -> { |
| 98 | + AadAuthenticationProperties properties = context.getBean(AadAuthenticationProperties.class); |
| 99 | + assertThat(properties.getJwtConnectTimeout()).isEqualTo(Duration.ofMillis(2000)); |
| 100 | + assertThat(properties.getJwtReadTimeout()).isEqualTo(Duration.ofMillis(3000)); |
| 101 | + // Verify JwtDecoder is still created successfully with custom timeouts |
| 102 | + final JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class); |
| 103 | + assertThat(jwtDecoder).isNotNull(); |
| 104 | + assertThat(jwtDecoder).isExactlyInstanceOf(NimbusJwtDecoder.class); |
| 105 | + // Verify the configured timeouts are applied to the RestTemplate used by the JwtDecoder |
| 106 | + verifyJwtDecoderRestTemplateTimeouts(jwtDecoder, 2000, 3000); |
| 107 | + }); |
| 108 | + } |
| 109 | + |
70 | 110 | @Test |
71 | 111 | void testNotAudienceDefaultValidator() { |
72 | 112 | resourceServerRunner() |
@@ -364,4 +404,52 @@ public Collection<GrantedAuthority> convert(Jwt source) { |
364 | 404 | return null; |
365 | 405 | } |
366 | 406 | } |
| 407 | + |
| 408 | + /** |
| 409 | + * Verifies that the RestTemplate used by the NimbusJwtDecoder for JWK retrieval |
| 410 | + * has the expected connect and read timeouts applied to its ClientHttpRequestFactory. |
| 411 | + */ |
| 412 | + @SuppressWarnings("unchecked") |
| 413 | + private static void verifyJwtDecoderRestTemplateTimeouts(JwtDecoder jwtDecoder, |
| 414 | + int expectedConnectTimeoutMs, |
| 415 | + int expectedReadTimeoutMs) { |
| 416 | + // NimbusJwtDecoder -> jwtProcessor (DefaultJWTProcessor) |
| 417 | + Object jwtProcessor = ReflectionTestUtils.getField(jwtDecoder, "jwtProcessor"); |
| 418 | + assertThat(jwtProcessor).isInstanceOf(com.nimbusds.jwt.proc.DefaultJWTProcessor.class); |
| 419 | + |
| 420 | + // DefaultJWTProcessor -> JWSKeySelector (JWSVerificationKeySelector) |
| 421 | + com.nimbusds.jose.proc.JWSKeySelector<?> keySelector = |
| 422 | + ((com.nimbusds.jwt.proc.DefaultJWTProcessor<?>) jwtProcessor).getJWSKeySelector(); |
| 423 | + assertThat(keySelector).isInstanceOf(com.nimbusds.jose.proc.JWSVerificationKeySelector.class); |
| 424 | + |
| 425 | + // JWSVerificationKeySelector -> JWKSource (JWKSetBasedJWKSource) |
| 426 | + com.nimbusds.jose.jwk.source.JWKSource<?> jwkSource = |
| 427 | + ((com.nimbusds.jose.proc.JWSVerificationKeySelector<?>) keySelector).getJWKSource(); |
| 428 | + assertThat(jwkSource).isInstanceOf(com.nimbusds.jose.jwk.source.JWKSetBasedJWKSource.class); |
| 429 | + |
| 430 | + // JWKSetBasedJWKSource -> JWKSetSource (CachingJWKSetSource -> JWKSetSourceWrapper -> actual source) |
| 431 | + Object jwkSetSource = |
| 432 | + ((com.nimbusds.jose.jwk.source.JWKSetBasedJWKSource<?>) jwkSource).getJWKSetSource(); |
| 433 | + |
| 434 | + // Unwrap JWKSetSourceWrapper chain to find the source with restOperations |
| 435 | + while (jwkSetSource instanceof com.nimbusds.jose.jwk.source.JWKSetSourceWrapper<?> wrapper) { |
| 436 | + jwkSetSource = wrapper.getSource(); |
| 437 | + } |
| 438 | + |
| 439 | + // actual source -> restOperations (RestTemplate) |
| 440 | + Object restOperations = ReflectionTestUtils.getField(jwkSetSource, "restOperations"); |
| 441 | + assertThat(restOperations).isInstanceOf(org.springframework.web.client.RestTemplate.class); |
| 442 | + |
| 443 | + // RestTemplate -> ClientHttpRequestFactory |
| 444 | + org.springframework.http.client.ClientHttpRequestFactory requestFactory = |
| 445 | + ((org.springframework.web.client.RestTemplate) restOperations).getRequestFactory(); |
| 446 | + |
| 447 | + // Verify timeouts on the request factory (may be stored as Duration or int) |
| 448 | + Object connectTimeoutValue = ReflectionTestUtils.getField(requestFactory, "connectTimeout"); |
| 449 | + Object readTimeoutValue = ReflectionTestUtils.getField(requestFactory, "readTimeout"); |
| 450 | + int connectTimeout = connectTimeoutValue instanceof java.time.Duration d ? (int) d.toMillis() : (int) connectTimeoutValue; |
| 451 | + int readTimeout = readTimeoutValue instanceof java.time.Duration d ? (int) d.toMillis() : (int) readTimeoutValue; |
| 452 | + assertThat(connectTimeout).isEqualTo(expectedConnectTimeoutMs); |
| 453 | + assertThat(readTimeout).isEqualTo(expectedReadTimeoutMs); |
| 454 | + } |
367 | 455 | } |
0 commit comments