1616package com.okta.authfoundation.client
1717
1818import com.google.common.truth.Truth.assertThat
19+ import com.okta.authfoundation.client.IdTokenValidator.Error.Companion.INVALID_JWT_SIGNATURE
1920import com.okta.authfoundation.client.dto.OidcIntrospectInfo
2021import com.okta.authfoundation.client.dto.OidcUserInfo
2122import com.okta.authfoundation.client.events.TokenCreatedEvent
23+ import com.okta.authfoundation.client.internal.performRequest
2224import com.okta.authfoundation.credential.Credential
2325import com.okta.authfoundation.credential.RevokeTokenType
26+ import com.okta.authfoundation.credential.SerializableToken
2427import com.okta.authfoundation.credential.Token
2528import com.okta.authfoundation.credential.TokenType
2629import com.okta.authfoundation.credential.createToken
@@ -36,13 +39,17 @@ import com.okta.testhelpers.RequestMatchers.path
3639import com.okta.testhelpers.RequestMatchers.query
3740import com.okta.testhelpers.testBodyFromFile
3841import io.mockk.coEvery
42+ import io.mockk.coVerify
3943import io.mockk.mockk
44+ import io.mockk.mockkConstructor
4045import io.mockk.mockkObject
46+ import io.mockk.mockkStatic
47+ import io.mockk.spyk
4148import io.mockk.unmockkAll
4249import kotlinx.coroutines.runBlocking
50+ import kotlinx.serialization.DeserializationStrategy
4351import kotlinx.serialization.SerialName
4452import kotlinx.serialization.Serializable
45- import kotlinx.serialization.encodeToString
4653import okhttp3.HttpUrl.Companion.toHttpUrl
4754import org.junit.After
4855import org.junit.Before
@@ -533,6 +540,12 @@ class OAuth2ClientTest {
533540 val keys = createJwks().toSerializableJwks()
534541 response.setBody(oktaRule.configuration.json.encodeToString(keys))
535542 }
543+ oktaRule.enqueue(
544+ path(" /oauth2/default/v1/keys" )
545+ ) { response ->
546+ val keys = createJwks().toSerializableJwks()
547+ response.setBody(oktaRule.configuration.json.encodeToString(keys))
548+ }
536549 oktaRule.enqueue(
537550 method(" POST" ),
538551 path(" /oauth2/default/v1/token" )
@@ -554,7 +567,64 @@ class OAuth2ClientTest {
554567 assertThat(exception).hasMessageThat().isEqualTo(" Invalid id_token signature" )
555568 assertThat(exception).isInstanceOf(IdTokenValidator .Error ::class .java)
556569 val idTokenValidatorError = exception as IdTokenValidator .Error
557- assertThat(idTokenValidatorError.identifier).isEqualTo(IdTokenValidator .Error .INVALID_JWT_SIGNATURE )
570+ assertThat(idTokenValidatorError.identifier).isEqualTo(INVALID_JWT_SIGNATURE )
571+ }
572+
573+ @Test
574+ fun `when validation fails, retry to with new jwks, expect success result returned` () =
575+ runBlocking {
576+ // arrange
577+ mockkStatic(" com.okta.authfoundation.client.internal.NetworkUtilsKt" )
578+
579+ val token = createToken(idToken = " dummyIdToken" , refreshToken = " dummyRefreshToken" )
580+ coEvery {
581+ any<OAuth2Client >().performRequest(any<DeserializationStrategy <SerializableToken >>(), any(), any(), any<Function1 <SerializableToken , Token >>())
582+ } returns OAuth2ClientResult .Success (token)
583+
584+ val spyClient = spyk(oktaRule.createOAuth2Client(oktaRule.createEndpoints(includeJwks = true )))
585+ val jwksSuccess = OAuth2ClientResult .Success (createJwks())
586+
587+ coEvery { spyClient.jwks() } returns jwksSuccess
588+ mockkConstructor(TokenValidator ::class )
589+ coEvery { anyConstructed<TokenValidator >().validate() } throws IdTokenValidator .Error (" fail" , INVALID_JWT_SIGNATURE ) andThen Unit
590+
591+ // Act
592+ val result = spyClient.tokenRequest(mockk(), requestToken = token)
593+
594+ // Assert
595+ coVerify(exactly = 2 ) { anyConstructed<TokenValidator >().validate() }
596+ assertThat(result).isInstanceOf(OAuth2ClientResult .Success ::class .java)
597+ unmockkAll()
598+ }
599+
600+ @Test
601+ fun `when validation fails, retry to with new jwks but still fails expect failure result returned` () =
602+ runBlocking {
603+ // arrange
604+ mockkStatic(" com.okta.authfoundation.client.internal.NetworkUtilsKt" )
605+
606+ val token = createToken(idToken = " dummyIdToken" , refreshToken = " dummyRefreshToken" )
607+ coEvery {
608+ any<OAuth2Client >().performRequest(any<DeserializationStrategy <SerializableToken >>(), any(), any(), any<Function1 <SerializableToken , Token >>())
609+ } returns OAuth2ClientResult .Success (token)
610+
611+ val spyClient = spyk(oktaRule.createOAuth2Client(oktaRule.createEndpoints(includeJwks = true )))
612+ val jwksSuccess = OAuth2ClientResult .Success (createJwks())
613+
614+ coEvery { spyClient.jwks() } returns jwksSuccess
615+ mockkConstructor(TokenValidator ::class )
616+ coEvery { anyConstructed<TokenValidator >().validate() } throws IdTokenValidator .Error (" fail" , INVALID_JWT_SIGNATURE )
617+
618+ // Act
619+ val result = spyClient.tokenRequest(mockk(), requestToken = token)
620+
621+ // Assert
622+ coVerify(exactly = 2 ) { anyConstructed<TokenValidator >().validate() }
623+ assertThat(result).isInstanceOf(OAuth2ClientResult .Error ::class .java)
624+ assertThat((result as OAuth2ClientResult .Error ).exception)
625+ .hasMessageThat()
626+ .isEqualTo(" fail" )
627+ unmockkAll()
558628 }
559629}
560630
0 commit comments