Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/spring/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ This section includes changes in `spring-cloud-azure-autoconfigure` module.
#### Bugs Fixed

- Fixed Redis Lettuce passwordless autoconfiguration so a user-defined `LettuceClientConfigurationBuilderCustomizer` no longer suppresses the Azure customizer bean that configures Azure Redis credentials and RESP2 support.
- Applied `jwt-connect-timeout` and `jwt-read-timeout` properties to the RestTemplate used by the JWT decoder in AAD and B2C resource server configurations, preventing indefinite hanging when fetching JWK keys ([#49329](https://github.com/Azure/azure-sdk-for-java/pull/49329)).

#### Other Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ JwtDecoder jwtDecoder(AadAuthenticationProperties aadAuthenticationProperties) {
aadAuthenticationProperties.getProfile().getEnvironment().getActiveDirectoryEndpoint(), tenantId);
NimbusJwtDecoder nimbusJwtDecoder = NimbusJwtDecoder
.withJwkSetUri(identityEndpoints.getJwkSetEndpoint())
.restOperations(createRestTemplate(restTemplateBuilder))
.build();
.restOperations(createRestTemplate(restTemplateBuilder
.connectTimeout(aadAuthenticationProperties.getJwtConnectTimeout())
.readTimeout(aadAuthenticationProperties.getJwtReadTimeout())))
.build();
List<OAuth2TokenValidator<Jwt>> validators = createDefaultValidator(aadAuthenticationProperties);
nimbusJwtDecoder.setJwtValidator(new DelegatingOAuth2TokenValidator<>(validators));
return nimbusJwtDecoder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
package com.azure.spring.cloud.autoconfigure.implementation.aad.configuration.properties;

import com.azure.spring.cloud.autoconfigure.implementation.aad.security.properties.AuthorizationClientProperties;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.jwk.source.JWKSourceBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
Expand Down Expand Up @@ -90,25 +90,22 @@ public class AadAuthenticationProperties implements InitializingBean {
private final Map<String, Object> authenticateAdditionalParameters = new HashMap<>();

/**
* Connection Timeout (duration) for the JWKSet Remote URL call. The default value is `500s`.
* @deprecated If you want to configure this, please provide a 'RestOperations' bean.
* Connection Timeout (duration) for the JWKSet Remote URL call.
* The default value is {@value com.nimbusds.jose.jwk.source.JWKSourceBuilder#DEFAULT_HTTP_CONNECT_TIMEOUT} milliseconds.
*/
@Deprecated
private Duration jwtConnectTimeout = Duration.ofMillis(RemoteJWKSet.DEFAULT_HTTP_CONNECT_TIMEOUT);
private Duration jwtConnectTimeout = Duration.ofMillis(JWKSourceBuilder.DEFAULT_HTTP_CONNECT_TIMEOUT);

Comment thread
rujche marked this conversation as resolved.
Comment thread
rujche marked this conversation as resolved.
Comment thread
rujche marked this conversation as resolved.
/**
* Read Timeout (duration) for the JWKSet Remote URL call. The default value is `500s`.
* @deprecated If you want to configure this, please provide a 'RestOperations' bean.
* Read Timeout (duration) for the JWKSet Remote URL call.
* The default value is {@value com.nimbusds.jose.jwk.source.JWKSourceBuilder#DEFAULT_HTTP_READ_TIMEOUT} milliseconds.
*/
@Deprecated
private Duration jwtReadTimeout = Duration.ofMillis(RemoteJWKSet.DEFAULT_HTTP_READ_TIMEOUT);
private Duration jwtReadTimeout = Duration.ofMillis(JWKSourceBuilder.DEFAULT_HTTP_READ_TIMEOUT);

Comment thread
rujche marked this conversation as resolved.
Comment thread
rujche marked this conversation as resolved.
Comment thread
rujche marked this conversation as resolved.
/**
* Size limit in Bytes of the JWKSet Remote URL call. The default value is `51200`.
* @deprecated If you want to configure this, please provide a 'RestOperations' bean.
* Size limit in Bytes of the JWKSet Remote URL call.
* The default value is {@value com.nimbusds.jose.jwk.source.JWKSourceBuilder#DEFAULT_HTTP_SIZE_LIMIT} bytes.
*/
@Deprecated
private int jwtSizeLimit = RemoteJWKSet.DEFAULT_HTTP_SIZE_LIMIT; /* bytes */
private int jwtSizeLimit = JWKSourceBuilder.DEFAULT_HTTP_SIZE_LIMIT; /* bytes */

/**
* The lifespan (duration) of the cached JWK set before it expires.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ public class AadB2cResourceServerAutoConfiguration {

AadB2cResourceServerAutoConfiguration(AadB2cProperties properties, RestTemplateBuilder restTemplateBuilder) {
this.properties = properties;
this.restTemplateBuilder = restTemplateBuilder;
this.restTemplateBuilder = restTemplateBuilder
.connectTimeout(properties.getJwtConnectTimeout())
.readTimeout(properties.getJwtReadTimeout());
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
package com.azure.spring.cloud.autoconfigure.implementation.aadb2c.configuration.properties;

import com.azure.spring.cloud.autoconfigure.implementation.aadb2c.security.exception.AadB2cConfigurationException;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.jwk.source.JWKSourceBuilder;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
import org.springframework.util.CollectionUtils;
Expand Down Expand Up @@ -60,25 +60,22 @@ public class AadB2cProperties implements InitializingBean {
private String appIdUri;

/**
* Connection Timeout(duration) for the JWKSet Remote URL call. The default value is `500s`.
* @deprecated If you want to configure this, please provide a RestOperations bean.
* Connection Timeout (duration) for the JWKSet Remote URL call.
* The default value is {@value com.nimbusds.jose.jwk.source.JWKSourceBuilder#DEFAULT_HTTP_CONNECT_TIMEOUT} milliseconds.
*/
@Deprecated
private Duration jwtConnectTimeout = Duration.ofMillis(RemoteJWKSet.DEFAULT_HTTP_CONNECT_TIMEOUT);
private Duration jwtConnectTimeout = Duration.ofMillis(JWKSourceBuilder.DEFAULT_HTTP_CONNECT_TIMEOUT);

Comment thread
rujche marked this conversation as resolved.
Comment thread
rujche marked this conversation as resolved.
Comment thread
rujche marked this conversation as resolved.
/**
* Read Timeout(duration) for the JWKSet Remote URL call. The default value is `500s`.
* @deprecated If you want to configure this, please provide a RestOperations bean.
* Read Timeout (duration) for the JWKSet Remote URL call.
* The default value is {@value com.nimbusds.jose.jwk.source.JWKSourceBuilder#DEFAULT_HTTP_READ_TIMEOUT} milliseconds.
*/
@Deprecated
private Duration jwtReadTimeout = Duration.ofMillis(RemoteJWKSet.DEFAULT_HTTP_READ_TIMEOUT);
private Duration jwtReadTimeout = Duration.ofMillis(JWKSourceBuilder.DEFAULT_HTTP_READ_TIMEOUT);

Comment thread
rujche marked this conversation as resolved.
Comment thread
rujche marked this conversation as resolved.
Comment thread
rujche marked this conversation as resolved.
/**
* Size limit in Bytes of the JWKSet Remote URL call. The default value is `50*1024`.
* @deprecated If you want to configure this, please provide a RestOperations bean.
* Size limit in Bytes of the JWKSet Remote URL call.
* The default value is {@value com.nimbusds.jose.jwk.source.JWKSourceBuilder#DEFAULT_HTTP_SIZE_LIMIT} bytes.
*/
@Deprecated
private int jwtSizeLimit = RemoteJWKSet.DEFAULT_HTTP_SIZE_LIMIT; /* bytes */
private int jwtSizeLimit = JWKSourceBuilder.DEFAULT_HTTP_SIZE_LIMIT; /* bytes */

/**
* Redirect URL after logout.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.azure.spring.cloud.autoconfigure.implementation.aad.security.jwt.AadJwtIssuerValidator;
import com.azure.spring.cloud.autoconfigure.implementation.aad.security.AadResourceServerHttpSecurityConfigurer;
import com.azure.spring.cloud.autoconfigure.implementation.context.AzureGlobalPropertiesAutoConfiguration;
import com.nimbusds.jose.jwk.source.JWKSourceBuilder;
import com.nimbusds.jwt.proc.JWTClaimsSetAwareJWSKeySelector;
import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
Expand All @@ -32,6 +33,7 @@
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.test.util.ReflectionTestUtils;

import java.time.Duration;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
Expand Down Expand Up @@ -67,6 +69,44 @@ void testCreateJwtDecoderByJwkKeySetUri() {
});
}

@Test
void testJwtDecoderTimeoutDefaultValues() {
resourceServerContextRunner()
.withPropertyValues("spring.cloud.azure.active-directory.enabled=true")
.run(context -> {
AadAuthenticationProperties properties = context.getBean(AadAuthenticationProperties.class);
assertThat(properties.getJwtConnectTimeout())
.isEqualTo(Duration.ofMillis(JWKSourceBuilder.DEFAULT_HTTP_CONNECT_TIMEOUT));
assertThat(properties.getJwtReadTimeout())
.isEqualTo(Duration.ofMillis(JWKSourceBuilder.DEFAULT_HTTP_READ_TIMEOUT));
// Verify the default timeouts are applied to the RestTemplate used by the JwtDecoder
final JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class);
verifyJwtDecoderRestTemplateTimeouts(jwtDecoder,
JWKSourceBuilder.DEFAULT_HTTP_CONNECT_TIMEOUT,
JWKSourceBuilder.DEFAULT_HTTP_READ_TIMEOUT);
});
}

@Test
void testJwtDecoderTimeoutCustomValues() {
resourceServerContextRunner()
.withPropertyValues(
"spring.cloud.azure.active-directory.enabled=true",
"spring.cloud.azure.active-directory.jwt-connect-timeout=2000",
"spring.cloud.azure.active-directory.jwt-read-timeout=3000")
.run(context -> {
AadAuthenticationProperties properties = context.getBean(AadAuthenticationProperties.class);
assertThat(properties.getJwtConnectTimeout()).isEqualTo(Duration.ofMillis(2000));
assertThat(properties.getJwtReadTimeout()).isEqualTo(Duration.ofMillis(3000));
// Verify JwtDecoder is still created successfully with custom timeouts
final JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class);
assertThat(jwtDecoder).isNotNull();
assertThat(jwtDecoder).isExactlyInstanceOf(NimbusJwtDecoder.class);
// Verify the configured timeouts are applied to the RestTemplate used by the JwtDecoder
verifyJwtDecoderRestTemplateTimeouts(jwtDecoder, 2000, 3000);
});
Comment thread
rujche marked this conversation as resolved.
Comment thread
rujche marked this conversation as resolved.
Comment thread
rujche marked this conversation as resolved.
}

@Test
void testNotAudienceDefaultValidator() {
resourceServerRunner()
Expand Down Expand Up @@ -364,4 +404,45 @@ public Collection<GrantedAuthority> convert(Jwt source) {
return null;
}
}

/**
* Verifies that the RestTemplate used by the NimbusJwtDecoder for JWK retrieval
* has the expected connect and read timeouts applied to its ClientHttpRequestFactory.
*/
@SuppressWarnings("unchecked")
private static void verifyJwtDecoderRestTemplateTimeouts(JwtDecoder jwtDecoder,
int expectedConnectTimeoutMs,
int expectedReadTimeoutMs) {
// NimbusJwtDecoder -> jwtProcessor (DefaultJWTProcessor)
Object jwtProcessor = ReflectionTestUtils.getField(jwtDecoder, "jwtProcessor");
assertThat(jwtProcessor).isInstanceOf(com.nimbusds.jwt.proc.DefaultJWTProcessor.class);

// DefaultJWTProcessor -> JWSKeySelector (JWSVerificationKeySelector)
com.nimbusds.jose.proc.JWSKeySelector<?> keySelector =
((com.nimbusds.jwt.proc.DefaultJWTProcessor<?>) jwtProcessor).getJWSKeySelector();
assertThat(keySelector).isInstanceOf(com.nimbusds.jose.proc.JWSVerificationKeySelector.class);

// JWSVerificationKeySelector -> JWKSource (JWKSetBasedJWKSource)
com.nimbusds.jose.jwk.source.JWKSource<?> jwkSource =
((com.nimbusds.jose.proc.JWSVerificationKeySelector<?>) keySelector).getJWKSource();
assertThat(jwkSource).isInstanceOf(com.nimbusds.jose.jwk.source.JWKSetBasedJWKSource.class);

// JWKSetBasedJWKSource -> JWKSetSource (SpringJWKSource)
Object jwkSetSource =
((com.nimbusds.jose.jwk.source.JWKSetBasedJWKSource<?>) jwkSource).getJWKSetSource();

// SpringJWKSource -> restOperations (RestTemplate)
Object restOperations = ReflectionTestUtils.getField(jwkSetSource, "restOperations");
assertThat(restOperations).isInstanceOf(org.springframework.web.client.RestTemplate.class);

// RestTemplate -> ClientHttpRequestFactory
org.springframework.http.client.ClientHttpRequestFactory requestFactory =
((org.springframework.web.client.RestTemplate) restOperations).getRequestFactory();

// Verify timeouts on the request factory
int connectTimeout = (int) ReflectionTestUtils.getField(requestFactory, "connectTimeout");
int readTimeout = (int) ReflectionTestUtils.getField(requestFactory, "readTimeout");
assertThat(connectTimeout).isEqualTo(expectedConnectTimeoutMs);
assertThat(readTimeout).isEqualTo(expectedReadTimeoutMs);
Comment thread
rujche marked this conversation as resolved.
Outdated
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.azure.spring.cloud.autoconfigure.implementation.aadb2c.configuration.properties.AadB2cProperties;
import com.azure.spring.cloud.autoconfigure.implementation.aadb2c.security.jwt.AadB2cTrustedIssuerRepository;
import com.azure.spring.cloud.autoconfigure.implementation.context.AzureGlobalPropertiesAutoConfiguration;
import com.nimbusds.jose.jwk.source.JWKSourceBuilder;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import com.nimbusds.jwt.proc.JWTClaimsSetAwareJWSKeySelector;
Expand All @@ -29,6 +30,7 @@
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;

import java.time.Duration;
import java.util.Set;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -141,6 +143,36 @@ void testB2COnlyResourceServerBean() {
getResourceServerContextRunner().run(b2CResourceServerBean());
}

@Test
void testB2CTimeoutDefaultValues() {
getResourceServerContextRunner().run(context -> {
AadB2cProperties properties = context.getBean(AadB2cProperties.class);
assertThat(properties.getJwtConnectTimeout())
.isEqualTo(Duration.ofMillis(JWKSourceBuilder.DEFAULT_HTTP_CONNECT_TIMEOUT));
assertThat(properties.getJwtReadTimeout())
.isEqualTo(Duration.ofMillis(JWKSourceBuilder.DEFAULT_HTTP_READ_TIMEOUT));
// Verify the default timeouts are applied to the RestTemplate used by the ResourceRetriever
verifyResourceRetrieverRestTemplateTimeouts(context,
JWKSourceBuilder.DEFAULT_HTTP_CONNECT_TIMEOUT,
JWKSourceBuilder.DEFAULT_HTTP_READ_TIMEOUT);
});
}

@Test
void testB2CTimeoutCustomValues() {
getResourceServerContextRunner()
.withPropertyValues(
"spring.cloud.azure.active-directory.b2c.jwt-connect-timeout=2000",
"spring.cloud.azure.active-directory.b2c.jwt-read-timeout=3000")
.run(context -> {
AadB2cProperties properties = context.getBean(AadB2cProperties.class);
assertThat(properties.getJwtConnectTimeout()).isEqualTo(Duration.ofMillis(2000));
assertThat(properties.getJwtReadTimeout()).isEqualTo(Duration.ofMillis(3000));
// Verify the custom timeouts are applied to the RestTemplate used by the ResourceRetriever
verifyResourceRetrieverRestTemplateTimeouts(context, 2000, 3000);
});
Comment thread
rujche marked this conversation as resolved.
Comment thread
rujche marked this conversation as resolved.
Comment thread
rujche marked this conversation as resolved.
}

@Test
void testResourceServerConditionsIsInvokedWhenAADB2CEnableFileExists() {
try (MockedStatic<BeanUtils> beanUtils = mockStatic(BeanUtils.class, Mockito.CALLS_REAL_METHODS)) {
Expand Down Expand Up @@ -301,4 +333,33 @@ void testExistJWTClaimsSetAwareJWSKeySelectorBean() {
assertThat(jwsKeySelector).isExactlyInstanceOf(AadIssuerJwsKeySelector.class);
});
}

/**
* Verifies that the RestTemplate used by the ResourceRetriever for JWK retrieval
* has the expected connect and read timeouts applied to its ClientHttpRequestFactory.
*/
private static void verifyResourceRetrieverRestTemplateTimeouts(ApplicationContext context,
int expectedConnectTimeoutMs,
int expectedReadTimeoutMs) {
com.nimbusds.jose.util.ResourceRetriever resourceRetriever =
context.getBean(com.nimbusds.jose.util.ResourceRetriever.class);
assertThat(resourceRetriever).isNotNull();

// RestOperationsResourceRetriever -> restOperations (RestTemplate)
Object restOperations = org.springframework.test.util.ReflectionTestUtils
.getField(resourceRetriever, "restOperations");
assertThat(restOperations).isInstanceOf(org.springframework.web.client.RestTemplate.class);

// RestTemplate -> ClientHttpRequestFactory
org.springframework.http.client.ClientHttpRequestFactory requestFactory =
((org.springframework.web.client.RestTemplate) restOperations).getRequestFactory();

// Verify timeouts on the request factory
int connectTimeout = (int) org.springframework.test.util.ReflectionTestUtils
.getField(requestFactory, "connectTimeout");
int readTimeout = (int) org.springframework.test.util.ReflectionTestUtils
.getField(requestFactory, "readTimeout");
assertThat(connectTimeout).isEqualTo(expectedConnectTimeoutMs);
assertThat(readTimeout).isEqualTo(expectedReadTimeoutMs);
Comment thread
rujche marked this conversation as resolved.
Outdated
}
}
Loading