Skip to content

Commit c525236

Browse files
authored
Refresh jwks when validation fails (#334)
OKTA-689349 Refresh jwks when validation fails Remove async in tokenRequest method
1 parent 40cbcf5 commit c525236

File tree

3 files changed

+112
-41
lines changed

3 files changed

+112
-41
lines changed

auth-foundation/src/main/java/com/okta/authfoundation/client/OAuth2Client.kt

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ import com.okta.authfoundation.credential.TokenType
3232
import com.okta.authfoundation.jwt.Jwks
3333
import com.okta.authfoundation.jwt.SerializableJwks
3434
import com.okta.authfoundation.util.CoalescingOrchestrator
35-
import kotlinx.coroutines.async
36-
import kotlinx.coroutines.coroutineScope
3735
import kotlinx.serialization.json.JsonObject
3836
import okhttp3.FormBody
3937
import okhttp3.Request
@@ -232,12 +230,8 @@ class OAuth2Client private constructor(
232230

233231
private fun jwksCoalescingOrchestrator(): CoalescingOrchestrator<OAuth2ClientResult<Jwks>> =
234232
CoalescingOrchestrator(
235-
factory = {
236-
actualJwks()
237-
},
238-
keepDataInMemory = { result ->
239-
result is OAuth2ClientResult.Success
240-
}
233+
factory = ::actualJwks,
234+
keepDataInMemory = { result -> result is OAuth2ClientResult.Success }
241235
)
242236

243237
private suspend fun actualJwks(): OAuth2ClientResult<Jwks> {
@@ -266,6 +260,7 @@ class OAuth2Client private constructor(
266260
is OAuth2ClientResult.Error -> {
267261
null
268262
}
263+
269264
is OAuth2ClientResult.Success -> {
270265
result.result
271266
}
@@ -281,53 +276,59 @@ class OAuth2Client private constructor(
281276
maxAge: Int? = null,
282277
requestToken: Token? = null,
283278
): OAuth2ClientResult<Token> {
284-
return coroutineScope {
285-
val isRefreshRequest = requestToken != null
286-
val tokenId = requestToken?.id ?: UUID.randomUUID().toString()
287-
val tokenDeferred =
288-
async {
289-
performRequest(SerializableToken.serializer(), request) { serializableToken ->
290-
serializableToken.asToken(id = tokenId, oidcConfiguration = configuration)
291-
}
292-
}
293-
val jwksDeferred =
294-
async {
295-
endpointsOrNull()?.jwksUri ?: return@async null
296-
jwks()
279+
val isRefreshRequest = requestToken != null
280+
val tokenId = requestToken?.id ?: UUID.randomUUID().toString()
281+
val tokenResult =
282+
performRequest(SerializableToken.serializer(), request) { serializableToken ->
283+
serializableToken.asToken(id = tokenId, oidcConfiguration = configuration)
284+
}
285+
286+
if (tokenResult is OAuth2ClientResult.Success) {
287+
suspend fun validateToken(
288+
token: Token,
289+
jwksResult: OAuth2ClientResult<Jwks>?,
290+
) {
291+
TokenValidator(this@OAuth2Client, token, nonce, maxAge, jwksResult).validate()
292+
configuration.eventCoordinator.sendEvent(TokenCreatedEvent(token))
293+
if (isRefreshRequest) {
294+
Credential.credentialDataSource().replaceToken(token)
297295
}
298-
val tokenResult = tokenDeferred.await()
299-
if (tokenResult is OAuth2ClientResult.Success) {
300-
val token = tokenResult.result
301-
302-
try {
303-
TokenValidator(this@OAuth2Client, token, nonce, maxAge, jwksDeferred.await()).validate()
304-
configuration.eventCoordinator.sendEvent(TokenCreatedEvent(token))
305-
if (isRefreshRequest) {
306-
Credential.credentialDataSource().replaceToken(token)
307-
}
308-
} catch (e: Exception) {
309-
return@coroutineScope OAuth2ClientResult.Error(e)
296+
}
297+
298+
runCatching {
299+
validateToken(tokenResult.result, endpointsOrNull()?.jwksUri?.run { jwks() })
300+
}.recoverCatching { throwable ->
301+
if (throwable is IdTokenValidator.Error) {
302+
val refreshJwksOrchestrator = CoalescingOrchestrator(::actualJwks, { result -> result is OAuth2ClientResult.Success }).get()
303+
if (refreshJwksOrchestrator is OAuth2ClientResult.Error) throw throwable else validateToken(tokenResult.result, refreshJwksOrchestrator)
304+
} else {
305+
throw throwable
310306
}
307+
}.onFailure { e ->
308+
return OAuth2ClientResult.Error(e as Exception)
311309
}
312-
return@coroutineScope tokenResult
313310
}
311+
return tokenResult
314312
}
315313

316314
private fun Token.getTokenOfType(tokenType: TokenType): Result<String> =
317315
when (tokenType) {
318316
TokenType.ACCESS_TOKEN -> {
319317
Result.success(accessToken)
320318
}
319+
321320
TokenType.REFRESH_TOKEN -> {
322321
refreshToken?.let {
323322
Result.success(it)
324323
} ?: Result.failure(IllegalStateException("No refresh token."))
325324
}
325+
326326
TokenType.ID_TOKEN -> {
327327
idToken?.let {
328328
Result.success(it)
329329
} ?: Result.failure(IllegalStateException("No id token."))
330330
}
331+
331332
TokenType.DEVICE_SECRET -> {
332333
deviceSecret?.let {
333334
Result.success(it)

auth-foundation/src/test/java/com/okta/authfoundation/client/OAuth2ClientTest.kt

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@
1616
package com.okta.authfoundation.client
1717

1818
import com.google.common.truth.Truth.assertThat
19+
import com.okta.authfoundation.client.IdTokenValidator.Error.Companion.INVALID_JWT_SIGNATURE
1920
import com.okta.authfoundation.client.dto.OidcIntrospectInfo
2021
import com.okta.authfoundation.client.dto.OidcUserInfo
2122
import com.okta.authfoundation.client.events.TokenCreatedEvent
23+
import com.okta.authfoundation.client.internal.performRequest
2224
import com.okta.authfoundation.credential.Credential
2325
import com.okta.authfoundation.credential.RevokeTokenType
26+
import com.okta.authfoundation.credential.SerializableToken
2427
import com.okta.authfoundation.credential.Token
2528
import com.okta.authfoundation.credential.TokenType
2629
import com.okta.authfoundation.credential.createToken
@@ -36,13 +39,17 @@ import com.okta.testhelpers.RequestMatchers.path
3639
import com.okta.testhelpers.RequestMatchers.query
3740
import com.okta.testhelpers.testBodyFromFile
3841
import io.mockk.coEvery
42+
import io.mockk.coVerify
3943
import io.mockk.mockk
44+
import io.mockk.mockkConstructor
4045
import io.mockk.mockkObject
46+
import io.mockk.mockkStatic
47+
import io.mockk.spyk
4148
import io.mockk.unmockkAll
4249
import kotlinx.coroutines.runBlocking
50+
import kotlinx.serialization.DeserializationStrategy
4351
import kotlinx.serialization.SerialName
4452
import kotlinx.serialization.Serializable
45-
import kotlinx.serialization.encodeToString
4653
import okhttp3.HttpUrl.Companion.toHttpUrl
4754
import org.junit.After
4855
import org.junit.Before
@@ -533,6 +540,12 @@ class OAuth2ClientTest {
533540
val keys = createJwks().toSerializableJwks()
534541
response.setBody(oktaRule.configuration.json.encodeToString(keys))
535542
}
543+
oktaRule.enqueue(
544+
path("/oauth2/default/v1/keys")
545+
) { response ->
546+
val keys = createJwks().toSerializableJwks()
547+
response.setBody(oktaRule.configuration.json.encodeToString(keys))
548+
}
536549
oktaRule.enqueue(
537550
method("POST"),
538551
path("/oauth2/default/v1/token")
@@ -554,7 +567,64 @@ class OAuth2ClientTest {
554567
assertThat(exception).hasMessageThat().isEqualTo("Invalid id_token signature")
555568
assertThat(exception).isInstanceOf(IdTokenValidator.Error::class.java)
556569
val idTokenValidatorError = exception as IdTokenValidator.Error
557-
assertThat(idTokenValidatorError.identifier).isEqualTo(IdTokenValidator.Error.INVALID_JWT_SIGNATURE)
570+
assertThat(idTokenValidatorError.identifier).isEqualTo(INVALID_JWT_SIGNATURE)
571+
}
572+
573+
@Test
574+
fun `when validation fails, retry to with new jwks, expect success result returned`() =
575+
runBlocking {
576+
// arrange
577+
mockkStatic("com.okta.authfoundation.client.internal.NetworkUtilsKt")
578+
579+
val token = createToken(idToken = "dummyIdToken", refreshToken = "dummyRefreshToken")
580+
coEvery {
581+
any<OAuth2Client>().performRequest(any<DeserializationStrategy<SerializableToken>>(), any(), any(), any<Function1<SerializableToken, Token>>())
582+
} returns OAuth2ClientResult.Success(token)
583+
584+
val spyClient = spyk(oktaRule.createOAuth2Client(oktaRule.createEndpoints(includeJwks = true)))
585+
val jwksSuccess = OAuth2ClientResult.Success(createJwks())
586+
587+
coEvery { spyClient.jwks() } returns jwksSuccess
588+
mockkConstructor(TokenValidator::class)
589+
coEvery { anyConstructed<TokenValidator>().validate() } throws IdTokenValidator.Error("fail", INVALID_JWT_SIGNATURE) andThen Unit
590+
591+
// Act
592+
val result = spyClient.tokenRequest(mockk(), requestToken = token)
593+
594+
// Assert
595+
coVerify(exactly = 2) { anyConstructed<TokenValidator>().validate() }
596+
assertThat(result).isInstanceOf(OAuth2ClientResult.Success::class.java)
597+
unmockkAll()
598+
}
599+
600+
@Test
601+
fun `when validation fails, retry to with new jwks but still fails expect failure result returned`() =
602+
runBlocking {
603+
// arrange
604+
mockkStatic("com.okta.authfoundation.client.internal.NetworkUtilsKt")
605+
606+
val token = createToken(idToken = "dummyIdToken", refreshToken = "dummyRefreshToken")
607+
coEvery {
608+
any<OAuth2Client>().performRequest(any<DeserializationStrategy<SerializableToken>>(), any(), any(), any<Function1<SerializableToken, Token>>())
609+
} returns OAuth2ClientResult.Success(token)
610+
611+
val spyClient = spyk(oktaRule.createOAuth2Client(oktaRule.createEndpoints(includeJwks = true)))
612+
val jwksSuccess = OAuth2ClientResult.Success(createJwks())
613+
614+
coEvery { spyClient.jwks() } returns jwksSuccess
615+
mockkConstructor(TokenValidator::class)
616+
coEvery { anyConstructed<TokenValidator>().validate() } throws IdTokenValidator.Error("fail", INVALID_JWT_SIGNATURE)
617+
618+
// Act
619+
val result = spyClient.tokenRequest(mockk(), requestToken = token)
620+
621+
// Assert
622+
coVerify(exactly = 2) { anyConstructed<TokenValidator>().validate() }
623+
assertThat(result).isInstanceOf(OAuth2ClientResult.Error::class.java)
624+
assertThat((result as OAuth2ClientResult.Error).exception)
625+
.hasMessageThat()
626+
.isEqualTo("fail")
627+
unmockkAll()
558628
}
559629
}
560630

buildSrc/src/main/java/Configuration.kt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ import org.gradle.api.JavaVersion
22
import org.gradle.api.Project
33
import java.util.Properties
44

5-
const val BOM_VERSION = "2.0.3"
6-
const val AUTH_FOUNDATION_VERSION = "2.0.3"
7-
const val OAUTH2_VERSION = "2.0.3"
8-
const val WEB_AUTHENTICATION_UI_VERSION = "2.0.3"
9-
const val LEGACY_TOKEN_MIGRATION_VERSION = "2.0.3"
5+
const val BOM_VERSION = "2.0.4"
6+
const val AUTH_FOUNDATION_VERSION = "2.0.4"
7+
const val OAUTH2_VERSION = "2.0.4"
8+
const val WEB_AUTHENTICATION_UI_VERSION = "2.0.4"
9+
const val LEGACY_TOKEN_MIGRATION_VERSION = "2.0.4"
1010
const val MIN_SDK = 23
1111
const val COMPILE_SDK = 35
1212
const val TARGET_SDK = 35

0 commit comments

Comments
 (0)