Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions sdk/spring/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ This section includes changes in `spring-cloud-azure-autoconfigure` module.

#### Bugs Fixed

- Fixed Redis Lettuce passwordless autoconfiguration so a user-defined `LettuceClientConfigurationBuilderCustomizer` no longer suppresses the Azure customizer bean that configures Azure Redis credentials and RESP2 support.
- Applied `jwt-connect-timeout` and `jwt-read-timeout` properties to the RestTemplate used by the JWT decoder in AAD and B2C resource server configurations, preventing indefinite hanging when fetching JWK keys ([#49329](https://github.com/Azure/azure-sdk-for-java/pull/49329)).
- Fixed the missing bean name in `@ConditionalOnMissingBean` for `LettuceClientConfigurationBuilderCustomizer` ([#49290](https://github.com/Azure/azure-sdk-for-java/issues/49290)).
Comment thread
rujche marked this conversation as resolved.
- Fixed the AAD and B2C resource server JWT decoder not honoring the `spring.cloud.azure.active-directory.jwt-connect-timeout`, `spring.cloud.azure.active-directory.jwt-read-timeout`, `spring.cloud.azure.active-directory.b2c.jwt-connect-timeout`, and `spring.cloud.azure.active-directory.b2c.jwt-read-timeout` configuration properties, which could cause indefinite hangs when fetching JWK keys ([#49329](https://github.com/Azure/azure-sdk-for-java/pull/49329)).
- Fixed AAD resource server JWK retrieval not honoring the `spring.cloud.azure.active-directory.jwk-set-cache-lifespan` and `spring.cloud.azure.active-directory.jwk-set-cache-refresh-time` configuration properties ([#42159](https://github.com/Azure/azure-sdk-for-java/issues/42159), [#49293](https://github.com/Azure/azure-sdk-for-java/issues/49293)).

#### Other Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
import com.azure.spring.cloud.autoconfigure.implementation.aad.configuration.conditions.ResourceServerCondition;
import com.azure.spring.cloud.autoconfigure.implementation.aad.configuration.properties.AadAuthenticationProperties;
import com.azure.spring.cloud.autoconfigure.implementation.aad.configuration.properties.AadResourceServerProperties;
import com.azure.spring.cloud.autoconfigure.implementation.aad.security.jose.RestOperationsResourceRetriever;
import com.azure.spring.cloud.autoconfigure.implementation.aad.security.constants.AadJwtClaimNames;
import com.azure.spring.cloud.autoconfigure.implementation.aad.security.jwt.AadJwtIssuerValidator;
import com.azure.spring.cloud.autoconfigure.implementation.aad.security.jwt.AadTrustedIssuerRepository;
import com.azure.spring.cloud.autoconfigure.implementation.aad.security.properties.AadAuthorizationServerEndpoints;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.JWKSourceBuilder;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.util.ResourceRetriever;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnExpression;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
Expand All @@ -32,12 +37,14 @@
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.util.StringUtils;

import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;

import static com.azure.spring.cloud.autoconfigure.implementation.aad.security.AadResourceServerHttpSecurityConfigurer.aadResourceServer;
import static com.azure.spring.cloud.autoconfigure.implementation.aad.utils.AadRestTemplateCreator.createRestTemplate;

@Configuration(proxyBeanMethods = false)
@Conditional(ResourceServerCondition.class)
Expand All @@ -56,16 +63,31 @@ JwtDecoder jwtDecoder(AadAuthenticationProperties aadAuthenticationProperties) {
AadAuthorizationServerEndpoints identityEndpoints = new AadAuthorizationServerEndpoints(
aadAuthenticationProperties.getProfile().getEnvironment().getActiveDirectoryEndpoint(), tenantId);
NimbusJwtDecoder nimbusJwtDecoder = NimbusJwtDecoder
.withJwkSetUri(identityEndpoints.getJwkSetEndpoint())
.restOperations(createRestTemplate(restTemplateBuilder
.connectTimeout(aadAuthenticationProperties.getJwtConnectTimeout())
.readTimeout(aadAuthenticationProperties.getJwtReadTimeout())))
.withJwkSource(createJwkSource(identityEndpoints.getJwkSetEndpoint(), aadAuthenticationProperties))
.build();
List<OAuth2TokenValidator<Jwt>> validators = createDefaultValidator(aadAuthenticationProperties);
nimbusJwtDecoder.setJwtValidator(new DelegatingOAuth2TokenValidator<>(validators));
return nimbusJwtDecoder;
}

private JWKSource<SecurityContext> createJwkSource(String jwkSetEndpoint,
AadAuthenticationProperties aadAuthenticationProperties) {
RestTemplateBuilder jwtRestTemplateBuilder = restTemplateBuilder
.connectTimeout(aadAuthenticationProperties.getJwtConnectTimeout())
.readTimeout(aadAuthenticationProperties.getJwtReadTimeout());
ResourceRetriever resourceRetriever = new RestOperationsResourceRetriever(jwtRestTemplateBuilder);
try {
URL jwkSetUrl = URI.create(jwkSetEndpoint).toURL();
return JWKSourceBuilder.create(jwkSetUrl, resourceRetriever)
.cache(aadAuthenticationProperties.getJwkSetCacheLifespan().toMillis(),
aadAuthenticationProperties.getJwkSetCacheRefreshTime().toMillis())
.refreshAheadCache(false)
.build();
Comment thread
rujche marked this conversation as resolved.
Outdated
} catch (MalformedURLException e) {
Comment thread
rujche marked this conversation as resolved.
Outdated
throw new IllegalStateException("Invalid JWK Set endpoint: " + jwkSetEndpoint, e);
}
}

List<OAuth2TokenValidator<Jwt>> createDefaultValidator(AadAuthenticationProperties aadAuthenticationProperties) {
List<OAuth2TokenValidator<Jwt>> validators = new ArrayList<>();
List<String> validAudiences = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@

import com.azure.identity.extensions.implementation.template.AzureAuthenticationTemplate;
import com.azure.spring.cloud.autoconfigure.implementation.aad.configuration.properties.AadAuthenticationProperties;
import com.azure.spring.cloud.autoconfigure.implementation.aad.security.jose.RestOperationsResourceRetriever;
import com.azure.spring.cloud.autoconfigure.implementation.aad.security.jwt.AadJwtIssuerValidator;
import com.azure.spring.cloud.autoconfigure.implementation.aad.security.AadResourceServerHttpSecurityConfigurer;
import com.azure.spring.cloud.autoconfigure.implementation.context.AzureGlobalPropertiesAutoConfiguration;
import com.nimbusds.jose.jwk.source.CachingJWKSetSource;
import com.nimbusds.jose.jwk.source.JWKSourceBuilder;
import com.nimbusds.jose.jwk.source.JWKSetSourceWrapper;
import com.nimbusds.jose.jwk.source.URLBasedJWKSetSource;
import com.nimbusds.jwt.proc.JWTClaimsSetAwareJWSKeySelector;
import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
Expand Down Expand Up @@ -87,6 +91,22 @@ void testJwtDecoderTimeoutDefaultValues() {
});
}

@Test
void testJwtDecoderCacheDefaultValues() {
resourceServerContextRunner()
.withPropertyValues("spring.cloud.azure.active-directory.enabled=true")
.run(context -> {
AadAuthenticationProperties properties = context.getBean(AadAuthenticationProperties.class);
assertThat(properties.getJwkSetCacheLifespan()).isEqualTo(Duration.ofMinutes(5));
assertThat(properties.getJwkSetCacheRefreshTime()).isEqualTo(Duration.ofMinutes(5));

JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class);
verifyJwtDecoderCacheDurations(jwtDecoder,
Duration.ofMinutes(5).toMillis(),
Duration.ofMinutes(5).toMillis());
});
}

@Test
void testJwtDecoderTimeoutCustomValues() {
resourceServerContextRunner()
Expand All @@ -107,6 +127,25 @@ void testJwtDecoderTimeoutCustomValues() {
});
}

@Test
void testJwtDecoderCacheCustomValues() {
resourceServerContextRunner()
.withPropertyValues(
"spring.cloud.azure.active-directory.enabled=true",
"spring.cloud.azure.active-directory.jwk-set-cache-lifespan=12m",
"spring.cloud.azure.active-directory.jwk-set-cache-refresh-time=34s")
.run(context -> {
AadAuthenticationProperties properties = context.getBean(AadAuthenticationProperties.class);
assertThat(properties.getJwkSetCacheLifespan()).isEqualTo(Duration.ofMinutes(12));
assertThat(properties.getJwkSetCacheRefreshTime()).isEqualTo(Duration.ofSeconds(34));

JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class);
verifyJwtDecoderCacheDurations(jwtDecoder,
Duration.ofMinutes(12).toMillis(),
Duration.ofSeconds(34).toMillis());
});
}

@Test
void testNotAudienceDefaultValidator() {
resourceServerRunner()
Expand Down Expand Up @@ -409,35 +448,14 @@ public Collection<GrantedAuthority> convert(Jwt source) {
* Verifies that the RestTemplate used by the NimbusJwtDecoder for JWK retrieval
* has the expected connect and read timeouts applied to its ClientHttpRequestFactory.
*/
Comment thread
rujche marked this conversation as resolved.
@SuppressWarnings("unchecked")
private static void verifyJwtDecoderRestTemplateTimeouts(JwtDecoder jwtDecoder,
int expectedConnectTimeoutMs,
int expectedReadTimeoutMs) {
// NimbusJwtDecoder -> jwtProcessor (DefaultJWTProcessor)
Object jwtProcessor = ReflectionTestUtils.getField(jwtDecoder, "jwtProcessor");
assertThat(jwtProcessor).isInstanceOf(com.nimbusds.jwt.proc.DefaultJWTProcessor.class);

// DefaultJWTProcessor -> JWSKeySelector (JWSVerificationKeySelector)
com.nimbusds.jose.proc.JWSKeySelector<?> keySelector =
((com.nimbusds.jwt.proc.DefaultJWTProcessor<?>) jwtProcessor).getJWSKeySelector();
assertThat(keySelector).isInstanceOf(com.nimbusds.jose.proc.JWSVerificationKeySelector.class);

// JWSVerificationKeySelector -> JWKSource (JWKSetBasedJWKSource)
com.nimbusds.jose.jwk.source.JWKSource<?> jwkSource =
((com.nimbusds.jose.proc.JWSVerificationKeySelector<?>) keySelector).getJWKSource();
assertThat(jwkSource).isInstanceOf(com.nimbusds.jose.jwk.source.JWKSetBasedJWKSource.class);
URLBasedJWKSetSource<?> urlBasedJwkSetSource = getUrlBasedJwkSetSource(jwtDecoder);
Object resourceRetriever = ReflectionTestUtils.getField(urlBasedJwkSetSource, "resourceRetriever");
assertThat(resourceRetriever).isInstanceOf(RestOperationsResourceRetriever.class);

// JWKSetBasedJWKSource -> JWKSetSource (CachingJWKSetSource -> JWKSetSourceWrapper -> actual source)
Object jwkSetSource =
((com.nimbusds.jose.jwk.source.JWKSetBasedJWKSource<?>) jwkSource).getJWKSetSource();

// Unwrap JWKSetSourceWrapper chain to find the source with restOperations
while (jwkSetSource instanceof com.nimbusds.jose.jwk.source.JWKSetSourceWrapper<?> wrapper) {
jwkSetSource = wrapper.getSource();
}

// actual source -> restOperations (RestTemplate)
Object restOperations = ReflectionTestUtils.getField(jwkSetSource, "restOperations");
Object restOperations = ReflectionTestUtils.getField(resourceRetriever, "restOperations");
assertThat(restOperations).isInstanceOf(org.springframework.web.client.RestTemplate.class);

// RestTemplate -> ClientHttpRequestFactory
Expand All @@ -452,4 +470,42 @@ private static void verifyJwtDecoderRestTemplateTimeouts(JwtDecoder jwtDecoder,
assertThat(connectTimeout).isEqualTo(expectedConnectTimeoutMs);
assertThat(readTimeout).isEqualTo(expectedReadTimeoutMs);
}

private static void verifyJwtDecoderCacheDurations(JwtDecoder jwtDecoder,
long expectedCacheLifespanMs,
long expectedCacheRefreshTimeoutMs) {
CachingJWKSetSource<?> cachingJwkSetSource = getCachingJwkSetSource(jwtDecoder);
assertThat(cachingJwkSetSource.getTimeToLive()).isEqualTo(expectedCacheLifespanMs);
assertThat(cachingJwkSetSource.getCacheRefreshTimeout()).isEqualTo(expectedCacheRefreshTimeoutMs);
Comment thread
rujche marked this conversation as resolved.
Outdated
}

private static CachingJWKSetSource<?> getCachingJwkSetSource(JwtDecoder jwtDecoder) {
Object jwkSetSource = getJwkSetSource(jwtDecoder);
assertThat(jwkSetSource).isInstanceOf(CachingJWKSetSource.class);
return (CachingJWKSetSource<?>) jwkSetSource;
}

private static URLBasedJWKSetSource<?> getUrlBasedJwkSetSource(JwtDecoder jwtDecoder) {
Object jwkSetSource = getJwkSetSource(jwtDecoder);
while (jwkSetSource instanceof JWKSetSourceWrapper<?> wrapper) {
jwkSetSource = wrapper.getSource();
}
assertThat(jwkSetSource).isInstanceOf(URLBasedJWKSetSource.class);
return (URLBasedJWKSetSource<?>) jwkSetSource;
}

private static Object getJwkSetSource(JwtDecoder jwtDecoder) {
Object jwtProcessor = ReflectionTestUtils.getField(jwtDecoder, "jwtProcessor");
assertThat(jwtProcessor).isInstanceOf(com.nimbusds.jwt.proc.DefaultJWTProcessor.class);

com.nimbusds.jose.proc.JWSKeySelector<?> keySelector =
((com.nimbusds.jwt.proc.DefaultJWTProcessor<?>) jwtProcessor).getJWSKeySelector();
assertThat(keySelector).isInstanceOf(com.nimbusds.jose.proc.JWSVerificationKeySelector.class);

com.nimbusds.jose.jwk.source.JWKSource<?> jwkSource =
((com.nimbusds.jose.proc.JWSVerificationKeySelector<?>) keySelector).getJWKSource();
assertThat(jwkSource).isInstanceOf(com.nimbusds.jose.jwk.source.JWKSetBasedJWKSource.class);

return ((com.nimbusds.jose.jwk.source.JWKSetBasedJWKSource<?>) jwkSource).getJWKSetSource();
}
}