Skip to content

Commit 3ffc60f

Browse files
committed
Fix test failures: unwrap CachingJWKSetSource chain and handle Duration timeout fields
1 parent e7e55fd commit 3ffc60f

2 files changed

Lines changed: 17 additions & 8 deletions

File tree

sdk/spring/spring-cloud-azure-autoconfigure/src/test/java/com/azure/spring/cloud/autoconfigure/implementation/aad/configuration/AadResourceServerConfigurationTests.java

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -427,21 +427,28 @@ private static void verifyJwtDecoderRestTemplateTimeouts(JwtDecoder jwtDecoder,
427427
((com.nimbusds.jose.proc.JWSVerificationKeySelector<?>) keySelector).getJWKSource();
428428
assertThat(jwkSource).isInstanceOf(com.nimbusds.jose.jwk.source.JWKSetBasedJWKSource.class);
429429

430-
// JWKSetBasedJWKSource -> JWKSetSource (SpringJWKSource)
430+
// JWKSetBasedJWKSource -> JWKSetSource (CachingJWKSetSource -> JWKSetSourceWrapper -> actual source)
431431
Object jwkSetSource =
432432
((com.nimbusds.jose.jwk.source.JWKSetBasedJWKSource<?>) jwkSource).getJWKSetSource();
433433

434-
// SpringJWKSource -> restOperations (RestTemplate)
434+
// Unwrap JWKSetSourceWrapper chain to find the source with restOperations
435+
while (jwkSetSource instanceof com.nimbusds.jose.jwk.source.JWKSetSourceWrapper<?> wrapper) {
436+
jwkSetSource = wrapper.getSource();
437+
}
438+
439+
// actual source -> restOperations (RestTemplate)
435440
Object restOperations = ReflectionTestUtils.getField(jwkSetSource, "restOperations");
436441
assertThat(restOperations).isInstanceOf(org.springframework.web.client.RestTemplate.class);
437442

438443
// RestTemplate -> ClientHttpRequestFactory
439444
org.springframework.http.client.ClientHttpRequestFactory requestFactory =
440445
((org.springframework.web.client.RestTemplate) restOperations).getRequestFactory();
441446

442-
// Verify timeouts on the request factory
443-
int connectTimeout = (int) ReflectionTestUtils.getField(requestFactory, "connectTimeout");
444-
int readTimeout = (int) ReflectionTestUtils.getField(requestFactory, "readTimeout");
447+
// Verify timeouts on the request factory (may be stored as Duration or int)
448+
Object connectTimeoutValue = ReflectionTestUtils.getField(requestFactory, "connectTimeout");
449+
Object readTimeoutValue = ReflectionTestUtils.getField(requestFactory, "readTimeout");
450+
int connectTimeout = connectTimeoutValue instanceof java.time.Duration d ? (int) d.toMillis() : (int) connectTimeoutValue;
451+
int readTimeout = readTimeoutValue instanceof java.time.Duration d ? (int) d.toMillis() : (int) readTimeoutValue;
445452
assertThat(connectTimeout).isEqualTo(expectedConnectTimeoutMs);
446453
assertThat(readTimeout).isEqualTo(expectedReadTimeoutMs);
447454
}

sdk/spring/spring-cloud-azure-autoconfigure/src/test/java/com/azure/spring/cloud/autoconfigure/implementation/aadb2c/configuration/AadB2cResourceServerAutoConfigurationTests.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -354,11 +354,13 @@ private static void verifyResourceRetrieverRestTemplateTimeouts(ApplicationConte
354354
org.springframework.http.client.ClientHttpRequestFactory requestFactory =
355355
((org.springframework.web.client.RestTemplate) restOperations).getRequestFactory();
356356

357-
// Verify timeouts on the request factory
358-
int connectTimeout = (int) org.springframework.test.util.ReflectionTestUtils
357+
// Verify timeouts on the request factory (may be stored as Duration or int)
358+
Object connectTimeoutValue = org.springframework.test.util.ReflectionTestUtils
359359
.getField(requestFactory, "connectTimeout");
360-
int readTimeout = (int) org.springframework.test.util.ReflectionTestUtils
360+
Object readTimeoutValue = org.springframework.test.util.ReflectionTestUtils
361361
.getField(requestFactory, "readTimeout");
362+
int connectTimeout = connectTimeoutValue instanceof java.time.Duration d ? (int) d.toMillis() : (int) connectTimeoutValue;
363+
int readTimeout = readTimeoutValue instanceof java.time.Duration d ? (int) d.toMillis() : (int) readTimeoutValue;
362364
assertThat(connectTimeout).isEqualTo(expectedConnectTimeoutMs);
363365
assertThat(readTimeout).isEqualTo(expectedReadTimeoutMs);
364366
}

0 commit comments

Comments
 (0)