Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions sdk/spring/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ This section includes changes in `spring-cloud-azure-autoconfigure` module.

#### Bugs Fixed

- Fixed Redis Lettuce passwordless autoconfiguration so a user-defined `LettuceClientConfigurationBuilderCustomizer` no longer suppresses the Azure customizer bean that configures Azure Redis credentials and RESP2 support.
- Applied `jwt-connect-timeout` and `jwt-read-timeout` properties to the RestTemplate used by the JWT decoder in AAD and B2C resource server configurations, preventing indefinite hanging when fetching JWK keys ([#49329](https://github.com/Azure/azure-sdk-for-java/pull/49329)).
- Fixed the missing bean name in `@ConditionalOnMissingBean` for `LettuceClientConfigurationBuilderCustomizer` ([#49290](https://github.com/Azure/azure-sdk-for-java/issues/49290)).
Comment thread
rujche marked this conversation as resolved.
- Fixed the AAD and B2C resource server JWT decoder not honoring the `spring.cloud.azure.active-directory.jwt-connect-timeout`, `spring.cloud.azure.active-directory.jwt-read-timeout`, `spring.cloud.azure.active-directory.b2c.jwt-connect-timeout`, and `spring.cloud.azure.active-directory.b2c.jwt-read-timeout` configuration properties ([#49329](https://github.com/Azure/azure-sdk-for-java/pull/49329)).
- Fixed AAD resource server JWK retrieval not honoring the `spring.cloud.azure.active-directory.jwk-set-cache-lifespan` and `spring.cloud.azure.active-directory.jwk-set-cache-refresh-time` configuration properties ([#42159](https://github.com/Azure/azure-sdk-for-java/issues/42159)).

#### Other Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ stages:
TestOptions: ${{ parameters.TestOptions }}
MatrixConfigs: ${{ parameters.MatrixConfigs }}
PreGenerationSteps:
- script: |
python -m pip install termcolor
displayName: 'Install python module'
- script: |
python ./sdk/spring/scripts/compatibility_update_supported_version_matrix_json.py -mcp ${{ parameters.MatrixConfigs[0].Path }}
displayName: 'Update supported Spring versions'
Expand All @@ -62,9 +59,6 @@ stages:
- 'sdk/spring'
PreSteps:
- ${{ parameters.PreSteps }}
- script: |
python -m pip install termcolor requests
displayName: 'Install python modules'
- bash: |
echo "##vso[task.setVariable variable=SPRING_CLOUD_AZURE_TEST_SUPPORTED_SPRING_CLOUD_VERSION]$(python ./sdk/spring/scripts/compatibility_get_spring_cloud_version.py -b $(SPRING_CLOUD_AZURE_TEST_SUPPORTED_SPRING_BOOT_VERSION))"
displayName: 'Set supported Spring version to environment variables'
Expand Down
7 changes: 6 additions & 1 deletion sdk/spring/scripts/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@


import os
from termcolor import colored

try:
from termcolor import colored
except ImportError:
def colored(content, _):
return content
Comment thread
rujche marked this conversation as resolved.


class Log:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@
import com.azure.spring.cloud.autoconfigure.implementation.aad.configuration.conditions.ResourceServerCondition;
import com.azure.spring.cloud.autoconfigure.implementation.aad.configuration.properties.AadAuthenticationProperties;
import com.azure.spring.cloud.autoconfigure.implementation.aad.configuration.properties.AadResourceServerProperties;
import com.azure.spring.cloud.autoconfigure.implementation.aad.security.jose.RestOperationsResourceRetriever;
import com.azure.spring.cloud.autoconfigure.implementation.aad.security.constants.AadJwtClaimNames;
import com.azure.spring.cloud.autoconfigure.implementation.aad.security.jwt.AadJwtIssuerValidator;
import com.azure.spring.cloud.autoconfigure.implementation.aad.security.jwt.AadTrustedIssuerRepository;
import com.azure.spring.cloud.autoconfigure.implementation.aad.security.properties.AadAuthorizationServerEndpoints;
import com.nimbusds.jose.jwk.source.DefaultJWKSetCache;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.JWKSetCache;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.util.ResourceRetriever;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnExpression;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
Expand All @@ -32,12 +39,15 @@
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.util.StringUtils;

import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.TimeUnit;

import static com.azure.spring.cloud.autoconfigure.implementation.aad.security.AadResourceServerHttpSecurityConfigurer.aadResourceServer;
import static com.azure.spring.cloud.autoconfigure.implementation.aad.utils.AadRestTemplateCreator.createRestTemplate;

@Configuration(proxyBeanMethods = false)
@Conditional(ResourceServerCondition.class)
Expand All @@ -56,16 +66,32 @@ JwtDecoder jwtDecoder(AadAuthenticationProperties aadAuthenticationProperties) {
AadAuthorizationServerEndpoints identityEndpoints = new AadAuthorizationServerEndpoints(
aadAuthenticationProperties.getProfile().getEnvironment().getActiveDirectoryEndpoint(), tenantId);
NimbusJwtDecoder nimbusJwtDecoder = NimbusJwtDecoder
.withJwkSetUri(identityEndpoints.getJwkSetEndpoint())
.restOperations(createRestTemplate(restTemplateBuilder
.connectTimeout(aadAuthenticationProperties.getJwtConnectTimeout())
.readTimeout(aadAuthenticationProperties.getJwtReadTimeout())))
.withJwkSource(createJwkSource(identityEndpoints.getJwkSetEndpoint(), aadAuthenticationProperties))
.build();
List<OAuth2TokenValidator<Jwt>> validators = createDefaultValidator(aadAuthenticationProperties);
nimbusJwtDecoder.setJwtValidator(new DelegatingOAuth2TokenValidator<>(validators));
return nimbusJwtDecoder;
}

@SuppressWarnings("deprecation")
private JWKSource<SecurityContext> createJwkSource(String jwkSetEndpoint,
AadAuthenticationProperties aadAuthenticationProperties) {
RestTemplateBuilder jwtRestTemplateBuilder = restTemplateBuilder
.connectTimeout(aadAuthenticationProperties.getJwtConnectTimeout())
.readTimeout(aadAuthenticationProperties.getJwtReadTimeout());
ResourceRetriever resourceRetriever = new RestOperationsResourceRetriever(jwtRestTemplateBuilder);
JWKSetCache jwkSetCache = new DefaultJWKSetCache(
aadAuthenticationProperties.getJwkSetCacheLifespan().toMillis(),
aadAuthenticationProperties.getJwkSetCacheRefreshTime().toMillis(),
TimeUnit.MILLISECONDS);
try {
URL jwkSetUrl = URI.create(jwkSetEndpoint).toURL();
return new RemoteJWKSet<>(jwkSetUrl, resourceRetriever, jwkSetCache);
} catch (IllegalArgumentException | MalformedURLException e) {
throw new IllegalStateException("Invalid JWK Set endpoint: " + jwkSetEndpoint, e);
}
}

List<OAuth2TokenValidator<Jwt>> createDefaultValidator(AadAuthenticationProperties aadAuthenticationProperties) {
List<OAuth2TokenValidator<Jwt>> validators = new ArrayList<>();
List<String> validAudiences = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
import com.azure.spring.cloud.autoconfigure.implementation.aad.RecordingClientHttpRequestFactoryBuilderConfiguration;
import com.azure.spring.cloud.autoconfigure.implementation.aad.RecordingClientHttpRequestFactoryBuilderConfiguration.RecordingClientHttpRequestFactoryBuilder;
import com.azure.spring.cloud.autoconfigure.implementation.aad.configuration.properties.AadAuthenticationProperties;
import com.azure.spring.cloud.autoconfigure.implementation.aad.security.AadResourceServerHttpSecurityConfigurer;
import com.azure.spring.cloud.autoconfigure.implementation.aad.security.jose.RestOperationsResourceRetriever;
import com.azure.spring.cloud.autoconfigure.implementation.aad.security.jwt.AadJwtIssuerValidator;
import com.azure.spring.cloud.autoconfigure.implementation.aad.security.AadResourceServerHttpSecurityConfigurer;
import com.azure.spring.cloud.autoconfigure.implementation.context.AzureGlobalPropertiesAutoConfiguration;
import com.nimbusds.jose.jwk.source.DefaultJWKSetCache;
import com.nimbusds.jose.jwk.source.JWKSetCache;
import com.nimbusds.jose.jwk.source.JWKSourceBuilder;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jwt.proc.JWTClaimsSetAwareJWSKeySelector;
import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
Expand Down Expand Up @@ -92,6 +96,22 @@ void testJwtDecoderTimeoutDefaultValues() {
});
}

@Test
void testJwtDecoderCacheDefaultValues() {
resourceServerContextRunner()
.withPropertyValues("spring.cloud.azure.active-directory.enabled=true")
.run(context -> {
AadAuthenticationProperties properties = context.getBean(AadAuthenticationProperties.class);
assertThat(properties.getJwkSetCacheLifespan()).isEqualTo(Duration.ofMinutes(5));
assertThat(properties.getJwkSetCacheRefreshTime()).isEqualTo(Duration.ofMinutes(5));

JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class);
verifyJwtDecoderCacheDurations(jwtDecoder,
Duration.ofMinutes(5).toMillis(),
Duration.ofMinutes(5).toMillis());
});
}

@Test
void testJwtDecoderTimeoutCustomValues() {
resourceServerContextRunner()
Expand All @@ -113,6 +133,25 @@ void testJwtDecoderTimeoutCustomValues() {
});
}

@Test
void testJwtDecoderCacheCustomValues() {
resourceServerContextRunner()
.withPropertyValues(
"spring.cloud.azure.active-directory.enabled=true",
"spring.cloud.azure.active-directory.jwk-set-cache-lifespan=12m",
"spring.cloud.azure.active-directory.jwk-set-cache-refresh-time=34s")
.run(context -> {
AadAuthenticationProperties properties = context.getBean(AadAuthenticationProperties.class);
assertThat(properties.getJwkSetCacheLifespan()).isEqualTo(Duration.ofMinutes(12));
assertThat(properties.getJwkSetCacheRefreshTime()).isEqualTo(Duration.ofSeconds(34));

JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class);
verifyJwtDecoderCacheDurations(jwtDecoder,
Duration.ofMinutes(12).toMillis(),
Duration.ofSeconds(34).toMillis());
});
}

@Test
void testNotAudienceDefaultValidator() {
resourceServerRunner()
Expand Down Expand Up @@ -415,53 +454,56 @@ public Collection<GrantedAuthority> convert(Jwt source) {
* Verifies that the RestTemplate used by the NimbusJwtDecoder for JWK retrieval
* has the expected connect and read timeouts applied to its ClientHttpRequestFactory.
*/
Comment thread
rujche marked this conversation as resolved.
@SuppressWarnings("unchecked")
@SuppressWarnings("deprecation")
private static void verifyJwtDecoderRestTemplateTimeouts(ApplicationContext context,
JwtDecoder jwtDecoder,
int expectedConnectTimeoutMs,
int expectedReadTimeoutMs) {
// NimbusJwtDecoder -> jwtProcessor (DefaultJWTProcessor)
RemoteJWKSet<?> remoteJwkSet = getRemoteJwkSet(jwtDecoder);
Object resourceRetriever = ReflectionTestUtils.getField(remoteJwkSet, "jwkSetRetriever");
assertThat(resourceRetriever).isInstanceOf(RestOperationsResourceRetriever.class);

Object restOperations = ReflectionTestUtils.getField(resourceRetriever, "restOperations");
assertThat(restOperations).isInstanceOf(org.springframework.web.client.RestTemplate.class);
Comment thread
rujche marked this conversation as resolved.
Outdated

HttpClientSettings clientSettings = context.getBean(RecordingClientHttpRequestFactoryBuilder.class)
.getClientSettings();
assertThat(clientSettings).isNotNull();
assertThat(clientSettings.connectTimeout()).isEqualTo(Duration.ofMillis(expectedConnectTimeoutMs));
assertThat(clientSettings.readTimeout()).isEqualTo(Duration.ofMillis(expectedReadTimeoutMs));
}

@SuppressWarnings("deprecation")
private static void verifyJwtDecoderCacheDurations(JwtDecoder jwtDecoder,
long expectedCacheLifespanMs,
long expectedCacheRefreshTimeMs) {
JWKSetCache jwkSetCache = getRemoteJwkSet(jwtDecoder).getJWKSetCache();
assertThat(jwkSetCache).isInstanceOf(DefaultJWKSetCache.class);

DefaultJWKSetCache defaultJwkSetCache = (DefaultJWKSetCache) jwkSetCache;
assertThat(defaultJwkSetCache.getLifespan(java.util.concurrent.TimeUnit.MILLISECONDS))
.isEqualTo(expectedCacheLifespanMs);
assertThat(defaultJwkSetCache.getRefreshTime(java.util.concurrent.TimeUnit.MILLISECONDS))
.isEqualTo(expectedCacheRefreshTimeMs);
}

@SuppressWarnings("deprecation")
private static RemoteJWKSet<?> getRemoteJwkSet(JwtDecoder jwtDecoder) {
Object jwkSource = getJwkSource(jwtDecoder);
assertThat(jwkSource).isInstanceOf(RemoteJWKSet.class);
return (RemoteJWKSet<?>) jwkSource;
}

private static Object getJwkSource(JwtDecoder jwtDecoder) {
Object jwtProcessor = ReflectionTestUtils.getField(jwtDecoder, "jwtProcessor");
assertThat(jwtProcessor).isInstanceOf(com.nimbusds.jwt.proc.DefaultJWTProcessor.class);

// DefaultJWTProcessor -> JWSKeySelector (JWSVerificationKeySelector)
com.nimbusds.jose.proc.JWSKeySelector<?> keySelector =
((com.nimbusds.jwt.proc.DefaultJWTProcessor<?>) jwtProcessor).getJWSKeySelector();
assertThat(keySelector).isInstanceOf(com.nimbusds.jose.proc.JWSVerificationKeySelector.class);

// JWSVerificationKeySelector -> JWKSource (JWKSetBasedJWKSource)
com.nimbusds.jose.jwk.source.JWKSource<?> jwkSource =
((com.nimbusds.jose.proc.JWSVerificationKeySelector<?>) keySelector).getJWKSource();
assertThat(jwkSource).isInstanceOf(com.nimbusds.jose.jwk.source.JWKSetBasedJWKSource.class);

// JWKSetBasedJWKSource -> JWKSetSource (CachingJWKSetSource -> JWKSetSourceWrapper -> actual source)
Object jwkSetSource =
((com.nimbusds.jose.jwk.source.JWKSetBasedJWKSource<?>) jwkSource).getJWKSetSource();

// Unwrap JWKSetSourceWrapper chain to find the source with restOperations
while (jwkSetSource instanceof com.nimbusds.jose.jwk.source.JWKSetSourceWrapper<?> wrapper) {
jwkSetSource = wrapper.getSource();
}

// actual source -> restOperations (RestTemplate)
Object restOperations = ReflectionTestUtils.getField(jwkSetSource, "restOperations");
assertThat(restOperations).isInstanceOf(org.springframework.web.client.RestTemplate.class);

// RestTemplate -> ClientHttpRequestFactory
org.springframework.http.client.ClientHttpRequestFactory requestFactory =
((org.springframework.web.client.RestTemplate) restOperations).getRequestFactory();
assertThat(requestFactory).isNotNull();

assertRecordedHttpClientSettings(context, expectedConnectTimeoutMs, expectedReadTimeoutMs);
}

private static void assertRecordedHttpClientSettings(ApplicationContext context,
int expectedConnectTimeoutMs,
int expectedReadTimeoutMs) {
HttpClientSettings clientSettings = context.getBean(RecordingClientHttpRequestFactoryBuilder.class)
.getClientSettings();
assertThat(clientSettings).isNotNull();
assertThat(clientSettings.connectTimeout()).isEqualTo(Duration.ofMillis(expectedConnectTimeoutMs));
assertThat(clientSettings.readTimeout()).isEqualTo(Duration.ofMillis(expectedReadTimeoutMs));
return jwkSource;
}
}
Loading