Skip to content

Commit

Permalink
implement token change listener (#5423)
Browse files Browse the repository at this point in the history
  • Loading branch information
samgst-amazon authored Feb 28, 2025
1 parent a6e7685 commit a74f884
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import software.aws.toolkits.core.utils.getLogger
import software.aws.toolkits.core.utils.info
import software.aws.toolkits.core.utils.warn
import software.aws.toolkits.jetbrains.isDeveloperMode
import software.aws.toolkits.jetbrains.services.amazonq.lsp.auth.DefaultAuthCredentialsService
import software.aws.toolkits.jetbrains.services.amazonq.lsp.encryption.JwtEncryptionManager
import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.createExtendedClientMetadata
import software.aws.toolkits.jetbrains.services.telemetry.ClientMetadata
Expand Down Expand Up @@ -303,6 +304,8 @@ private class AmazonQServerInstance(private val project: Project, private val cs
}
languageServer.initialized(InitializedParams())
}

DefaultAuthCredentialsService(project, encryptionManager, this)
}

override fun dispose() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@

package software.aws.toolkits.jetbrains.services.amazonq.lsp.auth

import com.intellij.openapi.Disposable
import com.intellij.openapi.project.Project
import org.eclipse.lsp4j.jsonrpc.messages.ResponseMessage
import software.aws.toolkits.core.TokenConnectionSettings
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProviderListener
import software.aws.toolkits.jetbrains.services.amazonq.lsp.AmazonQLspService
import software.aws.toolkits.jetbrains.services.amazonq.lsp.encryption.JwtEncryptionManager
import software.aws.toolkits.jetbrains.services.amazonq.lsp.model.aws.credentials.BearerCredentials
Expand All @@ -15,7 +21,31 @@ import java.util.concurrent.CompletableFuture
class DefaultAuthCredentialsService(
private val project: Project,
private val encryptionManager: JwtEncryptionManager,
) : AuthCredentialsService {
serverInstance: Disposable,
) : AuthCredentialsService,
BearerTokenProviderListener {
init {
project.messageBus.connect(serverInstance).subscribe(BearerTokenProviderListener.TOPIC, this)
}

override fun onChange(providerId: String, newScopes: List<String>?) {
val connection = ToolkitConnectionManager.getInstance(project)
.activeConnectionForFeature(QConnection.getInstance())
?: return

val provider = (connection.getConnectionSettings() as? TokenConnectionSettings)
?.tokenProvider
?.delegate as? BearerTokenProvider
?: return

provider.currentToken()?.accessToken?.let { token ->
updateTokenCredentials(token, false)
}
}

override fun invalidate(providerId: String) {
deleteTokenCredentials()
}

override fun updateTokenCredentials(accessToken: String, encrypted: Boolean): CompletableFuture<ResponseMessage> {
val token = if (encrypted) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@

package software.aws.toolkits.jetbrains.services.amazonq.lsp.auth

import com.intellij.openapi.Disposable
import com.intellij.openapi.components.serviceIfCreated
import com.intellij.openapi.project.Project
import com.intellij.util.messages.MessageBus
import com.intellij.util.messages.MessageBusConnection
import io.mockk.every
import io.mockk.just
import io.mockk.mockk
import io.mockk.runs
import io.mockk.verify
import org.eclipse.lsp4j.jsonrpc.messages.ResponseMessage
import org.junit.Before
Expand Down Expand Up @@ -42,7 +47,14 @@ class DefaultAuthCredentialsServiceTest {
func.invoke(mockLanguageServer)
}

sut = DefaultAuthCredentialsService(project, this.mockEncryptionManager)
// Mock message bus
val messageBus = mockk<MessageBus>()
every { project.messageBus } returns messageBus
val mockConnection = mockk<MessageBusConnection>()
every { messageBus.connect(any<Disposable>()) } returns mockConnection
every { mockConnection.subscribe(any(), any()) } just runs

sut = DefaultAuthCredentialsService(project, this.mockEncryptionManager, mockk())
}

@Test
Expand Down

0 comments on commit a74f884

Please sign in to comment.