Skip to content

Commit 82844cf

Browse files
committed
KTOR-6569 Add option for Bearer auth to not cache client token
1 parent 3fa4a7a commit 82844cf

File tree

2 files changed

+277
-14
lines changed
  • ktor-client/ktor-client-plugins/ktor-client-auth/common

2 files changed

+277
-14
lines changed

ktor-client/ktor-client-plugins/ktor-client-auth/common/src/io/ktor/client/plugins/auth/providers/BearerAuthProvider.kt

Lines changed: 85 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ import io.ktor.client.statement.*
1111
import io.ktor.http.*
1212
import io.ktor.http.auth.*
1313
import io.ktor.utils.io.*
14+
import kotlinx.coroutines.Job
15+
import kotlinx.coroutines.sync.Mutex
16+
import kotlinx.coroutines.sync.withLock
1417

1518
/**
1619
* Installs the client's [BearerAuthProvider].
@@ -19,7 +22,7 @@ import io.ktor.utils.io.*
1922
*/
2023
public fun AuthConfig.bearer(block: BearerAuthConfig.() -> Unit) {
2124
with(BearerAuthConfig().apply(block)) {
22-
this@bearer.providers.add(BearerAuthProvider(refreshTokens, loadTokens, sendWithoutRequest, realm))
25+
this@bearer.providers.add(BearerAuthProvider(refreshTokens, loadTokens, sendWithoutRequest, realm, cacheTokens))
2326
}
2427
}
2528

@@ -61,6 +64,13 @@ public class BearerAuthConfig {
6164
internal var sendWithoutRequest: (HttpRequestBuilder) -> Boolean = { true }
6265

6366
public var realm: String? = null
67+
68+
/**
69+
* Controls whether to cache tokens between requests.
70+
* When set to false, the provider will call [loadTokens] for each request.
71+
* Default value is true.
72+
*/
73+
public var cacheTokens: Boolean = true
6474

6575
/**
6676
* Configures a callback that refreshes a token when the 401 status code is received.
@@ -97,23 +107,34 @@ public class BearerAuthConfig {
97107
* As an example, these tokens can be used as a part of OAuth flow to authorize users of your application
98108
* by using external providers, such as Google, Facebook, Twitter, and so on.
99109
*
110+
* You can control whether tokens are cached between requests with the [cacheTokens] parameter:
111+
* - When `true` (default), tokens are cached after the first request and reused.
112+
* - When `false`, [loadTokens] is called for each request, and the token is never cached.
113+
*
100114
* You can learn more from [Bearer authentication](https://ktor.io/docs/bearer-client.html).
101115
*
102116
* [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.client.plugins.auth.providers.BearerAuthProvider)
103117
*/
104118
public class BearerAuthProvider(
105119
private val refreshTokens: suspend RefreshTokensParams.() -> BearerTokens?,
106-
loadTokens: suspend () -> BearerTokens?,
120+
private val loadTokensCallback: suspend () -> BearerTokens?,
107121
private val sendWithoutRequestCallback: (HttpRequestBuilder) -> Boolean = { true },
108-
private val realm: String?
122+
private val realm: String?,
123+
private val cacheTokens: Boolean = true
109124
) : AuthProvider {
110125

111126
@Suppress("OverridingDeprecatedMember")
112127
@Deprecated("Please use sendWithoutRequest function instead", level = DeprecationLevel.ERROR)
113128
override val sendWithoutRequest: Boolean
114129
get() = error("Deprecated")
115130

116-
private val tokensHolder = AuthTokenHolder(loadTokens)
131+
// Only create the tokens holder if caching is enabled
132+
private val tokensHolder = if (cacheTokens) AuthTokenHolder(loadTokensCallback) else null
133+
134+
// When caching is disabled, we still need to store the current refreshed token
135+
// so it can be used in the retry mechanism during the current request cycle
136+
private val currentRefreshTokenMutex = Mutex()
137+
private var currentRefreshedToken: BearerTokens? = null
117138

118139
override fun sendWithoutRequest(request: HttpRequestBuilder): Boolean = sendWithoutRequestCallback(request)
119140

@@ -144,24 +165,65 @@ public class BearerAuthProvider(
144165
* [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.client.plugins.auth.providers.BearerAuthProvider.addRequestHeaders)
145166
*/
146167
override suspend fun addRequestHeaders(request: HttpRequestBuilder, authHeader: HttpAuthHeader?) {
147-
val token = tokensHolder.loadToken() ?: return
168+
// If the request has the circuit breaker attribute, don't add any auth headers
169+
if (request.attributes.contains(AuthCircuitBreaker)) {
170+
LOGGER.trace("Circuit breaker active - no auth header will be added")
171+
return
172+
}
173+
174+
// Get the appropriate token based on caching settings
175+
val token = currentRefreshTokenMutex.withLock {
176+
if (currentRefreshedToken != null) {
177+
// If we have a refreshed token for the current retry cycle, use that
178+
LOGGER.trace("Using refreshed token for request: ${currentRefreshedToken!!.accessToken}")
179+
return@withLock currentRefreshedToken
180+
} else if (cacheTokens) {
181+
// If caching is enabled, use tokensHolder to cache between requests
182+
return@withLock tokensHolder!!.loadToken()
183+
} else {
184+
// If caching is disabled, load a fresh token
185+
val freshToken = loadTokensCallback()
186+
LOGGER.trace("Using fresh token for request: ${freshToken?.accessToken}")
187+
return@withLock freshToken
188+
}
189+
} ?: return
148190

149191
request.headers {
150192
val tokenValue = "Bearer ${token.accessToken}"
151193
if (contains(HttpHeaders.Authorization)) {
152194
remove(HttpHeaders.Authorization)
153195
}
154-
if (request.attributes.contains(AuthCircuitBreaker).not()) {
155-
append(HttpHeaders.Authorization, tokenValue)
156-
}
196+
append(HttpHeaders.Authorization, tokenValue)
157197
}
158198
}
159199

160200
public override suspend fun refreshToken(response: HttpResponse): Boolean {
161-
val newToken = tokensHolder.setToken {
162-
refreshTokens(RefreshTokensParams(response.call.client, response, tokensHolder.loadToken()))
201+
return if (cacheTokens) {
202+
// With caching enabled, use the token holder for persistent caching
203+
val newToken = tokensHolder!!.setToken {
204+
refreshTokens(RefreshTokensParams(response.call.client, response, tokensHolder.loadToken()))
205+
}
206+
newToken != null
207+
} else {
208+
// Thread-safe access to currentRefreshedToken
209+
currentRefreshTokenMutex.withLock {
210+
// Get the current token (used as oldTokens in RefreshTokensParams)
211+
val currentToken = loadTokensCallback()
212+
213+
// Get the new token from the refresh function
214+
val newToken = refreshTokens(RefreshTokensParams(response.call.client, response, currentToken))
215+
216+
// Store the refreshed token for use in the retry process
217+
if (newToken != null) {
218+
LOGGER.trace("Setting refreshed token: ${newToken.accessToken}")
219+
currentRefreshedToken = newToken
220+
true
221+
} else {
222+
LOGGER.trace("No refreshed token returned")
223+
false
224+
}
225+
}
163226
}
164-
return newToken != null
165227
}
166228

167229
/**
@@ -171,13 +233,22 @@ public class BearerAuthProvider(
171233
* - When access or refresh tokens have been updated externally
172234
* - When you want to clear sensitive token data (for example, during logout)
173235
*
174-
* Note: The result of `loadTokens` invocation is cached internally.
236+
* Note: The result of `loadTokens` invocation is cached internally when [cacheTokens] is true.
175237
* Calling this method will force the next authentication attempt to fetch fresh tokens
176238
* through the configured `loadTokens` function.
177239
*
240+
* If [cacheTokens] is false, this method will clear any temporarily stored refresh token.
241+
*
178242
* [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.client.plugins.auth.providers.BearerAuthProvider.clearToken)
179243
*/
180-
public fun clearToken() {
181-
tokensHolder.clearToken()
244+
public suspend fun clearToken() {
245+
if (cacheTokens) {
246+
tokensHolder!!.clearToken()
247+
} else {
248+
// Thread-safe access to clear any temporarily stored refreshed token
249+
currentRefreshTokenMutex.withLock {
250+
currentRefreshedToken = null
251+
}
252+
}
182253
}
183254
}

ktor-client/ktor-client-plugins/ktor-client-auth/common/test/io/ktor/client/plugins/auth/AuthTest.kt

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,198 @@ class AuthTest : ClientLoader() {
696696
assertEquals(2, loadCount)
697697
}
698698
}
699+
700+
@Test
701+
fun testDisableTokenCaching() = testWithEngine(MockEngine) {
702+
var loadCount = 0
703+
config {
704+
install(Auth) {
705+
bearer {
706+
loadTokens {
707+
loadCount++
708+
BearerTokens("valid-token", "refresh-token")
709+
}
710+
// Disable token caching
711+
cacheTokens = false
712+
}
713+
}
714+
715+
engine {
716+
addHandler { request ->
717+
val token = request.headers[HttpHeaders.Authorization]?.removePrefix("Bearer ")
718+
if (token == "valid-token") {
719+
respond("OK", HttpStatusCode.OK)
720+
} else {
721+
respond("Unauthorized", HttpStatusCode.Unauthorized)
722+
}
723+
}
724+
}
725+
}
726+
727+
test { client ->
728+
loadCount = 0
729+
730+
// First request should call loadTokens
731+
val response1 = client.get("/")
732+
assertEquals(HttpStatusCode.OK, response1.status)
733+
assertEquals(1, loadCount)
734+
735+
// Second request should call loadTokens again since caching is disabled
736+
val response2 = client.get("/")
737+
assertEquals(HttpStatusCode.OK, response2.status)
738+
assertEquals(2, loadCount)
739+
740+
// Third request should call loadTokens once more
741+
val response3 = client.get("/")
742+
assertEquals(HttpStatusCode.OK, response3.status)
743+
assertEquals(3, loadCount)
744+
}
745+
}
746+
747+
@Test
748+
fun testRefreshWithCachingDisabled() = testWithEngine(MockEngine) {
749+
var loadCount = 0
750+
var refreshCount = 0
751+
752+
// Storage of the current token to mimic Firebase token cache
753+
var tokenCache = "initial-token"
754+
755+
config {
756+
install(Auth) {
757+
bearer {
758+
loadTokens {
759+
loadCount++
760+
// Return token from cache
761+
println("loadTokens called (#$loadCount) - returning $tokenCache")
762+
BearerTokens(tokenCache, "refresh")
763+
}
764+
refreshTokens {
765+
refreshCount++
766+
// Update token in cache
767+
val newToken = "refreshed-token-$refreshCount"
768+
tokenCache = newToken
769+
println("refreshTokens called (#$refreshCount) - updating cache to $newToken")
770+
BearerTokens(newToken, "refresh")
771+
}
772+
// Disable token caching
773+
cacheTokens = false
774+
}
775+
}
776+
777+
engine {
778+
addHandler { request ->
779+
val token = request.headers[HttpHeaders.Authorization]?.removePrefix("Bearer ")
780+
println("Received request with token: ${token ?: "none"}")
781+
782+
if (token == "initial-token") {
783+
// Initial token is rejected
784+
println("Returning 401 for initial-token")
785+
respond("Unauthorized", HttpStatusCode.Unauthorized,
786+
headers = headersOf(HttpHeaders.WWWAuthenticate, "Bearer"))
787+
} else if (token?.startsWith("refreshed-token") == true) {
788+
// Refreshed tokens are accepted
789+
println("Returning 200 for refreshed token: $token")
790+
respond("OK", HttpStatusCode.OK)
791+
} else {
792+
println("Returning 401 for unexpected token: $token")
793+
respond("No token", HttpStatusCode.Unauthorized,
794+
headers = headersOf(HttpHeaders.WWWAuthenticate, "Bearer"))
795+
}
796+
}
797+
}
798+
}
799+
800+
test { client ->
801+
loadCount = 0
802+
refreshCount = 0
803+
804+
// Token is loaded from cache first, then refreshed during the 401 scenario
805+
val response1 = client.get("/first")
806+
807+
// Wait a moment to ensure processing completes
808+
delay(100)
809+
810+
// The retry should use the refreshed token from our cache
811+
assertEquals(HttpStatusCode.OK, response1.status)
812+
813+
// Verify refresh was called once
814+
assertEquals(1, refreshCount, "Expected refresh to be called once")
815+
816+
// Second request should use the fresh token from our external cache
817+
val response2 = client.get("/second")
818+
assertEquals(HttpStatusCode.OK, response2.status)
819+
820+
// Loadtokens should be called at least twice - once for each request
821+
assertTrue(loadCount >= 2, "Expected loadTokens to be called at least twice")
822+
823+
// No additional refresh needed for second request since our cache has the valid token
824+
assertEquals(1, refreshCount, "No additional refresh needed for second request")
825+
}
826+
}
827+
828+
@Test
829+
fun testRefreshWithCachingEnabled() = testWithEngine(MockEngine) {
830+
var loadCount = 0
831+
var refreshCount = 0
832+
833+
config {
834+
install(Auth) {
835+
bearer {
836+
loadTokens {
837+
loadCount++
838+
// Always return the same token from loadTokens
839+
BearerTokens("initial-token", "refresh")
840+
}
841+
refreshTokens {
842+
refreshCount++
843+
// Return a different token from refreshTokens
844+
BearerTokens("refreshed-token-$refreshCount", "refresh")
845+
}
846+
// Enable token caching (default)
847+
cacheTokens = true
848+
}
849+
}
850+
851+
engine {
852+
addHandler { request ->
853+
val token = request.headers[HttpHeaders.Authorization]?.removePrefix("Bearer ")
854+
if (token == "initial-token") {
855+
// Initial token is rejected
856+
respond("Unauthorized", HttpStatusCode.Unauthorized,
857+
headers = headersOf(HttpHeaders.WWWAuthenticate, "Bearer"))
858+
} else if (token?.startsWith("refreshed-token") == true) {
859+
// Refreshed tokens are accepted
860+
respond("OK", HttpStatusCode.OK)
861+
} else {
862+
respond("No token", HttpStatusCode.Unauthorized)
863+
}
864+
}
865+
}
866+
}
867+
868+
test { client ->
869+
loadCount = 0
870+
refreshCount = 0
871+
872+
// First request:
873+
// 1. loadTokens called -> returns "initial-token"
874+
// 2. Request with "initial-token" gets 401
875+
// 3. refreshTokens called -> returns "refreshed-token-1"
876+
// 4. Request with "refreshed-token-1" succeeds
877+
val response1 = client.get("/")
878+
assertEquals(HttpStatusCode.OK, response1.status)
879+
assertEquals(1, loadCount)
880+
assertEquals(1, refreshCount)
881+
882+
// Second request:
883+
// Since caching is enabled, it uses the cached "refreshed-token-1" directly
884+
// No 401 response, no refresh needed
885+
val response2 = client.get("/")
886+
assertEquals(HttpStatusCode.OK, response2.status)
887+
assertEquals(1, loadCount) // Not called again because caching is enabled
888+
assertEquals(1, refreshCount) // No refresh needed
889+
}
890+
}
699891

700892
@Test
701893
fun testMultipleChallengesInHeader() = clientTests {

0 commit comments

Comments
 (0)